from typing import Callable

import pandas as pd
import torch
import torch.nn as nn
from lightning.pytorch.callbacks import LearningRateMonitor

from .tabnat.model import TabNAT as TabNatModel
from ...base import *


class TabNAT(
    BaseImputerMixIn,
    BaseUnconditionalGeneratorMixIn,
    Base
):

    """
    https://github.com/fangliancheng/TabNAT
    Zhang, H., Fang, L., Wu, Q., & Yu, P. S.
    TabNAT: A Continuous-Discrete Joint Generative Framework for Tabular Data.
    In Forty-second International Conference on Machine Learning.
    """

    def __init__(self, lr, embed_dim, buffer_size, depth, dropout_rate, **kwargs):
        super().__init__()
        self.model = TabNatModel(
            n_num=self.numerical_dim,
            n_cat=self.categorical_dim,
            categories=self.n_categories_per_columns,
            embed_dim=embed_dim,
            buffer_size=buffer_size,
            depth=depth,
            norm_layer=nn.LayerNorm,
            dropout_rate=dropout_rate
        )
        self.register_buffer('mean', torch.zeros(1, self.numerical_dim))
        self.register_buffer('std', torch.zeros(1, self.numerical_dim))
        self.lr = lr

    def encode(self, df_or_tensor, *args, **kwargs):
        x = df_or_tensor
        if isinstance(x, pd.DataFrame):
            x = self.tabular_transform.transform(x, return_as_tensor=True)
        x = x.nan_to_num()
        num, cat = x[:, :self.numerical_dim], x[:, self.numerical_dim:].to(torch.long)
        num = (num - self.mean) / self.std / 2
        if self.numerical_dim == 0:
            num = None
        if self.categorical_dim == 0:
            cat = None
        return num, cat

    def decode(self, num, cat):
        if num is None:
            num = torch.zeros(size=cat[:, :0].shape, device=self.device, dtype=self.dtype)
        if cat is None:
            cat = torch.zeros(size=num[:, :0].shape, device=self.device, dtype=self.dtype)
        cat = cat.to(self.dtype)
        num = num * self.std * 2 + self.mean
        return num.view(len(num), -1), cat.view(len(cat), -1)

    def _generate_uncond(self, n: int, **kwargs) -> pd.DataFrame:
        num, cat = self.model.sample(n, cls=None, device=self.device)
        return self.tabular_transform.inverse_transform(*self.decode(num, cat))

    def _impute(self, df: pd.DataFrame, num_average=20, **kwargs) -> pd.DataFrame:
        tokens = self.tabular_transform.transform(df, return_as_tensor=True).to(self.device, self.dtype)
        impute_mask = tokens.isnan()
        num_tokens = []
        cat_tokens = []
        for _ in range(num_average):
            num, cat = self.model.impute(*self.encode(tokens), impute_mask, one_step=True, device=self.device)
            num, cat = self.decode(num, cat)
            num_tokens.append(num)
            cat_tokens.append(cat)
        num = torch.stack(num_tokens).mean(dim=0)
        cat = torch.stack(cat_tokens).mode(dim=0).values
        return self.tabular_transform.inverse_transform(num, cat)

    def training_step(self, x):
        if not isinstance(x, torch.Tensor):
            x = x[0]
        missing_mask = x.isnan()
        num_missing_mask = missing_mask[:, :self.numerical_dim]
        cat_missing_mask = missing_mask[:, self.numerical_dim:]
        loss, loss_num, loss_cat = self.model(*self.encode(x), num_missing_mask=num_missing_mask,
                                              cat_missing_mask=cat_missing_mask)
        self.log('t.loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('t.loss_num', loss_num, on_step=False, on_epoch=True, prog_bar=True)
        self.log('t.loss_cat', loss_cat, on_step=False, on_epoch=True, prog_bar=True)
        if torch.isnan(loss).any() or torch.isinf(loss).any():
            raise ValueError("NaN or Inf detected in loss")
        return loss

    def configure_optimizers(self):
        """
        Define optimizer (Adam) and a scheduler (ReduceLROnPlateau).

        Returns:
            dict: Contains "optimizer" and "lr_scheduler" entries.
        """
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-6)
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.9, patience=50
        )
        sch_config = {
            "scheduler": self.lr_scheduler,
            "interval": "epoch",
            "monitor": "t.loss_epoch"
        }
        return {"optimizer": self.optimizer, "lr_scheduler": sch_config}

    def configure_callbacks(self, *args, **kwargs):
        """
        Provide callbacks for:
         - LearningRateMonitor
         - EarlyStopping (patience=500 on training loss)

        Returns:
            list: A list of callback instances.
        """
        callbacks = [
            LearningRateMonitor(logging_interval='epoch'),
        ]
        return callbacks

    def fit(self, scenario: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x):
        cfg = self._cfg
        train_df = pd.read_csv(cfg.dataset.train_path)
        train_df = scenario(train_df)
        self._transform.fit(train_df)
        train_data = self.tabular_transform.transform(train_df, return_as_tensor=True)
        mean = train_data.nanmean(dim=0, keepdim=True)
        std = torch.stack([e[~e.isnan()].std() for e in train_data.T])[None] + 1e-4
        self.mean.data = mean[:, :self.numerical_dim]
        self.std.data = std[:, :self.numerical_dim]
        return super().fit(scenario)
