import math

import pandas as pd
import torch
from lightning.pytorch.callbacks import LearningRateMonitor

from ...base import *
from ...utils import get_kwargs
from impugen.models.remasker.torch_model import MaskedAutoencoder


class ReMasker(BaseImputerMixIn, Base):

    """
    https://github.com/tydusky/remasker
    Du, T., Melis, L. M., & Wang, T. (2024, January).
    ReMasker: Imputing Tabular Data with Masked Autoencoding.
    In International Conference on Learning Representations (ICLR’24).
    International Conference on Learning Representations.
    """

    def __init__(self, embed_dim, depth, num_heads, decoder_embed_dim, decoder_depth,
                 decoder_num_heads,
                 mlp_ratio, mask_ratio, max_epochs, warmup_epochs, lr, min_lr, **kwargs):
        super().__init__(**get_kwargs(**kwargs))
        self.model = MaskedAutoencoder(self.column_dim, embed_dim, depth, num_heads, decoder_embed_dim,
                                       decoder_depth,
                                       decoder_num_heads, mlp_ratio)
        self.mask_ratio = mask_ratio
        self.max_epochs = max_epochs
        self.warmup_epochs = warmup_epochs
        self.lr = lr
        self.min_lr = min_lr

    def _impute(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
        x = self.tabular_transform.transform(df, return_as_tensor=True).to(self.device, self.dtype)
        pred = self.model(x)[1].squeeze(-1)
        return self.tabular_transform.inverse_transform(pred)

    def training_step(self, batch, batch_idx):
        self.train()
        loss = self(batch)
        if torch.isnan(loss).any() or torch.isinf(loss).any():
            raise ValueError("NaN or Inf detected in loss")
        with torch.no_grad():
            self.log('t.loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def adjust_lr(self, step, lr, min_lr, max_epochs, warmup_epochs):
        epoch = step / self.trainer.num_training_batches
        if epoch < warmup_epochs:
            return epoch / warmup_epochs
        else:
            min_val = min_lr / lr
            return (min_val + (1 - min_val) * 0.5 *
                    (1. + math.cos(math.pi * (epoch - warmup_epochs) / (max_epochs - warmup_epochs))))

    def configure_optimizers(self):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.95))
        self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer,
                                                              lambda step: self.adjust_lr(step, self.lr, self.min_lr,
                                                                                          self.max_epochs, self.warmup_epochs))
        sch_config = {
            "scheduler": self.lr_scheduler,
            "interval": "step",
        }
        return {"optimizer": self.optimizer, "lr_scheduler": sch_config}

    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = x[0]
        loss, _, _, _ = self.model(x, mask_ratio=self.mask_ratio)
        return loss

    def configure_callbacks(self, *args, **kwargs):
        callbacks = [
            LearningRateMonitor(logging_interval='step')
        ]
        return callbacks
