from typing import TypeVar, Generic, Optional

from pydantic import Field

from src.configs.numpy_folder import NumpyFolderNoiseImageLabelSigmaConfig, NumpyFolderNoiseLabelConfig
from src.datasets.lines import NumpyLinesConditionalDataset, LinesConditionalDataset
from src.datasets.noise_label import NumpyNoiseLabelDataset, NoiseLabelDataset
from src.datasets.utils import create_lines_conditional_dataset, create_noise_label_dataset
from src.train.source_target_conditional_trainer import SourceTargetConditionalTrainer, Config as BaseConfig, StepStore
from utils.logger.logger import Logger
from utils.utils import get_class_name


class Config(BaseConfig):
    train_dataset: list[NumpyFolderNoiseImageLabelSigmaConfig] = Field()
    test_dataset: Optional[list[NumpyFolderNoiseImageLabelSigmaConfig]] = Field(default=None)
    fid_train_dataset: list[NumpyFolderNoiseLabelConfig] = Field()
    fid_test_dataset: list[NumpyFolderNoiseLabelConfig] = Field()


C: 'C' = TypeVar('C', bound=Config)
S: 'S' = TypeVar('S', bound=StepStore)


class LinesConditionalFolderTrainer(SourceTargetConditionalTrainer[C, S], Generic[C, S]):
    def __init__(self, config: C):
        Logger.debug(f'{get_class_name(LinesConditionalFolderTrainer.__init__)} start')
        super().__init__(config)
        self._train_dataset: NumpyLinesConditionalDataset = create_lines_conditional_dataset(
            self.config.edm_scheduler,
            self.config.edm_sampler,
            self.config.dataset,
            self.config.train_dataset
        )
        self._test_dataset: NumpyLinesConditionalDataset = create_lines_conditional_dataset(
            self.config.edm_scheduler,
            self.config.edm_sampler,
            self.config.dataset,
            self.config.test_dataset
        ) if self.config.test_dataset is not None else None
        self._fid_train_dataset: NumpyNoiseLabelDataset = create_noise_label_dataset(
            self.config.dataset,
            self.config.fid_train_dataset
        )
        self._fid_test_dataset: NumpyNoiseLabelDataset = create_noise_label_dataset(
            self.config.dataset,
            self.config.fid_test_dataset
        )
        Logger.debug(f'{get_class_name(LinesConditionalFolderTrainer.__init__)} end')

    @property
    def train_dataset(self) -> LinesConditionalDataset:
        return self._train_dataset

    @property
    def test_dataset(self) -> LinesConditionalDataset:
        return self._test_dataset

    @property
    def fid_train_dataset(self) -> NoiseLabelDataset:
        return self._fid_train_dataset

    @property
    def fid_test_dataset(self) -> NoiseLabelDataset:
        return self._fid_test_dataset
