from typing import Optional

from pydantic import Field, BaseModel, ConfigDict, model_validator

from enum import Enum


class DatasetType(Enum):
    CIFAR10_32_32 = ('cifar10-32x32', 10, (3, 32, 32))
    ImageNet_64_64 = ('imagenet-64x64', 1000, (3, 64, 64))
    AFHQv2_64_64 = ('afhqv2-64x64', None, (3, 64, 64))

    def __init__(self, dataset_name: str, num_classes: Optional[int], image_shape: tuple[int, ...]) -> None:
        self.dataset_name: str = dataset_name
        self.num_classes: Optional[int] = num_classes
        self.image_shape: tuple[int, ...] = image_shape

    @property
    def name(self) -> str:
        return self.dataset_name

    @classmethod
    def create_config(cls, dataset_name: str) -> 'DatasetType':
        for config in cls:
            if config.dataset_name == dataset_name:
                return config
        raise ValueError(f'unknown dataset name: {dataset_name}')


class DatasetConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    dataset_name: str = Field()

    @property
    def dataset_type(self) -> DatasetType:
        return DatasetType.create_config(self.dataset_name)

    @property
    def num_classes(self) -> Optional[int]:
        return self.dataset_type.num_classes

    @property
    def image_shape(self) -> tuple[int, ...]:
        return self.dataset_type.image_shape
