from typing import TypeVar, Generic, Optional

from pydantic import Field

from src.configs.numpy_folder import NumpyFolderNoiseImageSigmaConfig, NumpyFolderNoiseConfig
from src.datasets.lines import NumpyLinesUnconditionalDataset, LinesUnconditionalDataset
from src.datasets.noise import NumpyNoiseDataset
from src.datasets.noise_label import NoiseLabelDataset
from src.datasets.utils import create_lines_unconditional_dataset, create_noise_dataset
from src.train.source_target_unconditional_trainer import SourceTargetUnconditionalTrainer, Config as BaseConfig, \
    StepStore
from utils.logger.logger import Logger
from utils.utils import get_class_name


class Config(BaseConfig):
    train_dataset: list[NumpyFolderNoiseImageSigmaConfig] = Field()
    test_dataset: Optional[list[NumpyFolderNoiseImageSigmaConfig]] = Field(default=None)
    fid_train_dataset: list[NumpyFolderNoiseConfig] = Field()
    fid_test_dataset: list[NumpyFolderNoiseConfig] = Field()


C: 'C' = TypeVar('C', bound=Config)
S: 'S' = TypeVar('S', bound=StepStore)


class LinesUnconditionalFolderTrainer(SourceTargetUnconditionalTrainer[C, S], Generic[C, S]):
    def __init__(self, config: C):
        Logger.debug(f'{get_class_name(LinesUnconditionalFolderTrainer.__init__)} start')
        super().__init__(config)
        self._train_dataset: NumpyLinesUnconditionalDataset = create_lines_unconditional_dataset(
            self.config.edm_scheduler,
            self.config.edm_sampler,
            self.config.dataset,
            self.config.train_dataset
        )
        self._test_dataset: NumpyLinesUnconditionalDataset = create_lines_unconditional_dataset(
            self.config.edm_scheduler,
            self.config.edm_sampler,
            self.config.dataset,
            self.config.test_dataset
        ) if self.config.test_dataset is not None else None
        self._fid_train_dataset: NumpyNoiseDataset = create_noise_dataset(
            self.config.dataset,
            self.config.fid_train_dataset
        )
        self._fid_test_dataset: NumpyNoiseDataset = create_noise_dataset(
            self.config.dataset,
            self.config.fid_test_dataset
        )
        Logger.debug(f'{get_class_name(LinesUnconditionalFolderTrainer.__init__)} end')

    @property
    def train_dataset(self) -> LinesUnconditionalDataset:
        return self._train_dataset

    @property
    def test_dataset(self) -> LinesUnconditionalDataset:
        return self._test_dataset

    @property
    def fid_train_dataset(self) -> NoiseLabelDataset:
        return self._fid_train_dataset

    @property
    def fid_test_dataset(self) -> NoiseLabelDataset:
        return self._fid_test_dataset
