from typing import Callable

import numpy as np
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from torch_ema import ExponentialMovingAverage

from .models.unified_ctime_diffusion import UnifiedCtimeDiffusion
from .modules.main_modules import Model
from .modules.main_modules import UniModMLP
from ...base import *
from ...utils import create_hook, get_kwargs
import pandas as pd
from copy import deepcopy


class EMAMixInTabDiff:
    """
    A mixin class that adds Exponential Moving Average (EMA) functionality to the model.
    This class hooks into various stages of the training and validation process to
    ensure that the EMA is updated and transferred correctly.

    Methods:
        _ema_to: Moves the EMA to the correct device and dtype.
        _ema_update: Updates the EMA with the current model parameters.
        _ema_save: Saves the EMA state to the checkpoint.
        _ema_load: Loads the EMA state from the checkpoint.
    """

    def __new__(cls, *args, **kwargs):
        """
        Overrides __new__ to attach hooks for EMA updates and transfers during training and checkpointing.
        """
        cls.on_train_start = create_hook(cls.on_train_start, cls._ema_to)
        cls.on_validation_epoch_start = create_hook(cls.on_validation_epoch_start, cls._ema_to)
        cls.on_train_epoch_end = create_hook(cls.on_train_epoch_end, cls._ema_update)
        cls.on_load_checkpoint = create_hook(cls.on_load_checkpoint, cls._ema_load)
        cls.on_save_checkpoint = create_hook(cls.on_save_checkpoint, cls._ema_save)
        return super().__new__(cls)

    def _after_init(self, *args, **kwargs):
        try:
            self.ema_decay = kwargs['model']['ema_decay']
        except:
            self.ema_decay = 0.997
        self.ema = None
        return super()._after_init(*args, **kwargs)

    def _ema_to(self, *args, **kwargs):
        """
        Moves the EMA object to the same device and dtype as the model's parameters.
        """
        if self.ema is None:
            self.ema = ExponentialMovingAverage(self.model.parameters(), self.ema_decay)
        self.ema.to(self.device, self.dtype)

    def _ema_update(self, *args, **kwargs):
        """
        Updates the EMA with the current state of the model's parameters.
        """
        self.ema.update()

    def _ema_save(self, checkpoint) -> None:
        """
        Saves the EMA state to the checkpoint.

        Args:
            checkpoint (dict): The checkpoint dictionary where the EMA state will be saved.
        """
        checkpoint['ema'] = self.ema.state_dict()

    def _ema_load(self, checkpoint) -> None:
        """
        Loads the EMA state from the checkpoint.

        Args:
            checkpoint (dict): The checkpoint dictionary from which the EMA state will be loaded.
        """
        if self.ema is None:
            self.ema = ExponentialMovingAverage(self.model.parameters(), self.ema_decay)
        self.ema.load_state_dict(checkpoint['ema'])


class TabDiff(
    EMAMixInTabDiff,
    BaseImputerMixIn,
    BaseUnconditionalGeneratorMixIn,
    BaseArbitraryConditionalGeneratorMixIn,
    BaseImbalanceMixin,
    Base
):

    """
    https://github.com/MinkaiXu/TabDiff
    Shi, J., Xu, M., Hua, H., Zhang, H., Ermon, S., & Leskovec, J.
    TabDiff: a Mixed-type Diffusion Model for Tabular Data Generation.
    In The Thirteenth International Conference on Learning Representations.
    """

    def __init__(self, num_layers, d_token, n_head, factor, bias, dim_t, use_mlp,
                 precond, sigma_data, net_conditioning, num_timesteps, scheduler, cat_scheduler, noise_dist,
                 lr, weight_decay, reduce_lr_patience, reduce_lr_factor, noise_dist_params, noise_schedule_params,
                 max_epochs,
                 **kwargs):
        super().__init__(**get_kwargs(**kwargs))

        self.model = UnifiedCtimeDiffusion(
            num_classes=np.array(self.n_categories_per_columns),
            num_numerical_features=self.numerical_dim,
            denoise_fn=Model(
                UniModMLP(d_numerical=self.numerical_dim,
                          categories=(np.array(self.n_categories_per_columns) + 1).tolist(),
                          num_layers=num_layers, d_token=d_token, n_head=n_head, factor=factor,
                          bias=bias, dim_t=dim_t, use_mlp=use_mlp),
                precond=precond, sigma_data=sigma_data, net_conditioning=net_conditioning
            ),
            y_only_model=None,
            num_timesteps=num_timesteps,
            scheduler=scheduler,
            cat_scheduler=cat_scheduler,
            noise_dist=noise_dist,
            sampler_params=dict(stochastic_sampler=True, second_order_correction=True),
            edm_params=dict(precond=precond, sigma_data=sigma_data, net_conditioning=net_conditioning),
            noise_dist_params=noise_dist_params,
            noise_schedule_params=noise_schedule_params,
        )
        self.lr = lr
        self.weight_decay = weight_decay
        self.reduce_lr_patience = reduce_lr_patience
        self.reduce_lr_factor = reduce_lr_factor
        self.max_epochs = max_epochs
        self.distribution_info = None
        self.register_buffer('mean', torch.zeros(1, self.numerical_dim))

    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = x[0]
        x = x.to(self.device, self.dtype)
        return self.model.mixed_loss(x, self.mean)

    def training_step(self, x):
        closs_weight = 1 - (self.trainer.current_epoch / self.max_epochs)
        dloss, closs = self(x)
        loss = dloss + closs_weight * closs
        total_loss = np.around(dloss.item(), 4) + np.around(closs.item(), 4)
        self.log('t.loss', total_loss.item(), on_step=False, on_epoch=True, prog_bar=True)
        self.log('t.dloss', dloss.item(), on_step=False, on_epoch=True, prog_bar=True)
        self.log('t.closs', closs.item(), on_step=False, on_epoch=True, prog_bar=True)
        self.log('t.closs_weight', closs_weight, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, x):
        if self.current_epoch < self.max_epochs / 2:
            self.log('ema.loss', np.inf, on_step=False, on_epoch=True, prog_bar=True)
        else:
            with self.ema.average_parameters():
                dloss, closs = self(x)
            total_loss = dloss + closs
            self.log('ema.loss', total_loss.item(), on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min',
                                                                       factor=self.reduce_lr_factor,
                                                                       patience=self.reduce_lr_patience)
        sch_config = {
            "scheduler": self.lr_scheduler,
            "interval": "epoch",
            "monitor": 't.loss'
        }
        return {"optimizer": self.optimizer, "lr_scheduler": sch_config}

    def configure_callbacks(self, *args, **kwargs):
        callbacks = [
            LearningRateMonitor(logging_interval='epoch'),
            ModelCheckpoint(
                save_top_k=1, monitor='ema.loss', mode='min',
                filename="%s.best" % self.name,
                save_weights_only=True
            )
        ]
        return callbacks

    @torch.no_grad()
    def _generate_uncond(self, n, **kwargs):
        with self.ema.average_parameters():
            syn_data = self.model.sample(n)
        return self.tabular_transform.inverse_transform(syn_data)

    @torch.no_grad()
    def _impute(self, df: pd.DataFrame, num_average=1, resample_rounds=20, impute_condition='x_t', **kwargs) -> pd.DataFrame:
        index = df.index
        df = df.reset_index(drop=True)
        df_mask = df.isna()
        syn = self.generate_samples_from_distribution(len(df))
        df[df_mask] = syn[df_mask]

        cat_results = []
        num_results = []
        for i in range(num_average):
            tokens = self.tabular_transform.transform(df, return_as_tensor=True).to(self.device, self.dtype)
            num, cat = tokens[:, :self.numerical_dim], tokens[:, self.numerical_dim:]
            num_mask_idx = num.isnan().nonzero()
            cat_mask_idx = cat.isnan().nonzero()
            tokens = self.model.sample_impute(num, cat.long(), num_mask_idx, cat_mask_idx, resample_rounds,
                                              impute_condition, 0, 0)
            num, cat = tokens[:, :self.numerical_dim], tokens[:, self.numerical_dim:]
            cat_results.append(cat)
            num_results.append(num)
        cat = torch.stack(cat_results).mode(dim=0).values
        num = torch.stack(num_results).mean(dim=0)
        imputed = self.tabular_transform.inverse_transform(num, cat)
        df[df_mask] = imputed[df_mask]
        df.index = index
        return df

    @torch.no_grad()
    def _generate_by_class(self, df: pd.DataFrame, resample_rounds=20, impute_condition='x_t', **kwargs) -> pd.DataFrame:
        kwargs['num_average'] = 1
        return self._impute(df, **kwargs)

    @torch.no_grad()
    def _generate_by_condition(self, df: pd.DataFrame, resample_rounds=20, impute_condition='x_t', **kwargs) -> pd.DataFrame:
        kwargs['num_average'] = 1
        return self._impute(df, **kwargs)

    def prepare_dataloader(self, scenario=lambda df: df) -> dict:
        if self._cfg is None:
            return {}

        dataloaders = {}
        cfg = self._cfg

        # Load and scenario-ize training data
        train_df = pd.read_csv(cfg.dataset.train_path)
        train_df = scenario(train_df)

        # Optionally drop missing data
        if not self.model_flags['allow_missing_on_dataset']:
            train_df = train_df.dropna(subset=self.tabular_transform.columns, how='any')

        # Build train dataloader
        train_dataset = self.dataframe_to_dataset(train_df)
        drop_last = len(train_df) > cfg.model.batch_size
        dataloaders['train_dataloaders'] = self.dataset_to_dataloader(
            train_dataset,
            shuffle=True,
            drop_last=drop_last
        )
        dataloaders['val_dataloaders'] = deepcopy(dataloaders['train_dataloaders'])
        return dataloaders

    def on_train_end(self) -> None:
        self.load_state_dict(torch.load(self.trainer.checkpoint_callback.best_model_path, weights_only=False)['state_dict'])

    def fit(self, scenario: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x):
        cfg = self._cfg
        train_df = pd.read_csv(cfg.dataset.train_path)
        train_df = scenario(train_df)
        self._transform.fit(train_df)
        mean = self.tabular_transform.transform(train_df, return_as_tensor=True)[:, :self.numerical_dim].nanmean(dim=0, keepdim=True)
        self.mean.data = mean
        distribution_info = {}
        for col in self.tabular_transform.categorical_columns:
            distribution_info[col] = train_df[col].value_counts(normalize=True).to_dict()
        for col in self.tabular_transform.numerical_columns:
            distribution_info[col] = train_df[col].dropna().mean()
        self.distribution_info = distribution_info
        return super().fit(scenario)

    def generate_samples_from_distribution(self, n_samples):
        data = {}
        for col, info in self.distribution_info.items():
            if isinstance(info, dict):
                categories = list(info.keys())
                probabilities = list(info.values())
                data[col] = np.random.choice(categories, size=n_samples, p=probabilities)
            else:
                data[col] = np.full(n_samples, info)

        return pd.DataFrame(data)

    def on_save_checkpoint(self, checkpoint):
        if hasattr(self, 'distribution_info'):
            checkpoint['distribution_info'] = self.distribution_info

    def on_load_checkpoint(self, checkpoint):
        self.distribution_info = checkpoint.get('distribution_info', None)
