

import yaml
from pathlib import Path
from typing import Dict, Any, Optional, Union, List
from dataclasses import dataclass, field
import copy


@dataclass
class OptimizerConfig:
    type: str = 'Adam'
    learning_rate: float = 0.001
    weight_decay: float = 1e-4
    momentum: float = 0.9 
    betas: List[float] = field(default_factory=lambda: [0.9, 0.999]) 


@dataclass
class SchedulerConfig:
    enabled: bool = True
    type: str = 'ReduceLROnPlateau'
    patience: int = 10
    factor: float = 0.5
    min_lr: float = 1e-6
    step_size: int = 30  
    gamma: float = 0.5 


@dataclass
class EarlyStoppingConfig:
    enabled: bool = True
    patience: int = 20
    min_delta: float = 1e-4
    monitor: str = 'val_accuracy'





@dataclass
class ClusteringConfig:
    n_clusters: int = 10
    algorithm: str = 'kmeans'
    random_state: int = None  
    max_iter: int = 300


@dataclass
class SamplingConfig:
    method: str = 'gaussian'
    n_samples_per_cluster: int = 10
    gaussian_std: float = 1.0
    random_seed: int = None 


@dataclass
class AttentionConfig:
    hidden_dim: int = 128
    dropout_rate: float = 0.2
    use_gated: bool = True


@dataclass
class ClassifierConfig:
    hidden_layers: List[int] = field(default_factory=lambda: [256, 128])
    dropout_rate: float = 0.3
    activation: str = 'relu'


@dataclass
class ModelConfig:
    aggregation_method: str = 'attention'
    attention: AttentionConfig = field(default_factory=AttentionConfig)
    classifier: ClassifierConfig = field(default_factory=ClassifierConfig)


@dataclass
class CrossValidationConfig:
    enabled: bool = True
    n_folds: int = 5
    shuffle: bool = True
    random_state: int = None 
    stratified: bool = True


@dataclass
class MILTrainingConfig:
    epochs: int = 100
    batch_size: int = 1
    optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
    scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
    early_stopping: EarlyStoppingConfig = field(default_factory=EarlyStoppingConfig)


@dataclass
class MILDataConfig:
    clustering: ClusteringConfig = field(default_factory=ClusteringConfig)
    sampling: SamplingConfig = field(default_factory=SamplingConfig)


@dataclass
class MILConfig:
    training: MILTrainingConfig = field(default_factory=MILTrainingConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    data: MILDataConfig = field(default_factory=MILDataConfig)
    cross_validation: CrossValidationConfig = field(default_factory=CrossValidationConfig)


class TrainingConfigManager:
    
    def __init__(self, config_path: Optional[str] = None):
        if config_path is None:
            config_path = Path(__file__).parent.parent.parent / "configs" / "training.yaml"
        
        self.config_path = Path(config_path)
        self._raw_config = self._load_config()
        
    def _load_config(self) -> Dict[str, Any]:
        if not self.config_path.exists():
            raise FileNotFoundError(f"Training config file not found: {self.config_path}")
        
        with open(self.config_path, 'r', encoding='utf-8') as f:
            return yaml.safe_load(f)
    

    
    def get_classifier_config(self, classifier_type: str) -> Dict[str, Any]:
        if 'classifiers' not in self._raw_config:
            raise ValueError("No classifiers configuration found")

        if classifier_type not in self._raw_config['classifiers']:
            available = list(self._raw_config['classifiers'].keys())
            raise ValueError(f"Classifier '{classifier_type}' not found. Available: {available}")

        config = self._raw_config['classifiers'][classifier_type].copy()

        if 'global' in self._raw_config:
            global_config = self._raw_config['global']
            for key, value in global_config.items():
                if key not in config:
                    config[key] = value

        return config
    

    
    def _update_dataclass_from_dict(self, obj: Any, data: Dict[str, Any]):
        for key, value in data.items():
            if hasattr(obj, key):
                attr = getattr(obj, key)
                if hasattr(attr, '__dataclass_fields__'):
                    if isinstance(value, dict):
                        self._update_dataclass_from_dict(attr, value)
                else:
                    setattr(obj, key, value)
    
    def _deep_merge_dict(self, base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
        result = copy.deepcopy(base)
        
        for key, value in override.items():
            if key in result and isinstance(result[key], dict) and isinstance(value, dict):
                result[key] = self._deep_merge_dict(result[key], value)
            else:
                result[key] = copy.deepcopy(value)
        
        return result
    
    def save_config(self, config_dict: Dict[str, Any], save_path: Optional[str] = None):
        if save_path is None:
            save_path = self.config_path
        
        with open(save_path, 'w', encoding='utf-8') as f:
            yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True)
    
    def list_available_classifiers(self) -> List[str]:
        return list(self._raw_config.get('classifiers', {}).keys())

    def get_raw_config(self) -> Dict[str, Any]:
        return self._raw_config.copy()


_config_manager = None

def get_training_config_manager() -> TrainingConfigManager:
    global _config_manager
    if _config_manager is None:
        _config_manager = TrainingConfigManager()
    return _config_manager

def get_classifier_config(classifier_type: str) -> Dict[str, Any]:
    return get_training_config_manager().get_classifier_config(classifier_type)



def get_mil_config(dataset_name: Optional[str] = None) -> MILConfig:
    return get_training_config_manager().get_mil_config(dataset_name)
