from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple, Union
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING


@dataclass
class DatasetConfig:
    name: str = MISSING
    group: str = "dataset"
    data_path: str = MISSING
    grayscale: bool = False
    normal_class: Optional[int] = None
    n_normal_class: int = 1
    anomaly_class: Optional[int] = None
    size: Any = 224
    crop: bool = True
    random_crop: bool = False
    fixed_random_crop: bool = False  # Bugfix of original faulty random cropping
    is_tensor: bool = False
    random_horizontal_flip: bool = False


@dataclass
class Cifar10(DatasetConfig):
    name: str = "cifar10"
    data_path: str = "/datasets/CIFAR10/"
    size: int = 32
    num_dataset_classes: int = 10


@dataclass
class Cifar100(DatasetConfig):
    name: str = "cifar100"
    data_path: str = "/datasets/CIFAR100/"
    size: int = 32
    num_dataset_classes: int = 20


@dataclass
class DogsVsCats(DatasetConfig):
    name: str = "dogs_vs_cats"
    data_path: str = "/datasets/dogs-vs-cats/"
    size: int = 128
    num_dataset_classes: int = 2
    random_crop: bool = True


@dataclass
class MuffinVsChihuahua(DatasetConfig):
    name: str = "muffin_vs_chihuahua"
    data_path: str = "/datasets/muffin-vs-chihuahua/"
    size: int = 128
    num_dataset_classes: int = 2
    random_crop: bool = True


@dataclass
class ImageNet30(DatasetConfig):
    name: str = "imagenet30"
    data_path: str = "/datasets/imagenet30/"
    random_crop: bool = True
    size: int = 224
    random_horizontal_flip: bool = True
    num_dataset_classes: int = 30


@dataclass
class Pneumonia(DatasetConfig):
    name: str = "pneumonia"
    data_path: str = "/datasets/mimic_cxr"
    size: int = 224
    num_dataset_classes: int = 2
    grayscale: bool = True


@dataclass
class Melanoma(DatasetConfig):
    name: str = "melanoma"
    data_path: str = "/datasets/melanoma_cancer_dataset"
    size: int = 128
    random_horizontal_flip: bool = True
    num_dataset_classes: int = 2
