from dataclasses import dataclass, field
from typing import Any, List, Optional

from hydra.core.config_store import ConfigStore

from conf.checkpoint import CheckpointParams
from conf.dataset import (BlenderParams, BRATS2020Params, CelebAParams,
                          DatasetParams)
from conf.earlystop import EarlyStopParams
from conf.evaluation_params import EvaluationParams
from conf.ids import InceptionDistanceParams
from conf.model import (DiffusionParams_Cosine, DiffusionParams_DDIM,
                        DiffusionParams_DDPM, ModelParams)
from conf.ProfilerParams import ProfilerParams
from conf.slurm import CfgSlurm
from conf.system_params import SystemParams
from conf.trainer import TrainerParams
from conf.wandb_params import WandbParams


@dataclass
class GlobalConfiguration:
    # region default values
    defaults: List[Any] = field(default_factory=lambda: [
        '_self_',

        {'model_params/diffusion'    : 'ddim'},
        {'dataset_params/data_params': 'blender'},
    ])

    seed: Optional[int] = 42

    yaml_conf: Optional[str] = r'yaml_conf\blender3.yaml'
    # endregion

    checkpoint_params: CheckpointParams = CheckpointParams()
    dataset_params   : DatasetParams    = DatasetParams()
    early_stop_params: EarlyStopParams  = EarlyStopParams()
    model_params     : ModelParams      = ModelParams()
    wandb_params     : WandbParams      = WandbParams()
    cfgSlurm_params  : CfgSlurm         = CfgSlurm()
    trainer_params   : TrainerParams    = TrainerParams()
    system_params    : SystemParams     = SystemParams()
    id_params        : InceptionDistanceParams = InceptionDistanceParams()
    profiler_params  : ProfilerParams   = ProfilerParams()

    evaluation_params: EvaluationParams = EvaluationParams()


# region register config
cs = ConfigStore.instance()

cs.store(name='globalConfiguration', node=GlobalConfiguration)

cs.store(group='model_params/diffusion', name='ddpm'  , node=DiffusionParams_DDPM)
cs.store(group='model_params/diffusion', name='ddim'  , node=DiffusionParams_DDIM)
cs.store(group='model_params/diffusion', name='cosine', node=DiffusionParams_Cosine)

cs.store(group='dataset_params/data_params', name='blender'   , node=BlenderParams)
cs.store(group='dataset_params/data_params', name='brats2020' , node=BRATS2020Params)
cs.store(group='dataset_params/data_params', name='celeba'    , node=CelebAParams)
# endregion
