from dataclasses import dataclass

from omegaconf import DictConfig, OmegaConf

from tabicl.core.enums import BenchmarkName, DownstreamTask


@dataclass
class ConfigTesting():
    downstream_tasks: list[DownstreamTask]
    n_default_runs_per_dataset_valid: int
    n_default_runs_per_dataset_test: int               
    openml_dataset_ids_to_ignore: list[int]
    decision_boundary_analysis_enabled: bool
    decision_boundary_analysis_grid_size: int
    loss_graph_min_step: int
    benchmarks_valid: list[BenchmarkName]
    benchmarks_test: list[BenchmarkName]


    @classmethod
    def from_hydra(cls, cfg_hydra: DictConfig):

        downstream_tasks = [DownstreamTask[task] for task in cfg_hydra.downstream_tasks]
        benchmarks_valid = [BenchmarkName[benchmark] for benchmark in cfg_hydra.benchmarks_valid]
        benchmarks_test = [BenchmarkName[benchmark] for benchmark in cfg_hydra.benchmarks_test]

        return cls(
            downstream_tasks=downstream_tasks,
            n_default_runs_per_dataset_valid=cfg_hydra.n_default_runs_per_dataset_valid,
            n_default_runs_per_dataset_test=cfg_hydra.n_default_runs_per_dataset_test,
            openml_dataset_ids_to_ignore=OmegaConf.to_container(cfg_hydra.openml_dataset_ids_to_ignore),    # type: ignore   
            decision_boundary_analysis_enabled=cfg_hydra.decision_boundary_analysis.enabled,
            decision_boundary_analysis_grid_size=cfg_hydra.decision_boundary_analysis.grid_size,
            loss_graph_min_step=cfg_hydra.loss_graph_min_step,
            benchmarks_valid=benchmarks_valid,
            benchmarks_test=benchmarks_test
        )

