Source code for neurological_lrd_analysis.benchmark_registry.registry

"""
Registry for Hurst exponent estimation methods.

This module provides a registry system for managing and accessing different
Hurst exponent estimation methods.
"""

import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy as np

from ..biomedical_hurst_factory import (
    BiomedicalHurstEstimatorFactory,
    ConfidenceMethod,
    EstimatorType,
    HurstResult,
)


[docs] @dataclass class EstimatorResult: """Result from an estimator run.""" hurst_estimate: float convergence_flag: bool computation_time: float additional_metrics: Dict[str, Any] error_message: Optional[str] = None
[docs] def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary.""" from dataclasses import asdict return asdict(self)
[docs] class BaseEstimator(ABC): """Base class for all Hurst exponent estimators."""
[docs] def __init__(self, name: str): self.name = name
[docs] @abstractmethod def estimate(self, data: np.ndarray, **kwargs) -> EstimatorResult: """Estimate Hurst exponent from data.""" pass
[docs] class FactoryEstimator(BaseEstimator): """Wrapper for estimators from the biomedical factory."""
[docs] def __init__(self, name: str, factory_method: EstimatorType): super().__init__(name) self.factory_method = factory_method self.factory = BiomedicalHurstEstimatorFactory()
[docs] def estimate(self, data: np.ndarray, **kwargs) -> EstimatorResult: """Estimate using the biomedical factory.""" try: start_time = time.time() # Extract confidence method from kwargs or use default confidence_method = kwargs.pop( "confidence_method", ConfidenceMethod.BOOTSTRAP ) if isinstance(confidence_method, str): confidence_method = ConfidenceMethod(confidence_method) result = self.factory.estimate( data, method=self.factory_method, confidence_method=confidence_method, preprocess=False, # Skip preprocessing for benchmarking assess_quality=False, # Skip quality assessment for benchmarking **kwargs, ) computation_time = time.time() - start_time return EstimatorResult( hurst_estimate=result.hurst_estimate, convergence_flag=result.convergence_flag, computation_time=computation_time, additional_metrics={ "regression_r_squared": result.regression_r_squared, "scaling_range": result.scaling_range, "data_quality_score": result.data_quality_score, "confidence_interval": result.confidence_interval, "regression_p_value": result.regression_r_squared, # Use R² as proxy for p-value "regression_std_error": result.standard_error, # Include all additional metrics from the result **result.additional_metrics, }, ) except Exception as e: return EstimatorResult( hurst_estimate=np.nan, convergence_flag=False, computation_time=0.0, additional_metrics={}, error_message=str(e), )
[docs] class DFAEstimator(FactoryEstimator): """DFA estimator."""
[docs] def __init__(self): super().__init__("DFA", EstimatorType.DFA)
[docs] class HiguchiEstimator(FactoryEstimator): """Higuchi estimator."""
[docs] def __init__(self): super().__init__("Higuchi", EstimatorType.HIGUCHI)
[docs] class PeriodogramEstimator(FactoryEstimator): """Periodogram estimator."""
[docs] def __init__(self): super().__init__("Periodogram", EstimatorType.PERIODOGRAM)
[docs] class RSAnalysisEstimator(FactoryEstimator): """R/S Analysis estimator."""
[docs] def __init__(self): super().__init__("R/S", EstimatorType.RS_ANALYSIS)
[docs] class GPHEstimator(FactoryEstimator): """GPH estimator."""
[docs] def __init__(self): super().__init__("GPH", EstimatorType.GPH)
[docs] class WhittleMLEEstimator(FactoryEstimator): """Local Whittle MLE estimator."""
[docs] def __init__(self): super().__init__("Local-Whittle", EstimatorType.WHITTLE_MLE)
[docs] class GHEEstimator(FactoryEstimator): """Generalized Hurst Exponent estimator."""
[docs] def __init__(self): super().__init__("GHE", EstimatorType.GENERALIZED_HURST)
[docs] class MFDFAEstimator(FactoryEstimator): """MFDFA estimator."""
[docs] def __init__(self): super().__init__("MFDFA(q=2)", EstimatorType.MFDFA)
[docs] class MFDMAEstimator(FactoryEstimator): """MF-DMA estimator."""
[docs] def __init__(self): super().__init__("MF-DMA(q=2)", EstimatorType.MF_DMA)
[docs] class DWTLogscaleEstimator(FactoryEstimator): """DWT Logscale estimator."""
[docs] def __init__(self): super().__init__("DWT-Logscale", EstimatorType.DWT)
[docs] class AbryVeitchEstimator(FactoryEstimator): """Abry-Veitch estimator."""
[docs] def __init__(self): super().__init__("Abry-Veitch", EstimatorType.ABRY_VEITCH)
[docs] class NDWTLogscaleEstimator(FactoryEstimator): """NDWT Logscale estimator."""
[docs] def __init__(self): super().__init__("NDWT-Logscale", EstimatorType.NDWT)
# Registry of available estimators _ESTIMATORS: List[BaseEstimator] = [ DFAEstimator(), HiguchiEstimator(), PeriodogramEstimator(), RSAnalysisEstimator(), GPHEstimator(), WhittleMLEEstimator(), GHEEstimator(), MFDFAEstimator(), MFDMAEstimator(), DWTLogscaleEstimator(), AbryVeitchEstimator(), NDWTLogscaleEstimator(), ]
[docs] def get_registry() -> List[BaseEstimator]: """Get the list of available estimators.""" return _ESTIMATORS.copy()
[docs] def register_estimator(estimator: BaseEstimator) -> None: """Register a new estimator.""" _ESTIMATORS.append(estimator)
[docs] def get_estimator_by_name(name: str) -> Optional[BaseEstimator]: """Get an estimator by name.""" for estimator in _ESTIMATORS: if estimator.name == name: return estimator return None
[docs] def list_estimator_names() -> List[str]: """List all available estimator names.""" return [estimator.name for estimator in _ESTIMATORS]
[docs] def get_estimators_by_category(category: str) -> List[BaseEstimator]: """Get estimators by category.""" # This is a simplified categorization # In a full implementation, estimators would have category attributes temporal_estimators = ["DFA", "Higuchi", "R/S", "GHE", "MFDFA(q=2)", "MF-DMA(q=2)"] spectral_estimators = ["Periodogram", "GPH", "Local-Whittle"] wavelet_estimators = ["DWT-Logscale", "Abry-Veitch", "NDWT-Logscale"] if category.lower() == "temporal": return [est for est in _ESTIMATORS if est.name in temporal_estimators] elif category.lower() == "spectral": return [est for est in _ESTIMATORS if est.name in spectral_estimators] elif category.lower() == "wavelet": return [est for est in _ESTIMATORS if est.name in wavelet_estimators] else: return _ESTIMATORS.copy()
[docs] def benchmark_estimator( estimator: BaseEstimator, data: np.ndarray, n_runs: int = 5 ) -> Dict[str, Any]: """ Benchmark an estimator on given data. Parameters: ----------- estimator : BaseEstimator Estimator to benchmark data : np.ndarray Data to test on n_runs : int Number of runs for timing Returns: -------- Dict[str, Any] Benchmark results """ results = [] for _ in range(n_runs): result = estimator.estimate(data) results.append(result) # Compute statistics estimates = [r.hurst_estimate for r in results if r.convergence_flag] times = [r.computation_time for r in results] benchmark_results = { "estimator_name": estimator.name, "n_runs": n_runs, "success_rate": sum(1 for r in results if r.convergence_flag) / n_runs, "mean_estimate": np.mean(estimates) if estimates else np.nan, "std_estimate": np.std(estimates) if estimates else np.nan, "mean_time": np.mean(times), "std_time": np.std(times), "min_time": np.min(times), "max_time": np.max(times), } return benchmark_results