from typing import Callable

import pandas as pd
import torch
from lightning.pytorch.callbacks import LearningRateMonitor, EarlyStopping

from .model import MLPDiffusion
from .model import SimpDM as SimpDiffusion
from ...base import *
from ...base.base import _IDENTITY


class SimpDM(
    BaseImputerMixIn,
    Base
):
    """
    https://github.com/yixinliu233/SimpDM
    Liu, Y., Ajanthan, T., Husain, H., & Nguyen, V. (2024).
    Self-Supervision Improves Diffusion Models for Tabular Data Imputation.
    In CIKM'24: The 33rd ACM International Conference on Information and Knowledge Management.
    Association for Computing Machinery.
    """

    def __init__(self, args, **kwargs):
        super().__init__()
        self.model = SimpDiffusion(
            num_numerical_features=self.column_dim,
            denoise_fn=MLPDiffusion(d_in=self.column_dim, d_out=self.column_dim,
                                    d_layers=[args.hidden_units] * args.num_layers),
            num_timesteps=args.num_timesteps,
            gammas=[float(gamma) for gamma in args.gammas.split('_')],
            ssl_loss_weight=args.ssl_loss_weight
        )
        self.args = args
        self.register_buffer('mean', torch.zeros(1, self.column_dim))

    def configure_optimizers(self):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.95, patience=20
        )
        sch_config = {
            "scheduler": self.lr_scheduler,
            "interval": "epoch",
            "monitor": "t.loss_epoch"
        }
        return {"optimizer": self.optimizer, "lr_scheduler": sch_config}

    def configure_callbacks(self, *args, **kwargs):
        """
        Provide callbacks for:
         - LearningRateMonitor
         - EarlyStopping (patience=500 on training loss)

        Returns:
            list: A list of callback instances.
        """
        callbacks = [
            LearningRateMonitor(logging_interval='epoch'),
            EarlyStopping(monitor='t.loss_epoch', patience=500, min_delta=0)
        ]
        return callbacks

    def training_step(self, x):
        mask = x.isnan().to(x)
        x = x - self.mean
        x = x.nan_to_num()
        loss_gauss, loss_ssl = self.model.train_iter(x, mask)
        loss = loss_gauss + loss_ssl
        self.log('t.loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def fit(self, scenario: Callable[[pd.DataFrame], pd.DataFrame] = _IDENTITY):  # type: ignore[override]
        cfg = self._cfg
        train_df = pd.read_csv(cfg.dataset.train_path)
        train_df = scenario(train_df)
        self._transform.fit(train_df)
        train_data = self.tabular_transform.transform(train_df, return_as_tensor=True)
        self.mean.data = train_data.nanmean(dim=0, keepdim=True)
        return super().fit(scenario)

    def _impute(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
        tokens = self.tabular_transform.transform(df, return_as_tensor=True).to(self.device, self.dtype)
        tokens -= self.mean
        sample = self.model.impute(tokens.nan_to_num(), tokens.isnan().float()) + self.mean
        return self.tabular_transform.inverse_transform(sample)
