import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Dict, Any

import yaml


@dataclass
class ModelConfig:
    """
    Model configuration.

    Attributes
    ----------
    model_name : str
        Model name.
    model_params : Dict[str, Any]
        Additional model parameters.
    checkpoint_path : str
        Name of the weights file to load.
    force_overwrite : bool
        Whether to force loading the model weights even if it used a different model config.
    """
    model_name: str
    model_params: Dict[str, Any] = None

    # For pretrained models
    checkpoint_path: Optional[str] = "model_weights.pth"
    force_overwrite: bool = True


@dataclass
class TrainingConfig:
    """
    Training configuration.

    Attributes
    ----------
    batch_size : int
        Batch size for training.
    n_epochs : int
        Number of training epochs.
    optimizer : str
        Optimizer name.
    scheduler : Optional[str]
        Scheduler name.
    early_stopping_patience : int
        Early stopping patience.
    device : str
        Device to use (e.g., 'cuda', 'cpu').
    n_workers : int
        Number of data loader workers.
    use_wandb : bool
        Whether to use wandb for logging.
    wandb : Optional[Dict[str, Any]]
        Wandb configuration parameters.
    inference_params : Optional[Dict[str, Any]]
        Parameters for inference engine.
    """
    batch_size: int
    n_epochs: int

    optimizer: Dict[str, Any]
    scheduler: str = "reduce_on_plateau"
    early_stopping_patience: int = 10

    loss: str = "mse"
    val_loss: str | None = None
    inference_params: Optional[Dict[str, Any]] = None

    device: str = "cuda"
    n_workers: int = 4

    use_wandb: bool = True
    wandb: Optional[Dict[str, Any]] = None


@dataclass
class DatasetConfig:
    """
    Dataset configuration.

    Attributes
    ----------
    dataset_name : str
        Name of the dataset.
    train_path : str
        Path to the training data file.
    val_path : str
        Path to the validation data file.
    """

    train_path: str
    val_path: str
    dataset_params: Dict[str, Any] = None


@dataclass
class Config:
    """
    Full experiment configuration.

    Attributes
    ----------
    debug : bool
        Whether to run in debug mode, which reduces model and dataset size (default is False).
    model : ModelConfig
        Configuration for the model.
    

    """
    debug: bool
    model: ModelConfig
    training: TrainingConfig
    dataset: DatasetConfig
    seed: int = 42
    output_dir: str = "outputs"

    @classmethod
    def from_yaml(cls, yaml_path: Path | str) -> 'Config':
        """
        Load configuration from a YAML file.

        Parameters
        ----------
        yaml_path : str
            Path to the YAML config file.
        Returns
        -------
        Config
            Loaded configuration object.
        """
        with open(yaml_path, 'r') as f:
            config_dict = yaml.safe_load(f)
        return cls(
            debug=config_dict.get('debug', False),
            model=ModelConfig(**config_dict['model']),
            training=TrainingConfig(**config_dict['training']),
            dataset=DatasetConfig(**config_dict['dataset']),
            seed=config_dict.get('seed', 42),
            output_dir=config_dict.get('output_dir', 'outputs')
        )

    @classmethod
    def from_dict(cls, config_dict: dict) -> 'Config':
        """
        Load configuration from a YAML file.

        Parameters
        ----------
        yaml_path : str
            Path to the YAML config file.
        Returns
        -------
        Config
            Loaded configuration object.
        """
        return cls(
            debug=config_dict.get('debug', False),
            model=ModelConfig(**config_dict['model']),
            training=TrainingConfig(**config_dict['training']),
            dataset=DatasetConfig(**config_dict['dataset']),
            seed=config_dict.get('seed', 42),
            output_dir=config_dict.get('output_dir', 'outputs')
        )

    def save(self, path: str) -> None:
        """
        Save configuration to a YAML file.

        Parameters
        ----------
        path : str
            Path to save the YAML config file.
        """
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'w') as f:
            yaml.dump(self.to_dict(), f, default_flow_style=False)

    def to_dict(self) -> Dict[str, Any]:
        """
        Convert the configuration to a dictionary.

        Returns
        -------
        Dict[str, Any]
            Dictionary representation of the configuration.
        """
        config = {
            'model': self.model.__dict__,
            'training': self.training.__dict__,
            'dataset': self.dataset.__dict__,
            'seed': self.seed,
            'output_dir': str(self.output_dir)
        }

        for key, value in config['dataset'].items():
            if isinstance(value, Path):
                config['dataset'][key] = str(value)

        return config
