import math

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

from impugen.utils.io import setup_model_from_checkpoint
from .dataset import EncodedInfo
from .model import Model
from ...base import *
from ...utils import get_kwargs
from tqdm import tqdm

class MaCoDE(
    BasePredictorMixIn,
    BaseImputerMixIn,
    BaseUnconditionalGeneratorMixIn,
    BaseArbitraryConditionalGeneratorMixIn,
    BaseImbalanceMixin,
    Base
):
    """
    https://github.com/an-seunghwan/MaCoDE
    An, S., Woo, G., Lim, J., Kim, C., Hong, S., & Jeon, J. J. (2025, April).
    Masked Language Modeling Becomes Conditional Density Estimation for Tabular Data Synthesis.
    In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 39, No. 15, pp. 15356-15364).
    """

    def __init__(
            self,
            bins: int = 50,
            dim_transformer: int = 128,
            num_transformer_heads: int = 4,
            num_transformer_layer: int = 2,
            lr: float = 0.001,
            **kwargs
    ):
        super().__init__()
        self.lr = lr
        self.encoded_info = EncodedInfo(
            self.column_dim, self.numerical_dim, [e + 1 for e in self.n_categories_per_columns]
        )
        self.model = Model(
            EncodedInfo=self.encoded_info,
            bins=bins,
            dim_transformer=dim_transformer,
            num_transformer_heads=num_transformer_heads,
            num_transformer_layer=num_transformer_layer
        )
        self.register_buffer('bins', torch.Tensor(np.linspace(0, 1, bins + 1, endpoint=True)))

    @property
    def n_bins(self):
        return self.model.embedding.ContEmbed[0].num_embeddings - 1

    def configure_optimizers(self):
        return torch.optim.AdamW(
            self.model.parameters(),
            lr=self.lr,
            weight_decay=1e-3)

    @staticmethod
    def multiclass_loss(batch, pred, mask):
        class_loss = 0.
        for j in range(len(pred)):
            tmp = F.cross_entropy(
                pred[j][mask[:, j]],  # ignore [MASKED] token probability
                batch[:, j][mask[:, j]].long() - 1  # ignore unmasked
            )
            if not tmp.isnan():
                class_loss += tmp
        return class_loss

    def digitize(self, df_or_tensor, *args, **kwargs) -> torch.Tensor:
        x = df_or_tensor
        if isinstance(x, pd.DataFrame):
            x = self.tabular_transform.transform(x, return_as_tensor=True)
        nan_mask = x.isnan()
        x = x.nan_to_num().to(self.device, self.dtype)
        num, cat = x[:, :self.numerical_dim], x[:, self.numerical_dim:]
        num = num.clip(0, 1 - 1e-6)
        num = torch.bucketize(num, self.bins, right=True).to(self.dtype)
        cat = (cat + 1).nan_to_num()
        x = torch.cat([num, cat], dim=1)
        x[nan_mask] = torch.nan
        return x

    def training_step(self, x):
        x = self.digitize(x)
        nan_mask = x.isnan()
        mask1 = torch.rand(x.size(0), self.encoded_info.num_features, device=self.device) > torch.rand(len(x), 1,
                                                                                                       device=self.device)
        mask = mask1 | nan_mask
        loss_mask = mask1 & ~nan_mask

        x_masked = x.clone()
        x_masked[mask] = 0.  # [MASKED] token

        pred = self.model(x_masked)

        loss = self.multiclass_loss(x, pred, loss_mask)
        self.log('t.loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def forward(self, x):
        return self.model(x)

    def sample_tokens(self, tokens, tau=1.0, mask_target_column=False, randperm=True,
                      num_average=1, **kwargs):
        if tau <= 0:
            ValueError(r"$\tau$ must be positive!")

        mask = tokens.isnan()
        tokens = tokens.nan_to_num()
        num_tokens = []
        cat_tokens = []
        verbose = range(num_average) if num_average == 1 else tqdm(range(num_average))
        for _ in verbose:
            seq = torch.randperm(self.encoded_info.num_features) if randperm else range(self.encoded_info.num_features)
            for i in seq:
                if self.tgt is not None and i == self.tgt_index and mask_target_column:
                    continue
                masked_tokens = tokens.clone()
                masked_tokens[mask] = 0.  # [MASKED] token
                pred = self.model(masked_tokens)
                x = Categorical(logits=pred[i] / tau).sample().float() + 1
                tokens[:, i][mask[:, i]] = x[mask[:, i]]

            num, cat = tokens[:, :self.numerical_dim], tokens[:, self.numerical_dim:]
            if num.shape[1] > 0:
                num = self.bins[num.long() - 1] + torch.rand_like(num) / self.n_bins
            cat = cat - 1
            num_tokens.append(num)
            cat_tokens.append(cat)
        num = torch.stack(num_tokens).mean(dim=0)
        cat = torch.stack(cat_tokens).mode(dim=0).values
        return self.tabular_transform.inverse_transform(num, cat)

    def _generate_uncond(self, n, tau=1.0, **kwargs) -> pd.DataFrame:
        tokens = torch.full([n, self.encoded_info.num_features], torch.nan, device=self.device, dtype=self.dtype)
        return self.sample_tokens(tokens, tau, mask_target_column=False, randperm=True, **kwargs)

    def _predict(self, df: pd.DataFrame, tau=1e-4, **kwargs) -> np.ndarray:
        tokens = self.digitize(df)
        pred = self.sample_tokens(tokens, tau, randperm=False, **kwargs)
        return pred[self.tgt].values

    def _predict_proba(self, df: pd.DataFrame, **kwargs) -> np.ndarray:
        if self.tgt in self.tabular_transform.numerical_columns:
            return self._predict(
                df,
                **kwargs
            )

        tokens = self.digitize(df)
        mask = tokens.isnan()
        tokens = tokens.nan_to_num()

        target_ind = self.tgt_index
        masked_tokens = tokens.clone()
        masked_tokens[mask] = 0.  # [MASKED] token
        pred = self.model(masked_tokens)
        return F.softmax(pred[target_ind][:, :-1], dim=1).detach().cpu().numpy()

    def _impute(self, df: pd.DataFrame, tau=1.0, mask_target_column=True, **kwargs) -> pd.DataFrame:
        tokens = self.digitize(df)
        return self.sample_tokens(tokens, tau, mask_target_column=mask_target_column, randperm=False,
                                  num_average=10, **kwargs)

    def _generate_by_class(self, df: pd.DataFrame, tau=1.0, **kwargs) -> pd.DataFrame:
        tokens = self.digitize(df)
        return self.sample_tokens(tokens, tau, mask_target_column=False, randperm=True, **kwargs)


class ConfidenceMaCoDE(MaCoDE):

    def _sample_step(self, tokens, last_confidence, given_mask, tau, i):
        logit = self.model(tokens)
        x = torch.stack([Categorical(logits=e / tau).sample().float() + 1 for e in logit], dim=1)
        confidence = torch.stack([e.max(dim=1).values for e in logit], dim=1)
        if i == 0:
            last_confidence = confidence
        last_confidence[given_mask | (tokens == 0)] = -torch.inf
        last_confidence[~given_mask * (tokens != 0)] = torch.inf
        rank = torch.argsort(torch.argsort(last_confidence, dim=1), dim=1)
        update = (rank >= (self.column_dim - i - 1)) * ~given_mask * (tokens == 0)
        tokens[update] = x[update]
        last_confidence[update] = confidence[update]
        return tokens, last_confidence, logit, update

    def sample_tokens(self, tokens, tau=1.0, mask_target_column=False, num_iter=-1, num_average=1,
                      random=False, **kwargs):
        if random:
            return super().sample_tokens(tokens, tau, mask_target_column, **kwargs)
        if tau <= 0:
            ValueError(r"$\tau$ must be positive!")

        given_mask = ~tokens.isnan()
        if mask_target_column and self.tgt is not None:
            given_mask[:, self.tgt_index] = True
        tokens = tokens.nan_to_num()

        num_tokens = []
        cat_tokens = []
        verbose = range(num_average) if num_average == 1 else tqdm(range(num_average))
        for _ in verbose:
            if num_iter == -1:
                seq = range(self.column_dim)
            else:
                q, r = divmod(self.column_dim, num_iter)
                idx = np.arange(num_iter)
                parts = q + (idx >= (num_iter - r)).astype(int)
                seq = np.cumsum(parts) - 1

            last_confidence = torch.full_like(tokens, -torch.inf)

            for i in seq:
                tokens, last_confidence, logit, update = self._sample_step(tokens, last_confidence, given_mask, tau, i)

            num, cat = tokens[:, :self.numerical_dim], tokens[:, self.numerical_dim:]
            if num.shape[1] > 0:
                num = self.bins[num.long() - 1] + torch.rand_like(num) / self.n_bins
            cat = cat - 1
            num_tokens.append(num)
            cat_tokens.append(cat)
        num = torch.stack(num_tokens).mean(dim=0)
        cat = torch.stack(cat_tokens).mode(dim=0).values
        return self.tabular_transform.inverse_transform(num, cat)


class GRPOMaCoDE(ConfidenceMaCoDE):

    def __init__(self, ckpt, lr=3e-4, n_rollouts=4, **kwargs):
        super().__init__(**get_kwargs(**kwargs))

        self.lr = lr
        self.model, _ = setup_model_from_checkpoint(ckpt)
        self.K = n_rollouts
        self._update_transform(self.model)

    def _update_transform(self, model):
        self._transform = model._transform
        self.model_flags['onehot'] = model.model_flags['onehot']
        self.model_flags['scaler'] = model.model_flags['scaler']
        self._refresh_schema()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-2)

    @property
    def n_bins(self):
        return self.model.n_bins

    def _action(self, tokens, tau=1.0, num_iter=-1, **kwargs):
        if tau <= 0:
            ValueError(r"$\tau$ must be positive!")

        given_mask = ~tokens.isnan()
        tokens = tokens.nan_to_num()

        if num_iter == -1:
            seq = range(self.column_dim)
        else:
            q, r = divmod(self.column_dim, num_iter)
            idx = np.arange(num_iter)
            parts = q + (idx >= (num_iter - r)).astype(int)
            seq = np.cumsum(parts) - 1

        last_confidence = torch.full_like(tokens, -torch.inf)
        trajectory = dict(
            target=[],
            logit=[],
        )

        for i in seq:
            tokens, last_confidence, logit, update = self._sample_step(tokens, last_confidence, given_mask, tau, i)
            updated_tokens = tokens.clone()
            updated_tokens[~update] = 0
            trajectory['target'].append(updated_tokens.long() - 1)
            trajectory['logit'].append(logit)

        return tokens, trajectory

    def _reward(self, gt, action, mask_rep):
        gt_num, gt_cat = gt[:, :self.numerical_dim], gt[:, self.numerical_dim:]
        pred_num, pred_cat = action[:, :self.numerical_dim], action[:, self.numerical_dim:]
        num_reward = torch.clip(1 - torch.abs(gt_num - pred_num) / 10, 0, 1)  # soft for numerical, hard for categorical
        cat_reward = (gt_cat == pred_cat).float()
        reward = torch.cat([num_reward, cat_reward], dim=1)
        reward[~mask_rep] = torch.nan
        return reward.nanmean(dim=1).nan_to_num()

    def training_step(self, x):
        if not isinstance(x, torch.Tensor):
            x = x[0]
        bsz = len(x)
        gt = self.digitize(x)
        nan_mask = gt.isnan()
        mask1 = torch.rand(x.size(0), self.encoded_info.num_features, device=self.device) > torch.rand(len(x), 1,
                                                                                                       device=self.device)
        mask = mask1 | nan_mask
        loss_mask = mask1 & ~nan_mask

        state = gt.clone()
        state[mask] = torch.nan

        gt_rep = gt.tile(self.K, 1)
        state_rep = state.tile(self.K, 1)
        mask_rep = loss_mask.tile(self.K, 1)

        action, trajectory = self._action(state_rep)
        reward = self._reward(gt_rep, action, mask_rep).view(self.K, -1).T
        advantage = (reward - reward.mean(dim=1, keepdim=True)) / (reward.std(dim=1, keepdim=True) + 1e-8)

        log_probs = []
        for logit, target in zip(trajectory['logit'], trajectory['target']):
            log_probs_column = []
            for i, logit_column in enumerate(logit):
                log_probs_column.append(
                    -torch.nn.functional.cross_entropy(
                        logit_column, target[:, i],
                        ignore_index=-1,
                        reduction='none'
                    )
                )
            log_probs.append(torch.stack(log_probs_column, dim=-1))
        log_probs = torch.concat(log_probs, dim=-1).reshape(self.K, bsz, -1).mean(dim=-1).T

        obj = log_probs * advantage
        loss = -obj.mean()
        self.log('t.loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('t.reward.mean', reward.mean().item(), on_step=False, on_epoch=True, prog_bar=True)
        self.log('t.reward.std', reward.std().item(), on_step=False, on_epoch=True, prog_bar=True)
        return loss
