from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import torch
from omegaconf import DictConfig, OmegaConf

from tabicl.config.config_benchmark_sweep import ConfigPlotting
from tabicl.config.config_data import ConfigData
from tabicl.config.config_optim import ConfigOptim
from tabicl.config.config_plotting import ConfigPlottingTabzilla, ConfigPlottingWhytrees
from tabicl.config.config_preprocessing import ConfigPreprocessing
from tabicl.config.config_save_load_mixin import ConfigSaveLoadMixin
from tabicl.config.config_testing import ConfigTesting
from tabicl.core.enums import GeneratorName, ModelName, Task


@dataclass
class ConfigPretrain(ConfigSaveLoadMixin):
    output_dir: Path
    seed: int
    devices: list[torch.device]
    device: torch.device
    max_cpus_per_device: Optional[int]
    use_ddp: bool
    workers_per_gpu: int
    model: dict
    model_name: ModelName
    data: ConfigData
    optim: ConfigOptim
    preprocessing: ConfigPreprocessing
    testing: ConfigTesting
    plotting: ConfigPlotting
    hyperparams_finetuning: dict



    @classmethod
    def from_hydra(cls, cfg_hydra: DictConfig):

        output_dir = Path(cfg_hydra.output_dir)

        devices = [torch.device(device) for device in cfg_hydra.devices]
        pretrain_model_name = ModelName[cfg_hydra.pretrain_model.name]
        hyperparams_finetuning = cfg_hydra.hyperparams[pretrain_model_name.name.lower()]
        model_settings = cfg_hydra.pretrain_model

        # Initialize device to cpu, DDP will overwrite this
        device = torch.device("cpu")

        return cls(
            output_dir=output_dir,
            devices=devices,
            device=device,
            max_cpus_per_device=cfg_hydra.max_cpus_per_device,
            use_ddp=len(devices) > 1,
            seed=cfg_hydra.seed,
            workers_per_gpu=cfg_hydra.workers_per_gpu,
            model = OmegaConf.to_container(model_settings.hyperparams),    # type: ignore
            model_name = pretrain_model_name,
            hyperparams_finetuning = OmegaConf.to_container(hyperparams_finetuning),    # type: ignore
            data = ConfigData(
                generator=GeneratorName(cfg_hydra.data.generator),
                min_samples_support=cfg_hydra.data.min_samples_support,
                max_samples_support=cfg_hydra.data.max_samples_support,
                n_samples_query=cfg_hydra.data.n_samples_query,
                min_features=cfg_hydra.data.min_features,
                max_features=cfg_hydra.data.max_features,
                max_classes=cfg_hydra.data.max_classes,
                task=Task[cfg_hydra.data.task],
                generator_hyperparams=OmegaConf.to_container(cfg_hydra.data.generator_hyperparams),    # type: ignore
            ),
            optim = ConfigOptim.from_hydra(cfg_hydra.optim),
            preprocessing = ConfigPreprocessing(**cfg_hydra.preprocessing),
            testing = ConfigTesting.from_hydra(cfg_hydra.testing),
            plotting = ConfigPlotting(
                whytrees = ConfigPlottingWhytrees(                    
                    n_runs=cfg_hydra.plotting.whytrees.n_runs,
                    n_random_shuffles=cfg_hydra.plotting.whytrees.n_random_shuffles,
                    confidence_bound=cfg_hydra.plotting.whytrees.confidence_bound,
                    plot_default_value=cfg_hydra.plotting.whytrees.plot_default_value,
                    benchmark_model_names=[ModelName[model] for model in cfg_hydra.plotting.whytrees.benchmark_models]
                ),
                tabzilla = ConfigPlottingTabzilla(
                    benchmark_model_names=[ModelName[model] for model in cfg_hydra.plotting.tabzilla.benchmark_models],
                )
            ),
        )
    











