from dataclasses import dataclass, field
from typing_extensions import Literal
from typing import Dict, List, Optional, Tuple, Any
from emm.utils import (
    TemperatureAnnealer,
)


@dataclass
class TrainingConfig:
    """Comprehensive configuration for mixture model training."""

    # General Settings
    seed: int = 1000
    device: str = "cpu"
    verbose: bool = False
    record_history_every: int = int(1e10)

    # Data Preprocessing
    scaler_x_type: str = "standard"
    scaler_y_type: str = "standard"

    # Model Architecture
    n_mixture_components: int = 3
    temperature: float = 0.2
    optimize_weights: bool = True
    initializer_config: dict[str, Any] = field(
        default_factory=lambda: {
            "type": "sample_based",
            "interval_size": 0.1,
            "strategy": "diverse",
        }
    )

    # Loss Weights
    partition_weight: float = 0.0
    coverage_weight: float = 0.0
    entropy_weight: float = 0.0
    responsibility_weight: float = 0.0
    kl_weight: float = 0.0
    and_layer_entropy: float = 0.0

    # Flow Configuration
    flow_gen: Tuple[str, dict[str, int]] = ("zuko_gf", {"components": 6})
    flow_steps: int = 2
    rules_steps: int = 2

    # Training Parameters
    pop_train_epochs: int = 0
    component_train_epochs: int = 1000

    # Learning Rates
    lr_flow: float = 5e-3
    lr_rules: float = 5e-3
    lr_gmm_remix: float = 5e-3
    batchsize: int = -1

    # Annealing Schedules
    temp_anneal: Optional[TemperatureAnnealer | dict | Literal["auto"]] = field(
        default_factory=lambda: TemperatureAnnealer(0.2, 0.005, 0.1, 0.9)
    )

    # Component Management
    min_responsibility_threshold: float = 0.001
    check_responsibility_every: int = 50
    pruning_threshold: float = 1.0

    # GMM Remix Parameters
    use_gmm_remix: bool = False
    remix_pretrain_epochs: int = 0
    n_gmm_components: int = 20
    n_gmm_extra_components: int = 0
    gmm_reg_covar: float = 1e-6
    gmm_max_iter: int = 2000
    gmm_remix_l1_weight: float = 0
    gmm_div_weight: float = 0.0
    diagonal_gmm_init: bool = False
    component_scoring: str | None = "bic"

    # Background Component
    use_background_component: bool = False
    background_epsilon: float = 0.1
    background_pretrain_epochs: int = 0

    # Component Merging
    merge_components: bool = False
    merge_iou_threshold: float = 0.8
    merge_jsd_threshold: float = 0.1
    merge_adjacency_tol: float = 0.05
    merge_settle_epochs: int = 200

    # Model Finder
    use_model_finder: bool = False
    model_finder_component_range: Optional[List[int]] = None
    model_finder_return_history: bool = False

    def to_dict(self) -> Dict:
        config_dict = self.__dict__.copy()
        for k, v in config_dict.items():
            if isinstance(v, (list, tuple, TemperatureAnnealer)):
                config_dict[k] = str(v)
        return config_dict

    def copy(self) -> "TrainingConfig":
        return TrainingConfig(**self.__dict__)
