import numpy as np
import torch
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

from .torch_model.gaussian_multinomial_distribution import GaussianMultinomialDiffusion
from .torch_model.modules import MLPDiffusion
from ...base import *
from ...utils import get_kwargs


class TabDDPMUncond(BaseUnconditionalGeneratorMixIn, Base):

    """
    https://github.com/amazon-science/tabsyn/tree/main/baselines/tabddpm
    Kotelnikov, A., Baranchuk, D., Rubachev, I., & Babenko, A. (2023, July).
    Tabddpm: Modelling tabular data with diffusion models.
    In International conference on machine learning (pp. 17564-17579). PMLR.
    """

    # follows tabsyn's implementation
    def __init__(self, lr, diff_d_model=1024, **kwargs):
        super().__init__(**get_kwargs(**kwargs))
        self.model = GaussianMultinomialDiffusion(
            num_classes=np.array(self.n_categories_per_columns),
            num_numerical_features=self.numerical_dim,
            denoise_fn=MLPDiffusion(
                self.numerical_dim + sum(self.n_categories_per_columns), 0, is_y_cond=False, rtdl_params=dict(
                    d_layers=[diff_d_model, diff_d_model * 2, diff_d_model * 2, diff_d_model], dropout=0.),
                dim_t=diff_d_model),
            gaussian_loss_type='mse',
            num_timesteps=1000,
            scheduler='cosine',
        )
        self.lr = lr

    def _generate_uncond(self, bsz, ddim=False, steps=1000, **kwargs):
        sampled = self.model.sample_all(bsz, bsz, ddim=ddim, steps=steps)
        return self.tabular_transform.inverse_transform(sampled)

    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = x[0]
        return self.model.mixed_loss(x)

    def training_step(self, x):
        if not isinstance(x, torch.Tensor):
            x = x[0]
        loss_multi, loss_gauss = self(x)
        loss = loss_multi + loss_gauss
        self.log('t.loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('t.cat.loss', loss_multi.item(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('t.num.loss', loss_gauss.item(), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0)
        self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer,
                                                              lambda step: (1 - step / self.trainer.max_steps))
        sch_config = {
            "scheduler": self.lr_scheduler,
            "interval": "step",
        }
        return {"optimizer": self.optimizer, "lr_scheduler": sch_config}

    def configure_callbacks(self, *args, **kwargs):
        callbacks = [
            LearningRateMonitor(logging_interval='epoch'),
            ModelCheckpoint(
                save_top_k=1, monitor='t.loss_step', mode='min',
                filename="%s.best" % self.name,
                save_weights_only=True
            )
        ]
        return callbacks

    def on_train_end(self) -> None:
        self.load_state_dict(torch.load(self.trainer.checkpoint_callback.best_model_path, weights_only=False)['state_dict'])
