from typing import TypeVar, Generic

from pydantic import Field

from src.configs.numpy_folder import NumpyFolderNoiseImageLabelSigmaConfig, NumpyFolderImageLabelConfig
from src.datasets.image_label import NumpyImageLabelDataset, ImageLabelDataset
from src.datasets.lines import NumpyLinesConditionalDataset, LinesConditionalDataset
from src.datasets.utils import create_lines_conditional_dataset, create_image_label_dataset
from src.train.style_gan_xl_discriminator_conditional_trainer import StyleGANXLDiscriminatorConditionalTrainer, \
    Config as BaseConfig, StepStore
from utils.logger.logger import Logger
from utils.utils import get_class_name


class Config(BaseConfig):
    train_dataset: list[NumpyFolderNoiseImageLabelSigmaConfig] = Field()
    test_dataset: list[NumpyFolderNoiseImageLabelSigmaConfig] = Field(default=None)
    train_real_dataset: list[NumpyFolderImageLabelConfig] = Field()
    test_real_dataset: list[NumpyFolderImageLabelConfig] = Field(default=None)


C: 'C' = TypeVar('C', bound=Config)
S: 'S' = TypeVar('S', bound=StepStore)


class DiscriminatorConditionalFolderTrainer(StyleGANXLDiscriminatorConditionalTrainer[C, S], Generic[C, S]):
    def __init__(self, config: C):
        Logger.debug(f'{get_class_name(DiscriminatorConditionalFolderTrainer.__init__)} start')
        super().__init__(config)
        self._train_dataset: NumpyLinesConditionalDataset = create_lines_conditional_dataset(
            self.config.edm_scheduler,
            self.config.edm_sampler,
            self.config.dataset,
            self.config.train_dataset
        )
        self._test_dataset: NumpyLinesConditionalDataset = create_lines_conditional_dataset(
            self.config.edm_scheduler,
            self.config.edm_sampler,
            self.config.dataset,
            self.config.test_dataset
        ) if self.config.test_dataset is not None else None
        self._train_real_dataset: NumpyImageLabelDataset = create_image_label_dataset(
            self.config.dataset,
            self.config.train_real_dataset
        )
        self._test_real_dataset: NumpyImageLabelDataset = create_image_label_dataset(
            self.config.dataset,
            self.config.test_real_dataset
        ) if self.config.test_real_dataset is not None else None
        Logger.debug(f'{get_class_name(DiscriminatorConditionalFolderTrainer.__init__)} end')

    @property
    def train_dataset(self) -> LinesConditionalDataset:
        return self._train_dataset

    @property
    def test_dataset(self) -> LinesConditionalDataset:
        return self._test_dataset

    @property
    def train_real_dataset(self) -> ImageLabelDataset:
        return self._train_real_dataset

    @property
    def test_real_dataset(self) -> ImageLabelDataset:
        return self._test_real_dataset
