from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Any, Generator, Callable, Optional

import numpy as np
import torch.distributed
import torch.utils.data
import torch.utils.data
from ema_pytorch import EMA
from pydantic import Field

from src.configs.amp import AMPConfig
from src.configs.edm_model import EDMModelConfig
from src.configs.edm_sampler import EDMSamplerConfig
from src.configs.edm_scheduler import EDMSchedulerConfig
from src.configs.ema import EMAConfig
from src.configs.fid import FIDConfig
from src.configs.loss import LossConfig
from src.configs.optimizer import OptimizerConfig
from src.configs.style_gan_xl_discriminator import StyleGANXLDiscriminatorConfig
from src.datasets.image_label import ImageLabelDataset
from src.datasets.noise_label import NoiseLabelDataset
from src.datasets.source_target_label_sigma import SourceTargetLabelSigmaDataset
from src.fid.single_step import inference_single_step
from src.fid.utils import calculate_fid_for_dataset_and_model as calculate_fid_for_dataset_and_model_helper
from src.metrics.average_metric import AverageMetric
from src.models.utils import count_parameters, count_trainable_parameters
from src.train.trainer import Config as BaseConfig, StepStore as BaseStepStore, Trainer
from src.utils.style_gan_xl.loss import StyleGANXLGeneratorLoss, StyleGANXLDiscriminatorLoss
from src.utils.utils import create_infinite_data_loader, create_one_hot, report_values, extract_model, free_all_memory
from src.utils.utils import create_one_hot_torch
from torch_utils.distributed.distributed_manager import DistributedManager
from torch_utils.distributed.utils import ddp_sync
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.numpy.stats import get_numpy_stats
from utils.utils import get_class_name, get_object_name


class Config(BaseConfig):
    edm_scheduler: EDMSchedulerConfig = Field()
    edm_sampler: EDMSamplerConfig = Field()
    model: EDMModelConfig = Field()
    model_optimizer: OptimizerConfig = Field()
    discriminator: StyleGANXLDiscriminatorConfig = Field()
    discriminator_optimizer: OptimizerConfig = Field()
    ema: list[EMAConfig] = Field()
    fid: FIDConfig = Field()
    reconstruction_loss: LossConfig = Field()
    reconstruction_lambda: float = Field()
    clip_train_data: bool = Field(default=False)
    amp: AMPConfig = Field()
    distributed_sampler_seed: int = Field(default=0)
    distributed_sampler_real_seed: int = Field(default=0)


class StepStore(BaseStepStore):
    reconstruction_loss: float = Field()
    model_gan_loss: float = Field()
    adaptive_weight: float = Field()
    model_loss: float = Field()
    discriminator_real_loss: float = Field()
    discriminator_fake_loss: float = Field()
    discriminator_loss: float = Field()

    def get_step_values(self) -> dict[str, Any]:
        return {
            **super().get_step_values(),
            'reconstruction_loss': self.reconstruction_loss,
            'model_gan_loss': self.model_gan_loss,
            'adaptive_weight': self.adaptive_weight,
            'model_loss': self.model_loss,
            'gan_discriminator_real_loss': self.discriminator_real_loss,
            'gan_discriminator_fake_loss': self.discriminator_fake_loss,
            'gan_discriminator_loss': self.discriminator_loss
        }


C: 'C' = TypeVar('C', bound=Config)
S: 'S' = TypeVar('S', bound=StepStore)


class SourceTargetStyleGANXLConditionalTrainer(Trainer[C, S], Generic[C, S], ABC):
    model: torch.nn.Module
    model_optimizer: torch.optim.Optimizer
    discriminator: torch.nn.Module
    discriminator_feature_extractor: torch.nn.Module
    discriminator_optimizer: torch.optim.Optimizer
    scaler: torch.cuda.amp.GradScaler
    ema_models: list[EMA]
    best_fid_value: Optional[float]
    best_fid_ema_values: list[Optional[float]]
    reconstruction_loss: torch.nn.Module
    model_gan_loss: StyleGANXLGeneratorLoss
    discriminator_loss: StyleGANXLDiscriminatorLoss
    metrics: dict[str, AverageMetric]
    train_sampler: torch.utils.data.DistributedSampler
    train_data_loader: torch.utils.data.DataLoader
    train_infinite_data_loader: Generator[dict[str, torch.Tensor], None, None]
    train_real_sampler: torch.utils.data.DistributedSampler
    train_real_data_loader: torch.utils.data.DataLoader
    train_real_infinite_data_loader: Generator[dict[str, torch.Tensor], None, None]
    train_example_data: dict[str, np.ndarray]
    test_example_data: dict[str, np.ndarray]
    train_real_example_data: dict[str, np.ndarray]
    test_real_example_data: dict[str, np.ndarray]

    def __init__(self, config: C):
        super().__init__(config)

    @property
    @abstractmethod
    def train_dataset(self) -> SourceTargetLabelSigmaDataset:
        raise NotImplementedError('train_dataset method must be implemented')

    @property
    @abstractmethod
    def test_dataset(self) -> SourceTargetLabelSigmaDataset:
        raise NotImplementedError('test_dataset method must be implemented')

    @property
    @abstractmethod
    def train_real_dataset(self) -> ImageLabelDataset:
        raise NotImplementedError('train_dataset method must be implemented')

    @property
    @abstractmethod
    def test_real_dataset(self) -> ImageLabelDataset:
        raise NotImplementedError('test_dataset method must be implemented')

    @property
    @abstractmethod
    def fid_train_dataset(self) -> NoiseLabelDataset:
        raise NotImplementedError('fid_train_dataset method must be implemented')

    @property
    @abstractmethod
    def fid_test_dataset(self) -> NoiseLabelDataset:
        raise NotImplementedError('fid_test_dataset method must be implemented')

    def start_callback(self) -> None:
        Logger.debug(f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.start_callback)} start')
        super().start_callback()

        assert isinstance(self.train_dataset, SourceTargetLabelSigmaDataset), \
            f'train_dataset must be an instance of SourceTargetLabelSigmaDataset: {self.train_dataset}'
        assert self.test_dataset is None or isinstance(self.test_dataset, SourceTargetLabelSigmaDataset), \
            f'test_dataset must be an instance of SourceTargetLabelSigmaDataset: {self.test_dataset}'
        assert isinstance(self.train_real_dataset, ImageLabelDataset), \
            f'train_real_dataset must be an instance of ImageLabelDataset: {self.train_real_dataset}'
        assert self.test_real_dataset is None or isinstance(self.test_real_dataset, ImageLabelDataset), \
            f'test_real_dataset must be an instance of ImageLabelDataset: {self.test_real_dataset}'
        assert isinstance(self.fid_train_dataset, NoiseLabelDataset), \
            f'fid_train_dataset must be an instance of NoiseLabelDataset: {self.fid_train_dataset}'
        assert isinstance(self.fid_test_dataset, NoiseLabelDataset), \
            f'fid_test_dataset must be an instance of NoiseLabelDataset: {self.fid_test_dataset}'

        self.model: torch.nn.Module = self.config.model.get_model().train().to(get_default_device())
        Logger.debug(f'model parameters: {count_parameters(self.model)}')
        Logger.debug(f'model trainable parameters: {count_trainable_parameters(self.model)}')

        self.ema_models: list[EMA] = []
        for i in range(len(self.config.ema)):
            ema: EMA = self.config.ema[i].get_ema(self.model)
            ema.ema_model.to('cpu')
            self.ema_models.append(ema)

        self.model: torch.nn.Module = (
            torch.nn.parallel.DistributedDataParallel(self.model, find_unused_parameters=True)
            if DistributedManager.initialized else self.model
        )
        self.model_optimizer: torch.optim.Optimizer = self.config.model_optimizer.get_optimizer(
            self.model,
            batch_repeats=self.config.batch_repeats,
            learning_rate_batch_repeats=self.config.learning_rate_batch_repeats
        )

        Logger.debug(f'model: {get_object_name(self.model)}')
        Logger.debug(f'model_optimizer: {get_object_name(self.model_optimizer)}')

        self.discriminator, self.discriminator_feature_extractor = \
            self.config.discriminator.get_discriminator(self.config.dataset.image_shape[1])
        self.discriminator.train().to(get_default_device())
        self.discriminator_feature_extractor.requires_grad_(False).eval().to(get_default_device())

        Logger.debug(f'discriminator parameters: {count_parameters(self.discriminator)}')
        Logger.debug(f'discriminator trainable parameters: {count_trainable_parameters(self.discriminator)}')
        Logger.debug(
            f'discriminator feature extractor parameters: {count_parameters(self.discriminator_feature_extractor)}')
        Logger.debug(
            f'discriminator feature extractor trainable parameters: '
            f'{count_trainable_parameters(self.discriminator_feature_extractor)}'
        )

        self.discriminator: torch.nn.Module = (
            torch.nn.parallel.DistributedDataParallel(self.discriminator, broadcast_buffers=False)
            if DistributedManager.initialized else self.discriminator
        )
        self.discriminator_optimizer: torch.optim.Optimizer = self.config.discriminator_optimizer.get_optimizer(
            self.discriminator,
            batch_repeats=self.config.batch_repeats,
            learning_rate_batch_repeats=self.config.learning_rate_batch_repeats
        )

        Logger.debug(f'discriminator: {get_object_name(self.discriminator)}')
        Logger.debug(f'discriminator_feature_extractor: {get_object_name(self.discriminator_feature_extractor)}')
        Logger.debug(f'discriminator_optimizer: {get_object_name(self.discriminator_optimizer)}')

        self.scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available() and self.config.amp.use_gard_scaler)

        self.best_fid_value = None
        self.best_fid_ema_values = [None for _ in range(len(self.config.ema))]

        self.reconstruction_loss: torch.nn.Module = self.config.reconstruction_loss.get_loss()
        self.reconstruction_loss.to(get_default_device())

        self.model_gan_loss = StyleGANXLGeneratorLoss(
            model=self.model,
            discriminator_feature_extractor=self.discriminator_feature_extractor,
            discriminator=self.discriminator,
            reconstruction_loss=self.reconstruction_loss,
            dataset_type=self.config.dataset.dataset_type
        )
        self.discriminator_loss = StyleGANXLDiscriminatorLoss(
            discriminator_feature_extractor=self.discriminator_feature_extractor,
            discriminator=self.discriminator
        )

        self.metrics = {
            'reconstruction_loss': AverageMetric(),
            'model_gan_loss': AverageMetric(),
            'adaptive_weight': AverageMetric(),
            'model_loss': AverageMetric(),
            'discriminator_real_loss': AverageMetric(),
            'discriminator_fake_loss': AverageMetric(),
            'discriminator_loss': AverageMetric()
        }

        Logger.info('creating data loaders')
        self.train_sampler = torch.utils.data.DistributedSampler(
            dataset=self.train_dataset,
            num_replicas=DistributedManager.world_size,
            rank=DistributedManager.rank,
            shuffle=True,
            drop_last=False,
            seed=self.config.distributed_sampler_seed
        ) if DistributedManager.initialized else None
        self.train_data_loader = torch.utils.data.DataLoader(
            dataset=self.train_dataset,
            sampler=self.train_sampler,
            batch_size=self.config.batch_size // DistributedManager.world_size
            if DistributedManager.initialized else self.config.batch_size,
            shuffle=True if self.train_sampler is None else None,
            num_workers=self.config.data_loader_workers,
            drop_last=False
        )
        self.train_infinite_data_loader = create_infinite_data_loader(self.train_data_loader, self.train_sampler)

        self.train_real_sampler = torch.utils.data.DistributedSampler(
            dataset=self.train_real_dataset,
            num_replicas=DistributedManager.world_size,
            rank=DistributedManager.rank,
            shuffle=True,
            drop_last=False,
            seed=self.config.distributed_sampler_real_seed
        ) if DistributedManager.initialized else None
        self.train_real_data_loader = torch.utils.data.DataLoader(
            dataset=self.train_real_dataset,
            sampler=self.train_real_sampler,
            batch_size=self.config.batch_size // DistributedManager.world_size
            if DistributedManager.initialized else self.config.batch_size,
            shuffle=True if self.train_real_sampler is None else None,
            num_workers=self.config.data_loader_workers,
            drop_last=False
        )
        self.train_real_infinite_data_loader = \
            create_infinite_data_loader(self.train_real_data_loader, self.train_real_sampler)

        Logger.info('creating example data')
        self.train_example_data = self.train_dataset.merge_data([self.train_dataset[i] for i in range(10)])
        self.test_example_data = self.test_dataset.merge_data([self.test_dataset[i] for i in range(10)]) \
            if self.test_dataset is not None else None
        self.train_real_example_data = self.train_real_dataset.merge_data(
            [self.train_real_dataset[i] for i in range(10)])
        self.test_real_example_data = self.test_real_dataset.merge_data([self.test_real_dataset[i] for i in range(10)]) \
            if self.test_real_dataset is not None else None

        Logger.debug(str({
            'train_example_data':
                {key: get_numpy_stats(self.train_example_data[key]) for key in self.train_example_data}
        }))
        Logger.debug(str({
            'train_real_example_data':
                {key: get_numpy_stats(self.train_real_example_data[key]) for key in self.train_real_example_data}
        }))

        if self.test_example_data is not None:
            Logger.debug(str({
                'test_example_data':
                    {key: get_numpy_stats(self.test_example_data[key]) for key in self.test_example_data}
            }))

        if self.test_real_example_data is not None:
            Logger.debug(str({
                'test_real_example_data':
                    {key: get_numpy_stats(self.test_real_example_data[key]) for key in self.test_real_example_data}
            }))

        Logger.info('logging images')
        self.log_example_data()
        self.tensorboard_logger.log_images(
            tag='dataset/train_example_data',
            images=self.train_example_data['target'],
            step=0
        )
        if self.test_example_data is not None:
            self.tensorboard_logger.log_images(
                tag='dataset/test_example_data',
                images=self.test_example_data['target'],
                step=0
            )
        self.tensorboard_logger.log_images(
            tag='dataset/train_real_example_data',
            images=self.train_real_example_data['image'],
            step=0
        )
        if self.test_real_example_data is not None:
            self.tensorboard_logger.log_images(
                tag='dataset/test_real_example_data',
                images=self.test_real_example_data['image'],
                step=0
            )
        Logger.debug(f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.start_callback)} end')

    def log_images_model(
            self,
            label: str,
            step: int,
            sources: np.ndarray,
            sigmas: np.ndarray,
            labels: np.ndarray = None,
    ) -> None:
        Logger.debug(f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.log_images_model)} start')
        self.model.eval()
        with torch.inference_mode():
            images: np.ndarray = self.model(
                torch.from_numpy(sources).to(get_default_device()),
                torch.from_numpy(sigmas).to(get_default_device()),
                torch.from_numpy(create_one_hot(
                    labels,
                    self.config.dataset.num_classes
                )).to(get_default_device()) if labels is not None else None
            ).detach().cpu().numpy()
        self.tensorboard_logger.log_images(label, images, step)
        self.model.train()
        Logger.debug(f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.log_images_model)} end')

    def update_ema_callback(self, _: StepStore) -> None:
        for i in range(len(self.ema_models)):
            self.ema_models[i].update()

    def check_consistency(self, func: Callable[[torch.nn.Module], bool]) -> bool:
        Logger.debug(f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.check_consistency)} start')
        value: bool = func(self.model) and func(self.discriminator) and func(self.discriminator_feature_extractor)
        free_all_memory()
        Logger.debug(
            f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.check_consistency)} end - value: {value}')
        return value

    def load_from_folder(self, folder: str) -> None:
        Logger.debug(
            f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.load_from_folder)} start - folder: {folder}')
        extract_model(self.model).load_state_dict(torch.load(f'{folder}/model.pth', map_location=get_default_device()))
        self.model_optimizer.load_state_dict(
            torch.load(f'{folder}/model_optimizer.pth', map_location=get_default_device()))
        extract_model(self.discriminator).discriminator.load_state_dict(
            torch.load(f'{folder}/discriminator.pth', map_location=get_default_device()))
        extract_model(self.discriminator_feature_extractor).discriminator_feature_extractor.load_state_dict(
            torch.load(f'{folder}/discriminator_feature_extractor.pth', map_location=get_default_device()))
        self.discriminator_optimizer.load_state_dict(
            torch.load(f'{folder}/discriminator_optimizer.pth', map_location=get_default_device()))
        for i in range(len(self.ema_models)):
            self.ema_models[i].ema_model.load_state_dict(
                torch.load(f'{folder}/ema/{self.config.ema[i].get_str()}.pth', map_location=get_default_device()))
        Logger.debug(f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.load_from_folder)} end')

    def get_state_dicts(self, step: int) -> dict[str, dict]:
        Logger.debug(f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.get_state_dicts)} start - step: {step}')
        return {
            **super().get_state_dicts(step),
            'model': extract_model(self.model).state_dict(),
            'model_optimizer': self.model_optimizer.state_dict(),
            'discriminator': extract_model(self.discriminator).discriminator.state_dict(),
            'discriminator_feature_extractor':
                extract_model(self.discriminator_feature_extractor).discriminator_feature_extractor.state_dict(),
            'discriminator_optimizer': self.discriminator_optimizer.state_dict(),
            **{
                f'ema/{self.config.ema[i].get_str()}': self.ema_models[i].ema_model.state_dict()
                for i in range(len(self.ema_models))
            }
        }

    def model_sub_step(self) -> dict[str, float]:
        batch_data: dict[str, torch.Tensor] = next(self.train_infinite_data_loader)
        label: torch.Tensor = \
            create_one_hot_torch(batch_data['label'], self.config.dataset.num_classes).to(get_default_device())
        with torch.autocast(
                enabled=torch.cuda.is_available() and self.config.amp.use_autocast,
                device_type=get_default_device(),
                dtype=torch.float16 if self.config.amp.use_fp16 else torch.float32
        ):
            generated: torch.Tensor = self.model(
                batch_data['source'].to(get_default_device()),
                batch_data['sigma'].to(get_default_device()),
                label
            )
        target: torch.Tensor = batch_data['target'].to(torch.float32).to(get_default_device())
        if self.config.clip_train_data:
            target: torch.Tensor = torch.clip(target, -1, 1)
        loss_values: dict[str, torch.Tensor] = self.model_gan_loss.get_generator_loss(
            fake=generated,
            target=target,
            label=label
        )
        reconstruction_loss: torch.Tensor = loss_values['reconstruction_loss']
        model_gan_loss: torch.Tensor = loss_values['model_gan_loss']
        adaptive_weight: torch.Tensor = loss_values['adaptive_weight']
        model_loss: torch.Tensor = \
            self.config.reconstruction_lambda * reconstruction_loss + adaptive_weight * model_gan_loss
        self.scaler.scale(
            model_loss if self.config.learning_rate_batch_repeats
            else model_loss / self.config.batch_repeats
        ).backward()
        return {
            'reconstruction_loss': reconstruction_loss.item(),
            'model_gan_loss': model_gan_loss.item(),
            'adaptive_weight': adaptive_weight.item(),
            'model_loss': model_loss.item()
        }

    def discriminator_sub_step(self) -> dict[str, float]:
        batch_data: dict[str, torch.Tensor] = next(self.train_infinite_data_loader)
        batch_real_data: dict[str, torch.Tensor] = next(self.train_real_infinite_data_loader)
        label: torch.Tensor = \
            create_one_hot_torch(batch_data['label'], self.config.dataset.num_classes).to(get_default_device())
        label_real: torch.Tensor = \
            create_one_hot_torch(batch_real_data['label'], self.config.dataset.num_classes).to(get_default_device())
        with torch.autocast(
                enabled=torch.cuda.is_available() and self.config.amp.use_autocast,
                device_type=get_default_device(),
                dtype=torch.float16 if self.config.amp.use_fp16 else torch.float32
        ):
            with torch.inference_mode():
                generated: torch.Tensor = self.model(
                    batch_data['source'].to(get_default_device()),
                    batch_data['sigma'].to(get_default_device()),
                    label
                )
        discriminator_loss: dict[str, torch.Tensor] = self.discriminator_loss.get_discriminator_loss(
            real=batch_real_data['image'].to(torch.float32).to(get_default_device()),
            real_label=label_real,
            fake=generated.detach(),
            fake_label=label,
            prob_aug=self.config.discriminator.prob_aug,
            shift_ratio=self.config.discriminator.shift_ratio,
            cutout_ratio=self.config.discriminator.cutout_ratio
        )
        discriminator_real_loss: torch.Tensor = discriminator_loss['fake_loss']
        discriminator_fake_loss: torch.Tensor = discriminator_loss['real_loss']
        discriminator_loss: torch.Tensor = discriminator_real_loss + discriminator_fake_loss
        self.scaler.scale(
            discriminator_loss if self.config.learning_rate_batch_repeats
            else discriminator_loss / self.config.batch_repeats
        ).backward()
        return {
            'discriminator_real_loss': discriminator_real_loss.item(),
            'discriminator_fake_loss': discriminator_fake_loss.item(),
            'discriminator_loss': discriminator_loss.item()
        }

    def train_step(self, step: int) -> S:
        assert self.model.training, f'model must be in training mode: {self.model.training}'
        assert not self.discriminator_feature_extractor.training, \
            f'discriminator feature extractor must be in eval mode: {self.discriminator_feature_extractor.training}'
        assert self.discriminator.training, f'discriminator must be in training mode: {self.discriminator.training}'

        values_list: dict[str, list[float]] = {
            'reconstruction_loss': [],
            'model_gan_loss': [],
            'adaptive_weight': [],
            'model_loss': [],
            'discriminator_real_loss': [],
            'discriminator_fake_loss': [],
            'discriminator_loss': []
        }

        self.discriminator_optimizer.zero_grad()
        for i in range(self.config.batch_repeats):
            with ddp_sync(self.discriminator, i == self.config.batch_repeats - 1):
                if self.config.free_memory_every_sub_step:
                    free_all_memory()
                curr_discriminator_loss_values: dict[str, float] = self.discriminator_sub_step()
                values_list['discriminator_real_loss'].append(curr_discriminator_loss_values['discriminator_real_loss'])
                values_list['discriminator_fake_loss'].append(curr_discriminator_loss_values['discriminator_fake_loss'])
                values_list['discriminator_loss'].append(curr_discriminator_loss_values['discriminator_loss'])
        self.scaler.step(self.discriminator_optimizer)

        self.model_optimizer.zero_grad()
        for i in range(self.config.batch_repeats):
            with ddp_sync(self.model, i == self.config.batch_repeats - 1):
                if self.config.free_memory_every_sub_step:
                    free_all_memory()
                curr_model_loss_values: dict[str, float] = self.model_sub_step()
                values_list['reconstruction_loss'].append(curr_model_loss_values['reconstruction_loss'])
                values_list['model_gan_loss'].append(curr_model_loss_values['model_gan_loss'])
                values_list['adaptive_weight'].append(curr_model_loss_values['adaptive_weight'])
                values_list['model_loss'].append(curr_model_loss_values['model_loss'])
        self.scaler.step(self.model_optimizer)

        self.scaler.update()

        values: dict[str, float] = {key: float(np.mean(values_list[key])) for key in values_list}
        for key in values:
            self.metrics[key].add(values[key])

        return StepStore(
            step=step,
            reconstruction_loss=values['reconstruction_loss'],
            model_gan_loss=values['model_gan_loss'],
            adaptive_weight=values['adaptive_weight'],
            model_loss=values['model_loss'],
            discriminator_real_loss=values['discriminator_real_loss'],
            discriminator_fake_loss=values['discriminator_fake_loss'],
            discriminator_loss=values['discriminator_loss']
        )

    def reset_values(self) -> dict[str, Any]:
        return {
            **super().reset_values(),
            **{key: self.metrics[key].reset() for key in self.metrics}
        }

    def log_example_data(self) -> None:
        self.log_images_model(
            label='model/train_example_data',
            step=0,
            sources=self.train_example_data['source'],
            labels=self.train_example_data['label'],
            sigmas=self.train_example_data['sigma']
        )
        if self.test_example_data is not None:
            self.log_images_model(
                label='model/test_example_data',
                step=0,
                sources=self.test_example_data['source'],
                labels=self.test_example_data['label'],
                sigmas=self.test_example_data['sigma']
            )

    def log_images_callback(self, step_store: StepStore) -> None:
        if self.config.report.log_images_steps is not None and \
                (step_store.step + 1) % self.config.report.log_images_steps == 0:
            Logger.debug(
                f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.log_images_callback)} {step_store.step + 1} start')
            self.log_example_data()
            Logger.debug(
                f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.log_images_callback)} {step_store.step + 1} end')

    def calculate_fid_for_dataset_and_model(
            self,
            model: torch.nn.Module,
            dataset: NoiseLabelDataset
    ) -> float:
        try:
            Logger.debug(
                f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.calculate_fid_for_dataset_and_model)} start')
            fid: float = calculate_fid_for_dataset_and_model_helper(
                model=model,
                dataset=dataset,
                reference_path=self.config.fid.reference_path,
                inference_batch_func=lambda m, n, l: inference_single_step(
                    model=m,
                    noises=n,
                    labels=l,
                    num_classes=self.config.dataset.num_classes,
                    sigma=self.config.edm_scheduler.sigma_max
                ),
                batch_size=self.config.fid.batch_size,
                dims=self.config.fid.dims
            )
            Logger.debug(
                f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.calculate_fid_for_dataset_and_model)} end - fid: {fid}')
            return fid
        except Exception as e:
            Logger.exception(
                f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.calculate_fid_for_dataset_and_model)} {e}',
                e)
            model.train()
            return 0.0

    def fid_callback(self, step_store: StepStore) -> None:
        try:
            if self.config.fid.steps is not None and (step_store.step + 1) % self.config.fid.steps == 0:
                Logger.debug(
                    f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.fid_callback)} {step_store.step + 1} start)')
                free_all_memory()
                train_fid: float = self.calculate_fid_for_dataset_and_model(
                    model=self.model,
                    dataset=self.fid_train_dataset
                )
                free_all_memory()
                test_fid: float = self.calculate_fid_for_dataset_and_model(
                    model=self.model,
                    dataset=self.fid_test_dataset
                )
                free_all_memory()
                report_values(
                    step_store.step + 1,
                    {'fid': {'train': train_fid, 'test': test_fid}},
                    self.tensorboard_logger
                )
                if self.best_fid_value is None or test_fid < self.best_fid_value:
                    self.best_fid_value = test_fid
                    Logger.info(f'saving best checkpoint {step_store.step + 1}')
                    self.save_checkpoint(f'{self.config.base_folder}/{self.config.checkpoint.folder}/best',
                                         step_store.step + 1)

                for i in range(len(self.ema_models)):
                    ema: EMA = self.ema_models[i]
                    free_all_memory()
                    ema.ema_model.to(get_default_device())
                    free_all_memory()
                    train_fid_ema: float = self.calculate_fid_for_dataset_and_model(
                        model=ema.ema_model,
                        dataset=self.fid_train_dataset
                    )
                    free_all_memory()
                    test_fid_ema: float = self.calculate_fid_for_dataset_and_model(
                        model=ema.ema_model,
                        dataset=self.fid_test_dataset
                    )
                    free_all_memory()
                    ema.ema_model.to('cpu')
                    free_all_memory()

                    report_values(step_store.step + 1, {
                        'fid': {
                            'train': {'ema': {self.config.ema[i].get_str(): train_fid_ema}},
                            'test': {'ema': {self.config.ema[i].get_str(): test_fid_ema}}
                        }
                    }, self.tensorboard_logger)
                    if self.best_fid_ema_values[i] is None or test_fid_ema < self.best_fid_ema_values[i]:
                        self.best_fid_ema_values[i] = test_fid_ema
                        Logger.info(f'saving best checkpoint {step_store.step + 1} {self.config.ema[i].get_tuple()}')
                        self.save_checkpoint(
                            f'{self.config.base_folder}/{self.config.checkpoint.folder}/best_{self.config.ema[i].get_str()}',
                            step_store.step + 1
                        )
                Logger.info(f'best test fid value {step_store.step + 1}: {self.best_fid_value}')
                Logger.info(f'best test fid ema values {step_store.step + 1}: {self.best_fid_ema_values}')
                Logger.debug(
                    f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.fid_callback)} {step_store.step + 1} end')
        except Exception as e:
            Logger.exception(f'{get_class_name(SourceTargetStyleGANXLConditionalTrainer.fid_callback)} {e}', e)
            for i in range(len(self.ema_models)):
                self.ema_models[i].ema_model.to('cpu')

    def callbacks(self) -> list[Callable[[S], None]]:
        return [
            self.report_train_loss,
            self.update_ema_callback,
            self.save_step_checkpoint_steps_callback,
            self.save_last_checkpoint_steps_callback,
            self.validate_ddp_consistency_steps_callback,
            self.log_images_callback,
            self.fid_callback,
            self.free_all_memory_callback
        ]
