from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple, Union
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING

from .logging import *
from .data import *
from .model import *
from .optimizer import *
from .enums import *
from .augmentations import *
from .loss import *


@dataclass
class TrainConfig:
    name: str = MISSING

    # Subconfigs for specific experiment parts
    model: ModelConfig = MISSING
    dataset: DatasetConfig = MISSING
    logging: LogConfig = LogConfig()
    optimizer: OptimizerConfig = MISSING
    augmentation: AugmentationConfig = SimclrAugmentationConfig
    loss: LossConfig = MISSING
    seed: Optional[int] = 0

    # Training setup
    num_workers: int = 4
    num_nodes: int = 1
    num_gpus: int = 1
    use_cpu: bool = False
    epochs: int = 2048
    accum_batches: int = 1
    batch_size: int = 512
    batch_size_eval: int = 32
    fp16: bool = True
    skip_sanity_checks: bool = True
    # Set a path to continue training from
    resume_training_from: Optional[Union[str, bool]] = None

    # Enable debugging
    debug: bool = False

    # validation logging frequency
    eval_every_n_epochs: int = 256

    # Configure whether NND Index should be approximated to save memory
    approximate_index: bool = False
    index_pca_dim: Optional[int] = None


@dataclass
class ContextContrastingConfig(TrainConfig):
    name: str = "context_contrasting_ad"


def register_conf():
    cs = ConfigStore.instance()
    # Models
    cs.store(group=ResNet.group, name=ResNet.name, node=ResNet)

    # Optimizer
    cs.store(group=SGD.group, name=SGD.name, node=SGD)
    cs.store(group=Adam.group, name=Adam.name, node=Adam)
    cs.store(group=AdamW.group, name=AdamW.name, node=AdamW)

    # Datasets
    cs.store(group=Cifar10.group, name=Cifar10.name, node=Cifar10)
    cs.store(group=Cifar100.group, name=Cifar100.name, node=Cifar100)
    cs.store(group=ImageNet30.group, name=ImageNet30.name, node=ImageNet30)
    cs.store(group=DogsVsCats.group, name=DogsVsCats.name, node=DogsVsCats)
    cs.store(group=Pneumonia.group, name=Pneumonia.name, node=Pneumonia)
    cs.store(
        group=MuffinVsChihuahua.group,
        name=MuffinVsChihuahua.name,
        node=MuffinVsChihuahua,
    )
    cs.store(group=Melanoma.group, name=Melanoma.name, node=Melanoma)

    # Augmentations
    for group in [AugmentationConfig.group, "loss.augmentation_class"]:
        for node in [
            SimclrAugmentationConfig,
            RandomRotationConfig,
            RandomInvertConfig,
            RandomEqualizeConfig,
            ContextFlipConfig,
            ContextFlipInvertConfig,
            ContextFlipEqualizeConfig,
            ContextInvertEqualizeConfig,
            ContextFlipInvertEqualizeConfig,
        ]:
            cs.store(group=group, name=node.name, node=node)

    # Loss
    cs.store(
        group=ContextContrastingLossConfig.group,
        name=ContextContrastingLossConfig.name,
        node=ContextContrastingLossConfig,
    )

    # Logging
    cs.store(group=LogConfig.group, name=LogConfig.name, node=LogConfig)

    # Config template
    cs.store(name=ContextContrastingConfig.name, node=ContextContrastingConfig)
