from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Any, Generator, Callable, Optional

import numpy as np
import torch.distributed
import torch.utils.data
import torch.utils.data
from ema_pytorch import EMA
from pydantic import Field

from src.configs.amp import AMPConfig
from src.configs.edm_model import EDMModelConfig
from src.configs.edm_sampler import EDMSamplerConfig
from src.configs.edm_scheduler import EDMSchedulerConfig
from src.configs.ema import EMAConfig
from src.configs.fid import FIDConfig
from src.configs.loss import LossConfig
from src.configs.optimizer import OptimizerConfig
from src.datasets.noise_label import NoiseLabelDataset
from src.datasets.source_target_label_sigma import SourceTargetLabelSigmaDataset
from src.fid.single_step import inference_single_step
from src.fid.utils import calculate_fid_for_dataset_and_model as calculate_fid_for_dataset_and_model_helper
from src.metrics.average_metric import AverageMetric
from src.models.utils import count_parameters, count_trainable_parameters
from src.train.trainer import Config as BaseConfig, StepStore as BaseStepStore, Trainer
from src.utils.utils import create_infinite_data_loader, create_one_hot, report_values, free_all_memory, extract_model
from src.utils.utils import create_one_hot_torch
from torch_utils.distributed.distributed_manager import DistributedManager
from torch_utils.distributed.utils import ddp_sync
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.numpy.stats import get_numpy_stats
from utils.utils import get_class_name, get_object_name


class Config(BaseConfig):
    edm_scheduler: EDMSchedulerConfig = Field()
    edm_sampler: EDMSamplerConfig = Field()
    model: EDMModelConfig = Field()
    model_optimizer: OptimizerConfig = Field()
    ema: list[EMAConfig] = Field()
    fid: FIDConfig = Field()
    reconstruction_loss: LossConfig = Field()
    clip_train_data: bool = Field(default=False)
    amp: AMPConfig = Field()
    distributed_sampler_seed: int = Field(default=0)


class StepStore(BaseStepStore):
    reconstruction_loss: float = Field()

    def get_step_values(self) -> dict[str, Any]:
        return {
            **super().get_step_values(),
            'reconstruction_loss': self.reconstruction_loss
        }


C: 'C' = TypeVar('C', bound=Config)
S: 'S' = TypeVar('S', bound=StepStore)


class SourceTargetConditionalTrainer(Trainer[C, S], Generic[C, S], ABC):
    model: torch.nn.Module
    model_optimizer: torch.optim.Optimizer
    scaler: torch.cuda.amp.GradScaler
    ema_models: list[EMA]
    best_fid_value: Optional[float]
    best_fid_ema_values: list[Optional[float]]
    reconstruction_loss: torch.nn.Module
    reconstruction_loss_metric: AverageMetric
    train_sampler: torch.utils.data.DistributedSampler
    train_data_loader: torch.utils.data.DataLoader
    train_infinite_data_loader: Generator[dict[str, torch.Tensor], None, None]
    train_example_data: dict[str, np.ndarray]
    test_example_data: dict[str, np.ndarray]

    def __init__(self, config: C):
        super().__init__(config)

    @property
    @abstractmethod
    def train_dataset(self) -> SourceTargetLabelSigmaDataset:
        raise NotImplementedError('train_dataset method must be implemented')

    @property
    @abstractmethod
    def test_dataset(self) -> SourceTargetLabelSigmaDataset:
        raise NotImplementedError('test_dataset method must be implemented')

    @property
    @abstractmethod
    def fid_train_dataset(self) -> NoiseLabelDataset:
        raise NotImplementedError('fid_train_dataset method must be implemented')

    @property
    @abstractmethod
    def fid_test_dataset(self) -> NoiseLabelDataset:
        raise NotImplementedError('fid_test_dataset method must be implemented')

    def start_callback(self) -> None:
        Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.start_callback)} start')
        super().start_callback()

        assert isinstance(self.train_dataset, SourceTargetLabelSigmaDataset), \
            f'train_dataset must be an instance of SourceTargetLabelSigmaDataset: {self.train_dataset}'
        assert self.test_dataset is None or isinstance(self.test_dataset, SourceTargetLabelSigmaDataset), \
            f'test_dataset must be an instance of SourceTargetLabelSigmaDataset: {self.test_dataset}'
        assert isinstance(self.fid_train_dataset, NoiseLabelDataset), \
            f'fid_train_dataset must be an instance of NoiseLabelDataset: {self.fid_train_dataset}'
        assert isinstance(self.fid_test_dataset, NoiseLabelDataset), \
            f'fid_test_dataset must be an instance of NoiseLabelDataset: {self.fid_test_dataset}'

        self.model: torch.nn.Module = self.config.model.get_model().train().to(get_default_device())
        Logger.debug(f'model parameters: {count_parameters(self.model)}')
        Logger.debug(f'model trainable parameters: {count_trainable_parameters(self.model)}')

        self.ema_models: list[EMA] = []
        for i in range(len(self.config.ema)):
            ema: EMA = self.config.ema[i].get_ema(self.model)
            ema.ema_model.to('cpu')
            self.ema_models.append(ema)

        self.model: torch.nn.Module = (
            torch.nn.parallel.DistributedDataParallel(self.model, find_unused_parameters=True)
            if DistributedManager.initialized else self.model
        )
        self.model_optimizer: torch.optim.Optimizer = self.config.model_optimizer.get_optimizer(
            self.model,
            batch_repeats=self.config.batch_repeats,
            learning_rate_batch_repeats=self.config.learning_rate_batch_repeats
        )

        Logger.debug(f'model: {get_object_name(self.model)}')
        Logger.debug(f'model_optimizer: {get_object_name(self.model_optimizer)}')

        self.scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available() and self.config.amp.use_gard_scaler)

        self.best_fid_value = None
        self.best_fid_ema_values = [None for _ in range(len(self.config.ema))]

        self.reconstruction_loss: torch.nn.Module = self.config.reconstruction_loss.get_loss()
        self.reconstruction_loss.to(get_default_device())

        self.reconstruction_loss_metric: AverageMetric = AverageMetric()

        Logger.info('creating data loaders')
        self.train_sampler = torch.utils.data.DistributedSampler(
            dataset=self.train_dataset,
            num_replicas=DistributedManager.world_size,
            rank=DistributedManager.rank,
            shuffle=True,
            drop_last=False,
            seed=self.config.distributed_sampler_seed
        ) if DistributedManager.initialized else None
        self.train_data_loader = torch.utils.data.DataLoader(
            dataset=self.train_dataset,
            sampler=self.train_sampler,
            batch_size=self.config.batch_size // DistributedManager.world_size
            if DistributedManager.initialized else self.config.batch_size,
            shuffle=True if self.train_sampler is None else None,
            num_workers=self.config.data_loader_workers,
            drop_last=False
        )
        self.train_infinite_data_loader = create_infinite_data_loader(self.train_data_loader, self.train_sampler)

        Logger.info('creating example data')
        self.train_example_data = self.train_dataset.merge_data([self.train_dataset[i] for i in range(10)])
        self.test_example_data = self.test_dataset.merge_data([self.test_dataset[i] for i in range(10)]) \
            if self.test_dataset is not None else None

        Logger.debug(str({
            'train_example_data':
                {key: get_numpy_stats(self.train_example_data[key]) for key in self.train_example_data}
        }))

        if self.test_example_data is not None:
            Logger.debug(str({
                'test_example_data':
                    {key: get_numpy_stats(self.test_example_data[key]) for key in self.test_example_data}
            }))

        Logger.info('logging images')
        self.log_example_data()
        self.tensorboard_logger.log_images(
            tag='dataset/train_example_data',
            images=self.train_example_data['target'],
            step=0
        )
        if self.test_example_data is not None:
            self.tensorboard_logger.log_images(
                tag='dataset/test_example_data',
                images=self.test_example_data['target'],
                step=0
            )
        Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.start_callback)} end')

    def log_images_model(
            self,
            label: str,
            step: int,
            sources: np.ndarray,
            sigmas: np.ndarray,
            labels: np.ndarray = None,
    ) -> None:
        Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.log_images_model)} start')
        self.model.eval()
        with torch.inference_mode():
            images: np.ndarray = self.model(
                torch.from_numpy(sources).to(get_default_device()),
                torch.from_numpy(sigmas).to(get_default_device()),
                torch.from_numpy(create_one_hot(
                    labels,
                    self.config.dataset.num_classes
                )).to(get_default_device()) if labels is not None else None
            ).detach().cpu().numpy()
        self.tensorboard_logger.log_images(label, images, step)
        self.model.train()
        Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.log_images_model)} end')

    def update_ema_callback(self, _: StepStore) -> None:
        for i in range(len(self.ema_models)):
            self.ema_models[i].update()

    def check_consistency(self, func: Callable[[torch.nn.Module], bool]) -> bool:
        Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.check_consistency)} start')
        value: bool = func(self.model)
        free_all_memory()
        Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.check_consistency)} end - value: {value}')
        return value

    def load_from_folder(self, folder: str) -> None:
        Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.load_from_folder)} start - folder: {folder}')
        extract_model(self.model).load_state_dict(torch.load(f'{folder}/model.pth', map_location=get_default_device()))
        self.model_optimizer.load_state_dict(
            torch.load(f'{folder}/model_optimizer.pth', map_location=get_default_device()))
        for i in range(len(self.ema_models)):
            self.ema_models[i].ema_model.load_state_dict(
                torch.load(f'{folder}/ema/{self.config.ema[i].get_str()}.pth', map_location=get_default_device()))
        Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.load_from_folder)} end')

    def get_state_dicts(self, step: int) -> dict[str, dict]:
        Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.get_state_dicts)} start - step: {step}')
        return {
            **super().get_state_dicts(step),
            'model': extract_model(self.model).state_dict(),
            'model_optimizer': self.model_optimizer.state_dict(),
            **{
                f'ema/{self.config.ema[i].get_str()}': self.ema_models[i].ema_model.state_dict()
                for i in range(len(self.ema_models))
            }
        }

    def sub_step(self) -> float:
        batch_data: dict[str, torch.Tensor] = next(self.train_infinite_data_loader)
        with torch.autocast(
                enabled=torch.cuda.is_available() and self.config.amp.use_autocast,
                device_type=get_default_device(),
                dtype=torch.float16 if self.config.amp.use_fp16 else torch.float32
        ):
            generated: torch.Tensor = self.model(
                batch_data['source'].to(get_default_device()),
                batch_data['sigma'].to(get_default_device()),
                create_one_hot_torch(batch_data['label'], self.config.dataset.num_classes).to(get_default_device())
            )
        target: torch.Tensor = batch_data['target'].to(torch.float32).to(get_default_device())
        if self.config.clip_train_data:
            target: torch.Tensor = torch.clip(target, -1, 1)
        reconstruction_loss_value: torch.Tensor = self.reconstruction_loss(generated, target).mean()
        self.scaler.scale(
            reconstruction_loss_value if self.config.learning_rate_batch_repeats
            else reconstruction_loss_value / self.config.batch_repeats
        ).backward()
        return reconstruction_loss_value.item()

    def train_step(self, step: int) -> S:
        assert self.model.training, f'model must be in training mode: {self.model.training}'

        reconstruction_loss_values: list[float] = []

        self.model_optimizer.zero_grad()
        for i in range(self.config.batch_repeats):
            with ddp_sync(self.model, i == self.config.batch_repeats - 1):
                if self.config.free_memory_every_sub_step:
                    free_all_memory()
                curr_reconstruction_loss_value: float = self.sub_step()
                reconstruction_loss_values.append(curr_reconstruction_loss_value)
        self.scaler.step(self.model_optimizer)
        self.scaler.update()

        reconstruction_loss_value: float = float(np.mean(reconstruction_loss_values))

        self.reconstruction_loss_metric.add(reconstruction_loss_value)
        return StepStore(step=step, reconstruction_loss=reconstruction_loss_value)

    def reset_values(self) -> dict[str, Any]:
        return {
            **super().reset_values(),
            'reconstruction_loss': self.reconstruction_loss_metric.reset()
        }

    def log_example_data(self) -> None:
        self.log_images_model(
            label='model/train_example_data',
            step=0,
            sources=self.train_example_data['source'],
            labels=self.train_example_data['label'],
            sigmas=self.train_example_data['sigma']
        )
        if self.test_example_data is not None:
            self.log_images_model(
                label='model/test_example_data',
                step=0,
                sources=self.test_example_data['source'],
                labels=self.test_example_data['label'],
                sigmas=self.test_example_data['sigma']
            )

    def log_images_callback(self, step_store: StepStore) -> None:
        if self.config.report.log_images_steps is not None and \
                (step_store.step + 1) % self.config.report.log_images_steps == 0:
            Logger.debug(
                f'{get_class_name(SourceTargetConditionalTrainer.log_images_callback)} {step_store.step + 1} start')
            self.log_example_data()
            Logger.debug(
                f'{get_class_name(SourceTargetConditionalTrainer.log_images_callback)} {step_store.step + 1} end')

    def calculate_fid_for_dataset_and_model(
            self,
            model: torch.nn.Module,
            dataset: NoiseLabelDataset
    ) -> float:
        try:
            Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.calculate_fid_for_dataset_and_model)} start')
            fid: float = calculate_fid_for_dataset_and_model_helper(
                model=model,
                dataset=dataset,
                reference_path=self.config.fid.reference_path,
                inference_batch_func=lambda m, n, l: inference_single_step(
                    model=m,
                    noises=n,
                    labels=l,
                    num_classes=self.config.dataset.num_classes,
                    sigma=self.config.edm_scheduler.sigma_max
                ),
                batch_size=self.config.fid.batch_size,
                dims=self.config.fid.dims
            )
            Logger.debug(
                f'{get_class_name(SourceTargetConditionalTrainer.calculate_fid_for_dataset_and_model)} end - fid: {fid}')
            return fid
        except Exception as e:
            Logger.exception(
                f'{get_class_name(SourceTargetConditionalTrainer.calculate_fid_for_dataset_and_model)} {e}', e)
            model.train()
            return 0.0

    def fid_callback(self, step_store: StepStore) -> None:
        try:
            if self.config.fid.steps is not None and (step_store.step + 1) % self.config.fid.steps == 0:
                Logger.debug(
                    f'{get_class_name(SourceTargetConditionalTrainer.fid_callback)} {step_store.step + 1} start)')
                free_all_memory()
                train_fid: float = self.calculate_fid_for_dataset_and_model(
                    model=self.model,
                    dataset=self.fid_train_dataset
                )
                free_all_memory()
                test_fid: float = self.calculate_fid_for_dataset_and_model(
                    model=self.model,
                    dataset=self.fid_test_dataset
                )
                free_all_memory()
                report_values(
                    step_store.step + 1,
                    {'fid': {'train': train_fid, 'test': test_fid}},
                    self.tensorboard_logger
                )
                if self.best_fid_value is None or test_fid < self.best_fid_value:
                    self.best_fid_value = test_fid
                    Logger.info(f'saving best checkpoint {step_store.step + 1}')
                    self.save_checkpoint(f'{self.config.base_folder}/{self.config.checkpoint.folder}/best',
                                         step_store.step + 1)

                for i in range(len(self.ema_models)):
                    ema: EMA = self.ema_models[i]
                    free_all_memory()
                    ema.ema_model.to(get_default_device())
                    free_all_memory()
                    train_fid_ema: float = self.calculate_fid_for_dataset_and_model(
                        model=ema.ema_model,
                        dataset=self.fid_train_dataset
                    )
                    free_all_memory()
                    test_fid_ema: float = self.calculate_fid_for_dataset_and_model(
                        model=ema.ema_model,
                        dataset=self.fid_test_dataset
                    )
                    free_all_memory()
                    ema.ema_model.to('cpu')
                    free_all_memory()

                    report_values(step_store.step + 1, {
                        'fid': {
                            'train': {'ema': {self.config.ema[i].get_str(): train_fid_ema}},
                            'test': {'ema': {self.config.ema[i].get_str(): test_fid_ema}}
                        }
                    }, self.tensorboard_logger)
                    if self.best_fid_ema_values[i] is None or test_fid_ema < self.best_fid_ema_values[i]:
                        self.best_fid_ema_values[i] = test_fid_ema
                        Logger.info(f'saving best checkpoint {step_store.step + 1} {self.config.ema[i].get_tuple()}')
                        self.save_checkpoint(
                            f'{self.config.base_folder}/{self.config.checkpoint.folder}/best_{self.config.ema[i].get_str()}',
                            step_store.step + 1
                        )
                Logger.info(f'best test fid value {step_store.step + 1}: {self.best_fid_value}')
                Logger.info(f'best test fid ema values {step_store.step + 1}: {self.best_fid_ema_values}')
                Logger.debug(f'{get_class_name(SourceTargetConditionalTrainer.fid_callback)} {step_store.step + 1} end')
        except Exception as e:
            Logger.exception(f'{get_class_name(SourceTargetConditionalTrainer.fid_callback)} {e}', e)
            for i in range(len(self.ema_models)):
                self.ema_models[i].ema_model.to('cpu')

    def callbacks(self) -> list[Callable[[S], None]]:
        return [
            self.report_train_loss,
            self.update_ema_callback,
            self.save_step_checkpoint_steps_callback,
            self.save_last_checkpoint_steps_callback,
            self.validate_ddp_consistency_steps_callback,
            self.log_images_callback,
            self.fid_callback,
            self.free_all_memory_callback
        ]
