from typing import TypeVar, Generic

from pydantic import Field

from src.configs.numpy_folder import NumpyFolderNoiseImageSigmaConfig, NumpyFolderImageConfig
from src.datasets.image import NumpyImageDataset, ImageDataset
from src.datasets.lines import NumpyLinesUnconditionalDataset, LinesUnconditionalDataset
from src.datasets.utils import create_lines_unconditional_dataset, create_image_dataset
from src.train.style_gan_xl_discriminator_unconditional_trainer import StyleGANXLDiscriminatorUnconditionalTrainer, \
    Config as BaseConfig, StepStore
from utils.logger.logger import Logger
from utils.utils import get_class_name


class Config(BaseConfig):
    train_dataset: list[NumpyFolderNoiseImageSigmaConfig] = Field()
    test_dataset: list[NumpyFolderNoiseImageSigmaConfig] = Field(default=None)
    train_real_dataset: list[NumpyFolderImageConfig] = Field()
    test_real_dataset: list[NumpyFolderImageConfig] = Field(default=None)


C: 'C' = TypeVar('C', bound=Config)
S: 'S' = TypeVar('S', bound=StepStore)


class DiscriminatorUnconditionalFolderTrainer(StyleGANXLDiscriminatorUnconditionalTrainer[C, S], Generic[C, S]):
    def __init__(self, config: C):
        Logger.debug(f'{get_class_name(DiscriminatorUnconditionalFolderTrainer.__init__)} start')
        super().__init__(config)
        self._train_dataset: NumpyLinesUnconditionalDataset = create_lines_unconditional_dataset(
            self.config.edm_scheduler,
            self.config.edm_sampler,
            self.config.dataset,
            self.config.train_dataset
        )
        self._test_dataset: NumpyLinesUnconditionalDataset = create_lines_unconditional_dataset(
            self.config.edm_scheduler,
            self.config.edm_sampler,
            self.config.dataset,
            self.config.test_dataset
        ) if self.config.test_dataset is not None else None
        self._train_real_dataset: NumpyImageDataset = create_image_dataset(
            self.config.dataset,
            self.config.train_real_dataset
        )
        self._test_real_dataset: NumpyImageDataset = create_image_dataset(
            self.config.dataset,
            self.config.test_real_dataset
        ) if self.config.test_real_dataset is not None else None
        Logger.debug(f'{get_class_name(DiscriminatorUnconditionalFolderTrainer.__init__)} end')

    @property
    def train_dataset(self) -> LinesUnconditionalDataset:
        return self._train_dataset

    @property
    def test_dataset(self) -> LinesUnconditionalDataset:
        return self._test_dataset

    @property
    def train_real_dataset(self) -> ImageDataset:
        return self._train_real_dataset

    @property
    def test_real_dataset(self) -> ImageDataset:
        return self._test_real_dataset
