from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Any, Generator, Callable

import numpy as np
import torch.distributed
import torch.utils.data
import torch.utils.data
from pydantic import Field

from src.configs.amp import AMPConfig
from src.configs.edm_model import EDMModelConfig
from src.configs.edm_sampler import EDMSamplerConfig
from src.configs.edm_scheduler import EDMSchedulerConfig
from src.configs.optimizer import OptimizerConfig
from src.configs.style_gan_xl_discriminator import StyleGANXLDiscriminatorConfig
from src.datasets.image import ImageDataset
from src.datasets.source_target_sigma import SourceTargetSigmaDataset
from src.metrics.average_metric import AverageMetric
from src.models.utils import count_parameters, count_trainable_parameters
from src.train.trainer import Config as BaseConfig, StepStore as BaseStepStore, Trainer
from src.utils.style_gan_xl.loss import StyleGANXLDiscriminatorLoss
from src.utils.utils import create_infinite_data_loader, create_one_hot, extract_model, free_all_memory
from torch_utils.distributed.distributed_manager import DistributedManager
from torch_utils.distributed.utils import ddp_sync
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.numpy.stats import get_numpy_stats
from utils.utils import get_class_name, get_object_name


class Config(BaseConfig):
    edm_scheduler: EDMSchedulerConfig = Field()
    edm_sampler: EDMSamplerConfig = Field()
    model: EDMModelConfig = Field()
    discriminator: StyleGANXLDiscriminatorConfig = Field()
    discriminator_optimizer: OptimizerConfig = Field()
    amp: AMPConfig = Field()
    distributed_sampler_seed: int = Field(default=0)
    distributed_sampler_real_seed: int = Field(default=0)


class StepStore(BaseStepStore):
    discriminator_real_loss: float = Field()
    discriminator_fake_loss: float = Field()
    discriminator_loss: float = Field()

    def get_step_values(self) -> dict[str, Any]:
        return {
            **super().get_step_values(),
            'gan_discriminator_real_loss': self.discriminator_real_loss,
            'gan_discriminator_fake_loss': self.discriminator_fake_loss,
            'gan_discriminator_loss': self.discriminator_loss
        }


C: 'C' = TypeVar('C', bound=Config)
S: 'S' = TypeVar('S', bound=StepStore)


class StyleGANXLDiscriminatorUnconditionalTrainer(Trainer[C, S], Generic[C, S], ABC):
    model: torch.nn.Module
    discriminator: torch.nn.Module
    discriminator_feature_extractor: torch.nn.Module
    discriminator_optimizer: torch.optim.Optimizer
    scaler: torch.cuda.amp.GradScaler
    discriminator_loss: StyleGANXLDiscriminatorLoss
    metrics: dict[str, AverageMetric]
    train_sampler: torch.utils.data.DistributedSampler
    train_data_loader: torch.utils.data.DataLoader
    train_infinite_data_loader: Generator[dict[str, torch.Tensor], None, None]
    train_real_sampler: torch.utils.data.DistributedSampler
    train_real_data_loader: torch.utils.data.DataLoader
    train_real_infinite_data_loader: Generator[dict[str, torch.Tensor], None, None]
    train_example_data: dict[str, np.ndarray]
    test_example_data: dict[str, np.ndarray]
    train_real_example_data: dict[str, np.ndarray]
    test_real_example_data: dict[str, np.ndarray]

    def __init__(self, config: C):
        super().__init__(config)

    @property
    @abstractmethod
    def train_dataset(self) -> SourceTargetSigmaDataset:
        raise NotImplementedError('train_dataset method must be implemented')

    @property
    @abstractmethod
    def test_dataset(self) -> SourceTargetSigmaDataset:
        raise NotImplementedError('test_dataset method must be implemented')

    @property
    @abstractmethod
    def train_real_dataset(self) -> ImageDataset:
        raise NotImplementedError('train_dataset method must be implemented')

    @property
    @abstractmethod
    def test_real_dataset(self) -> ImageDataset:
        raise NotImplementedError('test_dataset method must be implemented')

    def start_callback(self) -> None:
        Logger.debug(f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.start_callback)} start')
        super().start_callback()

        assert isinstance(self.train_dataset, SourceTargetSigmaDataset), \
            f'train_dataset must be an instance of SourceTargetSigmaDataset: {self.train_dataset}'
        assert self.test_dataset is None or isinstance(self.test_dataset, SourceTargetSigmaDataset), \
            f'test_dataset must be an instance of SourceTargetSigmaDataset: {self.test_dataset}'
        assert isinstance(self.train_real_dataset, ImageDataset), \
            f'train_real_dataset must be an instance of ImageDataset: {self.train_real_dataset}'
        assert self.test_real_dataset is None or isinstance(self.test_real_dataset, ImageDataset), \
            f'test_real_dataset must be an instance of ImageDataset: {self.test_real_dataset}'

        self.model: torch.nn.Module = self.config.model.get_model().requires_grad_(False).eval().to(
            get_default_device())
        Logger.debug(f'model parameters: {count_parameters(self.model)}')
        Logger.debug(f'model trainable parameters: {count_trainable_parameters(self.model)}')
        Logger.debug(f'model: {get_object_name(self.model)}')

        self.discriminator, self.discriminator_feature_extractor = \
            self.config.discriminator.get_discriminator(self.config.dataset.image_shape[1])
        self.discriminator.train().to(get_default_device())
        self.discriminator_feature_extractor.requires_grad_(False).eval().to(get_default_device())

        Logger.debug(f'discriminator parameters: {count_parameters(self.discriminator)}')
        Logger.debug(f'discriminator trainable parameters: {count_trainable_parameters(self.discriminator)}')
        Logger.debug(
            f'discriminator feature extractor parameters: {count_parameters(self.discriminator_feature_extractor)}')
        Logger.debug(
            f'discriminator feature extractor trainable parameters: '
            f'{count_trainable_parameters(self.discriminator_feature_extractor)}'
        )

        self.discriminator: torch.nn.Module = (
            torch.nn.parallel.DistributedDataParallel(self.discriminator, broadcast_buffers=False)
            if DistributedManager.initialized else self.discriminator
        )
        self.discriminator_optimizer: torch.optim.Optimizer = self.config.discriminator_optimizer.get_optimizer(
            self.discriminator,
            batch_repeats=self.config.batch_repeats,
            learning_rate_batch_repeats=self.config.learning_rate_batch_repeats
        )

        Logger.debug(f'discriminator: {get_object_name(self.discriminator)}')
        Logger.debug(f'discriminator_feature_extractor: {get_object_name(self.discriminator_feature_extractor)}')
        Logger.debug(f'discriminator_optimizer: {get_object_name(self.discriminator_optimizer)}')

        self.scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available() and self.config.amp.use_gard_scaler)

        self.discriminator_loss = StyleGANXLDiscriminatorLoss(
            discriminator_feature_extractor=self.discriminator_feature_extractor,
            discriminator=self.discriminator
        )

        self.metrics = {
            'discriminator_real_loss': AverageMetric(),
            'discriminator_fake_loss': AverageMetric(),
            'discriminator_loss': AverageMetric()
        }

        Logger.info('creating data loaders')
        self.train_sampler = torch.utils.data.DistributedSampler(
            dataset=self.train_dataset,
            num_replicas=DistributedManager.world_size,
            rank=DistributedManager.rank,
            shuffle=True,
            drop_last=False,
            seed=self.config.distributed_sampler_seed
        ) if DistributedManager.initialized else None
        self.train_data_loader = torch.utils.data.DataLoader(
            dataset=self.train_dataset,
            sampler=self.train_sampler,
            batch_size=self.config.batch_size // DistributedManager.world_size
            if DistributedManager.initialized else self.config.batch_size,
            shuffle=True if self.train_sampler is None else None,
            num_workers=self.config.data_loader_workers,
            drop_last=False
        )
        self.train_infinite_data_loader = create_infinite_data_loader(self.train_data_loader, self.train_sampler)

        self.train_real_sampler = torch.utils.data.DistributedSampler(
            dataset=self.train_real_dataset,
            num_replicas=DistributedManager.world_size,
            rank=DistributedManager.rank,
            shuffle=True,
            drop_last=False,
            seed=self.config.distributed_sampler_real_seed
        ) if DistributedManager.initialized else None
        self.train_real_data_loader = torch.utils.data.DataLoader(
            dataset=self.train_real_dataset,
            sampler=self.train_real_sampler,
            batch_size=self.config.batch_size // DistributedManager.world_size
            if DistributedManager.initialized else self.config.batch_size,
            shuffle=True if self.train_real_sampler is None else None,
            num_workers=self.config.data_loader_workers,
            drop_last=False
        )
        self.train_real_infinite_data_loader = \
            create_infinite_data_loader(self.train_real_data_loader, self.train_real_sampler)

        Logger.info('creating example data')
        self.train_example_data = self.train_dataset.merge_data([self.train_dataset[i] for i in range(10)])
        self.test_example_data = self.test_dataset.merge_data([self.test_dataset[i] for i in range(10)]) \
            if self.test_dataset is not None else None
        self.train_real_example_data = self.train_real_dataset.merge_data(
            [self.train_real_dataset[i] for i in range(10)])
        self.test_real_example_data = self.test_real_dataset.merge_data([self.test_real_dataset[i] for i in range(10)]) \
            if self.test_real_dataset is not None else None

        Logger.debug(str({
            'train_example_data':
                {key: get_numpy_stats(self.train_example_data[key]) for key in self.train_example_data}
        }))
        Logger.debug(str({
            'train_real_example_data':
                {key: get_numpy_stats(self.train_real_example_data[key]) for key in self.train_real_example_data}
        }))

        if self.test_example_data is not None:
            Logger.debug(str({
                'test_example_data':
                    {key: get_numpy_stats(self.test_example_data[key]) for key in self.test_example_data}
            }))

        if self.test_real_example_data is not None:
            Logger.debug(str({
                'test_real_example_data':
                    {key: get_numpy_stats(self.test_real_example_data[key]) for key in self.test_real_example_data}
            }))

        Logger.info('logging images')
        self.log_example_data()
        self.tensorboard_logger.log_images(
            tag='dataset/train_example_data',
            images=self.train_example_data['target'],
            step=0
        )
        if self.test_example_data is not None:
            self.tensorboard_logger.log_images(
                tag='dataset/test_example_data',
                images=self.test_example_data['target'],
                step=0
            )
        self.tensorboard_logger.log_images(
            tag='dataset/train_real_example_data',
            images=self.train_real_example_data['image'],
            step=0
        )
        if self.test_real_example_data is not None:
            self.tensorboard_logger.log_images(
                tag='dataset/test_real_example_data',
                images=self.test_real_example_data['image'],
                step=0
            )
        Logger.debug(f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.start_callback)} end')

    def log_images_model(
            self,
            label: str,
            step: int,
            sources: np.ndarray,
            sigmas: np.ndarray,
            labels: np.ndarray = None,
    ) -> None:
        with torch.inference_mode():
            images: np.ndarray = self.model(
                torch.from_numpy(sources).to(get_default_device()),
                torch.from_numpy(sigmas).to(get_default_device()),
                torch.from_numpy(create_one_hot(
                    labels,
                    self.config.dataset.num_classes
                )).to(get_default_device()) if labels is not None else None
            ).detach().cpu().numpy()
        self.tensorboard_logger.log_images(label, images, step)
        Logger.debug(f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.log_images_model)} end')

    def check_consistency(self, func: Callable[[torch.nn.Module], bool]) -> bool:
        Logger.debug(f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.check_consistency)} start')
        value: bool = func(self.model) and func(self.discriminator) and func(self.discriminator_feature_extractor)
        free_all_memory()
        Logger.debug(
            f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.check_consistency)} end - value: {value}')
        return value

    def load_from_folder(self, folder: str) -> None:
        Logger.debug(
            f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.load_from_folder)} start - folder: {folder}')
        extract_model(self.model).model.load_state_dict(
            torch.load(f'{folder}/model.pth', map_location=get_default_device()))
        extract_model(self.discriminator).discriminator.load_state_dict(
            torch.load(f'{folder}/discriminator.pth', map_location=get_default_device()))
        extract_model(self.discriminator_feature_extractor).discriminator_feature_extractor.load_state_dict(
            torch.load(f'{folder}/discriminator_feature_extractor.pth', map_location=get_default_device()))
        self.discriminator_optimizer.load_state_dict(
            torch.load(f'{folder}/discriminator_optimizer.pth', map_location=get_default_device()))
        Logger.debug(f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.load_from_folder)} end')

    def get_state_dicts(self, step: int) -> dict[str, dict]:
        Logger.debug(
            f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.get_state_dicts)} start - step: {step}')
        return {
            **super().get_state_dicts(step),
            'model': extract_model(self.model).model.state_dict(),
            'discriminator': extract_model(self.discriminator).discriminator.state_dict(),
            'discriminator_feature_extractor':
                extract_model(self.discriminator_feature_extractor).discriminator_feature_extractor.state_dict(),
            'discriminator_optimizer': self.discriminator_optimizer.state_dict()
        }

    def sub_step(self) -> dict[str, float]:
        batch_data: dict[str, torch.Tensor] = next(self.train_infinite_data_loader)
        batch_real_data: dict[str, torch.Tensor] = next(self.train_real_infinite_data_loader)
        with torch.autocast(
                enabled=torch.cuda.is_available() and self.config.amp.use_autocast,
                device_type=get_default_device(),
                dtype=torch.float16 if self.config.amp.use_fp16 else torch.float32
        ):
            with torch.inference_mode():
                generated: torch.Tensor = self.model(
                    batch_data['source'].to(get_default_device()),
                    batch_data['sigma'].to(get_default_device())
                )
        discriminator_loss: dict[str, torch.Tensor] = self.discriminator_loss.get_discriminator_loss(
            real=batch_real_data['image'].to(torch.float32).to(get_default_device()),
            fake=generated.detach(),
            prob_aug=self.config.discriminator.prob_aug,
            shift_ratio=self.config.discriminator.shift_ratio,
            cutout_ratio=self.config.discriminator.cutout_ratio
        )
        discriminator_real_loss: torch.Tensor = discriminator_loss['fake_loss']
        discriminator_fake_loss: torch.Tensor = discriminator_loss['real_loss']
        discriminator_loss: torch.Tensor = discriminator_real_loss + discriminator_fake_loss
        self.scaler.scale(
            discriminator_loss if self.config.learning_rate_batch_repeats
            else discriminator_loss / self.config.batch_repeats
        ).backward()
        return {
            'discriminator_real_loss': discriminator_real_loss.item(),
            'discriminator_fake_loss': discriminator_fake_loss.item(),
            'discriminator_loss': discriminator_loss.item()
        }

    def train_step(self, step: int) -> S:
        assert not self.model.training, f'model must be in eval mode: {self.model.training}'
        assert not self.discriminator_feature_extractor.training, \
            f'discriminator feature extractor must be in eval mode: {self.discriminator_feature_extractor.training}'
        assert self.discriminator.training, f'discriminator must be in training mode: {self.discriminator.training}'

        values_list: dict[str, list[float]] = {
            'discriminator_real_loss': [],
            'discriminator_fake_loss': [],
            'discriminator_loss': []
        }

        self.discriminator_optimizer.zero_grad()
        for i in range(self.config.batch_repeats):
            with ddp_sync(self.discriminator, i == self.config.batch_repeats - 1):
                if self.config.free_memory_every_sub_step:
                    free_all_memory()
                curr_discriminator_loss_values: dict[str, float] = self.sub_step()
                values_list['discriminator_real_loss'].append(curr_discriminator_loss_values['discriminator_real_loss'])
                values_list['discriminator_fake_loss'].append(curr_discriminator_loss_values['discriminator_fake_loss'])
                values_list['discriminator_loss'].append(curr_discriminator_loss_values['discriminator_loss'])
        self.scaler.step(self.discriminator_optimizer)
        self.scaler.update()

        values: dict[str, float] = {key: float(np.mean(values_list[key])) for key in values_list}
        for key in values:
            self.metrics[key].add(values[key])

        return StepStore(
            step=step,
            discriminator_real_loss=values['discriminator_real_loss'],
            discriminator_fake_loss=values['discriminator_fake_loss'],
            discriminator_loss=values['discriminator_loss']
        )

    def reset_values(self) -> dict[str, Any]:
        return {
            **super().reset_values(),
            **{key: self.metrics[key].reset() for key in self.metrics}
        }

    def log_example_data(self) -> None:
        self.log_images_model(
            label='model/train_example_data',
            step=0,
            sources=self.train_example_data['source'],
            labels=None,
            sigmas=self.train_example_data['sigma']
        )
        if self.test_example_data is not None:
            self.log_images_model(
                label='model/test_example_data',
                step=0,
                sources=self.test_example_data['source'],
                labels=None,
                sigmas=self.test_example_data['sigma']
            )

    def log_images_callback(self, step_store: StepStore) -> None:
        if self.config.report.log_images_steps is not None and \
                (step_store.step + 1) % self.config.report.log_images_steps == 0:
            Logger.debug(
                f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.log_images_callback)} {step_store.step + 1} start')
            self.log_example_data()
            Logger.debug(
                f'{get_class_name(StyleGANXLDiscriminatorUnconditionalTrainer.log_images_callback)} {step_store.step + 1} end')

    def callbacks(self) -> list[Callable[[S], None]]:
        return [
            self.report_train_loss,
            self.save_step_checkpoint_steps_callback,
            self.save_last_checkpoint_steps_callback,
            self.validate_ddp_consistency_steps_callback,
            self.log_images_callback,
            self.free_all_memory_callback
        ]
