import os
import pandas as pd
import torch
from lightning.pytorch.callbacks import LearningRateMonitor, EarlyStopping
from lightning.pytorch.utilities.model_summary import ModelSummary
from omegaconf import OmegaConf
from tqdm import tqdm

from impugen.base import *
from impugen.base.base import DataSet
from impugen.models.diffputer.diffusion_utils import impute_mask
from impugen.models.diffputer.model import MLPDiffusion, Model
from impugen.utils import setup_logger, create_latest_symlink, rank_zero_print, SeedContext


class Diffputer(
    BaseImputerMixIn,
    Base
):
    """
    https://github.com/hengruizhang98/DiffPuter
    Zhang, H., Fang, L., Wu, Q., & Yu, P. S. (2025).
    Diffputer: Empowering diffusion models for missing data imputation.
    In The Thirteenth International Conference on Learning Representations.
    """

    def __init__(self, lr, d_model, em_steps=10, **kwargs):
        super().__init__()
        self.in_dim = self.column_dim
        self.d_model = d_model
        self.lr = lr
        self.em_steps = em_steps
        self.current_em_step = 0
        self.register_buffer('mean', torch.zeros(1, self.column_dim))
        self.register_buffer('std', torch.ones(1, self.column_dim))
        self.register_buffer('max', torch.zeros(1, self.column_dim))
        self.register_buffer('min', torch.zeros(1, self.column_dim))

        self.model = Model(MLPDiffusion(self.in_dim, self.d_model), hid_dim=self.in_dim)
        self.train_df = None
        self.in_sample_imputed = None
        self.best_model = Model(MLPDiffusion(self.in_dim, self.d_model), hid_dim=self.in_dim)
        self.best_loss = torch.inf

    @torch.no_grad()
    def impute(
            self,
            df: pd.DataFrame,
            batch_size: int = 4096,
            *,
            seed=None,  # passed to seed_context
            mask_target_column: bool = False,
            **kwargs
    ) -> pd.DataFrame:
        if self.in_sample_imputed is not None and df.equals(
                self.train_df):  # if input is same as train_df, return iteratively imputed result
            return self.in_sample_imputed

        return super().impute(df, batch_size, seed=seed, mask_target_column=mask_target_column, **kwargs)

    @torch.no_grad()
    def _impute(self, dataframe: pd.DataFrame, num_average=10, num_steps=50, progress=True, **kwargs) -> pd.DataFrame:
        tokens = self.tabular_transform.transform(dataframe, return_as_tensor=True).to(self.device, self.dtype)
        nan_mask = tokens.isnan()
        X = self.normalize(tokens)

        rec_Xs = []

        for _ in tqdm(range(num_average), desc='Imputation'):
            impute_X = X.clone()
            # perform missing value imputation with best model: line 173 in main.py
            net = self.best_model.denoise_fn_D
            # line 179-184 in main.py
            num_samples, dim = X.shape[0], X.shape[1]
            mask_train = nan_mask
            device = self.device

            rec_X = impute_mask(net, impute_X, mask_train, num_samples, dim, num_steps, device)
            rec_X = rec_X * mask_train + impute_X * ~mask_train
            rec_Xs.append(rec_X)

        # line 188 in main.py
        rec_X = torch.stack(rec_Xs, dim=0).mean(0)
        num, cat = self.denormalize(rec_X)
        return self.tabular_transform.inverse_transform(num, cat)

    def normalize(self, df_or_tensor, *args, **kwargs) -> torch.Tensor:
        x = df_or_tensor
        if isinstance(x, pd.DataFrame):
            x = self.tabular_transform.transform(x, return_as_tensor=True)

        x = x.to(self.device, self.dtype)
        x = (x - self.mean) / self.std / 2
        x = x.nan_to_num()  # Mean imputation
        return x

    def denormalize(self, tensor: torch.Tensor, *args, **kwargs):
        x = tensor.to(self.device, self.dtype) * 2
        x = x * self.std + self.mean
        x = torch.minimum(torch.maximum(x, self.min), self.max)  # clipping

        num = x[:, :self.numerical_dim]
        cat = x[:, self.numerical_dim:]
        # View/reshape in case the caller supplied a flattened batch
        return num.view(len(num), -1), cat.view(len(cat), -1)

    def fit(self, scenario=lambda x: x):
        cfg = self._cfg
        logger = setup_logger(self, cfg)
        self.log_dir = logger.log_dir
        os.makedirs(self.log_dir, exist_ok=True)
        create_latest_symlink(self.log_dir)
        OmegaConf.save(cfg, os.path.join(self.log_dir, "config.yaml"))
        rank_zero_print(OmegaConf.to_yaml(cfg))
        rank_zero_print(ModelSummary(self))

        train_df = pd.read_csv(cfg.dataset.train_path)
        train_df = scenario(train_df)
        self._transform.fit(train_df)
        self.train_df = train_df
        # assume num scaler is standard, cat is onehot encoded
        # mean imputation and scaling for onehot encoding is performed by normalize()
        self.in_sample_imputed = train_df

        data = self.tabular_transform.transform(train_df, return_as_tensor=True)

        mean = data.nanmean(dim=0, keepdim=True)
        std = torch.stack([e[~e.isnan()].std() for e in data.T])[None] + 1e-4
        min = data.nanquantile(0, dim=0, keepdim=True)
        max = data.nanquantile(1, dim=0, keepdim=True)
        self.mean.data = mean
        self.std.data = std
        self.min.data = min
        self.max.data = max
        del data

        for step in range(self.em_steps):
            self.current_em_step = int(step)

            with SeedContext(cfg.seed + step):
                # init model for each iteration - line 102 in main.py
                self.model = Model(MLPDiffusion(self.in_dim, self.d_model), hid_dim=self.in_dim)
                self.best_loss = torch.inf
                trainer = self.prepare_trainer(logger)
                loaders = dict()
                # make train dataloader with imputed data
                loaders["train_dataloaders"] = self.dataset_to_dataloader(
                    DataSet(self.normalize(self.in_sample_imputed)), shuffle=True,
                    drop_last=len(train_df) > cfg.model.batch_size
                )
                trainer.fit(self, **loaders)
                # impute and save data in continuous space
                self.to('cuda')
                self.in_sample_imputed = None
                self.in_sample_imputed = self.impute(train_df)

                # code for monitor intermediate performance
                if step != self.em_steps - 1:
                    log_dir = self.log_dir
                    self.log_dir = os.path.join(log_dir, 'em_step_%s' % step)
                    self.model_flags['in_sample_only'] = True
                    self.evaluation(cfg)
                    self.model_flags['in_sample_only'] = False
                    self.log_dir = log_dir

        trainer.save_checkpoint(os.path.join(self.log_dir, f"{self.name}.ckpt"), weights_only=True)
        return self

    def configure_optimizers(self):
        # optimizer: line 109 in main.py
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=0)
        # lr_scheduler: line 110 in main.py
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.9,
                                                                       patience=50)
        sch_config = {
            "scheduler": self.lr_scheduler,
            "interval": "epoch",
            "monitor": 't.loss_epoch'
        }
        return {"optimizer": self.optimizer, "lr_scheduler": sch_config}

    def configure_callbacks(self, *args, **kwargs):
        callbacks = [LearningRateMonitor(logging_interval='epoch'),
                     # EarlyStopping: line 144-147 in main.py
                     EarlyStopping(monitor='t.loss_epoch', patience=500)]
        return callbacks

    def training_step(self, x):
        loss = self.model(x.view(len(x), -1))
        self.log('t.loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('em_step', self.current_em_step, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def on_train_epoch_end(self) -> None:
        current = self.trainer.callback_metrics['t.loss_epoch'].squeeze()
        # save best model: line 142 in main.py
        if current < self.best_loss:
            self.best_loss = current
            self.best_model.load_state_dict(self.model.state_dict())

    def on_save_checkpoint(self, checkpoint):
        checkpoint['train_df'] = self.train_df
        checkpoint['in_sample_imputed'] = self.in_sample_imputed

    def on_load_checkpoint(self, checkpoint):
        self.train_df = checkpoint.get('train_df', None)
        self.in_sample_imputed = checkpoint.get('in_sample_imputed', None)
