import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from tqdm import tqdm

from impugen.models.metrics import *
from .diffusion_utils import sample
from .model import Model, MLPDiffusion
from .vae.model import Model_VAE
from ...base import *
from ...callbacks import ReduceHParamsOnPlateau


class TabSynVAE(BaseUnconditionalGeneratorMixIn, Base):

    """
    https://github.com/amazon-science/tabsyn
    Zhang, H., Zhang, J., Shen, Z., Srinivasan, B., Qin, X., Faloutsos, C., ... & Karypis, G.
    Mixed-Type Tabular Data Synthesis with Score-based Diffusion in Latent Space.
    In The Twelfth International Conference on Learning Representations.
    """

    def __init__(self, num_layers, d_token, n_head, factor, bias, lr, max_beta, min_beta, beta_decay,
                 **kwargs):
        super().__init__()
        self.d_model = d_token
        self.model = Model_VAE(num_layers, self.numerical_dim, self.n_categories_per_columns, d_token, n_head=n_head,
                               factor=factor, bias=bias)
        self.register_buffer('latent_mean', torch.zeros(1, self.column_dim, self.d_model))
        self.lr = lr
        self.beta = max_beta
        self.min_beta = min_beta
        self.beta_decay = beta_decay

    def encode(self, x):
        if isinstance(x, pd.DataFrame):
            x = torch.Tensor(np.concatenate(self.tabular_transform.transform(x), axis=1))
        x = x.to(self.device)
        num, cat = x[:, :self.numerical_dim], x[:, self.numerical_dim:]
        mask = x.isnan()
        embedding = self.model.get_embedding(num.nan_to_num(),
                                             cat.nan_to_num().to(torch.int32),
                                             padding_mask=mask)[:, 1:]  # drop cls token
        return (embedding - self.latent_mean) / 2  # follows original code

    def decode(self, z):
        z = z.view(len(z), -1, self.d_model).to(self.device, self.dtype)
        z = z * 2 + self.latent_mean  # follows original code
        h = self.model.VAE.decoder(z)
        num, cat = self.model.Reconstructor(h)
        return num, cat

    def forward(self, x, reduction='none'):
        if not isinstance(x, torch.Tensor):
            x = x[0]
        x = x.to(self.device, self.dtype)
        all_masked = x.isnan().all(dim=-1)
        x = x[~all_masked]
        num, cat = x[:, :self.numerical_dim], x[:, self.numerical_dim:]
        mask = x.isnan()
        cat_mask = ~cat.isnan()
        cat = cat.nan_to_num().to(torch.int32)
        num_mask = ~num.isnan()
        num = num.nan_to_num()
        x = x.nan_to_num()

        num, cat, mu_z, std_z = self.model(num, cat, padding_mask=mask)

        target_cat = x[:, self.numerical_dim:].nan_to_num().to(torch.int64)
        target_cat = [F.one_hot(split.squeeze(1), n_vocab).to(self.device, self.dtype) for split, n_vocab in
                      zip(torch.split(target_cat, 1, dim=1), self.n_categories_per_columns)]
        target_num = x[:, :self.numerical_dim]
        loss_ce = cross_entropy(cat, target_cat, cat_mask, reduction=reduction) if self.categorical_dim else (
            torch.zeros(1, device=self.device, dtype=self.dtype))
        loss_mse = mse(num, target_num, num_mask, reduction=reduction) if self.numerical_dim else (
            torch.zeros(1, device=self.device, dtype=self.dtype))
        kld_mask = torch.cat([torch.ones_like(mask[:, :1]), ~mask], dim=1)
        loss_kld = kld(mu_z, std_z, kld_mask, reduction=reduction)
        acc = accuracy(cat, target_cat, cat_mask, reduction=reduction) if self.categorical_dim else (
            torch.full((1,), torch.nan, device=self.device, dtype=self.dtype))
        return loss_ce, loss_mse, loss_kld, acc

    def training_step(self, batch, batch_idx):
        loss_ce, loss_mse, loss_kld, acc = self(batch, reduction='mean')
        loss = loss_ce + loss_mse + loss_kld * self.beta
        with torch.no_grad():
            self.log('beta', self.beta, on_step=False, on_epoch=True, prog_bar=True)
            self.log('t.ce', loss_ce.mean(), on_step=False, on_epoch=True, prog_bar=True)
            self.log('t.mse', loss_mse.mean(), on_step=False, on_epoch=True, prog_bar=True)
            self.log('t.kld', loss_kld.mean(), on_step=False, on_epoch=True, prog_bar=True)
            self.log('t.acc', acc.mean(), on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        self.eval()
        loss_ce, loss_mse, loss_kld, acc = self(batch, reduction='mean')
        loss = loss_ce + loss_mse
        self.log('v.ce', loss_ce.mean(), on_step=False, on_epoch=True, prog_bar=True)
        self.log('v.mse', loss_mse.mean(), on_step=False, on_epoch=True, prog_bar=True)
        self.log('v.kld', loss_kld.mean(), on_step=False, on_epoch=True, prog_bar=True)
        self.log('v.acc', acc.mean(), on_step=False, on_epoch=True, prog_bar=True)
        self.log('v.loss', loss.mean(), on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=0)
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.95,
                                                                       patience=10)
        sch_config = {
            "scheduler": self.lr_scheduler,
            "interval": "epoch",
            "monitor": 'v.loss'
        }
        return {"optimizer": self.optimizer, "lr_scheduler": sch_config}

    def configure_callbacks(self, *args, **kwargs):
        callbacks = [
            LearningRateMonitor(logging_interval='epoch'),
            ModelCheckpoint(
                save_top_k=1, monitor='v.loss', mode='min',
                filename="%s.best" % self.name,
                save_weights_only=True
            ),
            ReduceHParamsOnPlateau(
                param_name='beta',
                monitor='v.loss',
                decay=self.beta_decay,
                min_value=self.min_beta,
                patience=10
            )
        ]
        return callbacks

    def on_save_checkpoint(self, checkpoint) -> None:
        df = pd.read_csv(self._cfg.dataset.train_path)
        n = len(df)
        batch_size = self._cfg.model.batch_size
        queue = [batch_size] * (n // batch_size)
        embeddings = []
        if n % batch_size > 0:
            queue.append(n % batch_size)
        for ind in np.array_split(df.index, np.cumsum(queue)):
            if len(ind) == 0:
                continue
            minibatch = df.loc[ind, self.tabular_transform.columns]
            embeddings.append(self.encode(minibatch))
        checkpoint['state_dict']['latent_mean'] = torch.cat(embeddings).mean(dim=0, keepdim=True)

    def on_train_end(self) -> None:
        self.load_state_dict(torch.load(self.trainer.checkpoint_callback.best_model_path, weights_only=False)['state_dict'])

    @torch.no_grad()
    def _generate_uncond(self, n, **kwargs):
        z = torch.randn([n, self.model.VAE.Tokenizer.n_tokens, self.model.VAE.hid_dim], device=self.device,
                        dtype=self.dtype)
        h = self.model.VAE.decoder(z)
        num, cat = self.model.Reconstructor(h)
        cat = torch.stack([e.argmax(dim=-1) for e in cat], dim=1) if self.categorical_dim else (
            torch.zeros(n, 0, device=self.device, dtype=torch.int32))
        return self.tabular_transform.inverse_transform(num, cat)


class TabSyn(
    AutoEncoderManagerMixin,
    BaseImputerMixIn,
    BaseUnconditionalGeneratorMixIn,
    BaseArbitraryConditionalGeneratorMixIn,
    BaseImbalanceMixin,
    Base
):

    """
    https://github.com/amazon-science/tabsyn
    Zhang, H., Zhang, J., Shen, Z., Srinivasan, B., Qin, X., Faloutsos, C., ... & Karypis, G.
    Mixed-Type Tabular Data Synthesis with Score-based Diffusion in Latent Space.
    In The Twelfth International Conference on Learning Representations.
    """

    def __init__(self, lr, d_model, batch_mul=1, auto_batch_mul=4, **kwargs):
        super().__init__()

        self.token_embed_dim = self.ae.d_model

        self.in_dim = self.column_dim * self.token_embed_dim
        self.model = Model(MLPDiffusion(self.in_dim, d_model), hid_dim=self.in_dim)
        self.lr = lr
        self.distribution_info = {}
        self.batch_mul = batch_mul
        self.auto_batch_mul = auto_batch_mul

    @torch.no_grad()
    def _generate_uncond(self, bsz, **kwargs):
        z = sample(self.model.denoise_fn_D, bsz, self.in_dim, device=self.device, num_steps=50)
        z = z.view(bsz, -1, self.token_embed_dim)
        num, cat = self.ae.decode(z)
        if not self.model_flags['onehot']:
            cat = torch.stack([e.argmax(dim=-1) for e in cat], dim=1) if self.categorical_dim else (
                torch.zeros(bsz, 0, device=self.device, dtype=torch.int32))
        return self.tabular_transform.inverse_transform(num, cat)

    @torch.no_grad()
    def _impute(self, df: pd.DataFrame, n=20, num_steps=50, sigma_min=0.002, sigma_max=80, rho=7, num_average=5,
                progress=False, **kwargs) -> pd.DataFrame:
        index = df.index
        df = df.reset_index(drop=True)
        df_mask = df.isna()
        syn = self.generate_samples_from_distribution(len(df))
        df[df_mask] = syn[df_mask]
        latent_mask = self.tabular_transform.transform(df, return_as_tensor=True).to(self.device).isnan()
        mask = torch.zeros(len(df), self.column_dim, self.token_embed_dim,
                           device=self.device, dtype=torch.bool)
        mask[latent_mask] = True
        mask = mask.view(len(df), -1)
        indices = list(range(num_average))
        if progress:
            indices = tqdm(indices)

        def step(net, num_steps, i, t_cur, t_next, x_next):
            x_cur = x_next
            # Increase noise temporarily.
            gamma = min(1 / num_steps, np.sqrt(2) - 1)
            t_hat = net.round_sigma(t_cur + gamma * t_cur)
            x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * torch.randn_like(x_cur)
            # Euler step.
            denoised = net(x_hat, t_hat).to(torch.float32)
            d_cur = (x_hat - denoised) / t_hat
            x_next = x_hat + (t_next - t_hat) * d_cur
            # Apply 2nd order correction.
            if i < num_steps - 1:
                denoised = net(x_next, t_next).to(torch.float32)
                d_prime = (x_next - denoised) / t_next
                x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
            return x_next

        cat_results = []
        num_results = []
        for _ in indices:
            num_samples = len(df)
            x = self.ae.encode(df).view(num_samples, -1)
            net = self.model.denoise_fn_D
            x_t = torch.randn([num_samples, x.shape[1]], device=self.device)
            step_indices = torch.arange(num_steps, dtype=torch.float32, device=x_t.device)

            sigma_min = max(sigma_min, net.sigma_min)
            sigma_max = min(sigma_max, net.sigma_max)

            t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
                    sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
            t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])

            mask = mask.to(torch.int).to(self.device)
            x_t = x_t.to(torch.float32) * t_steps[0]
            for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
                if i < num_steps - 1:
                    for j in range(n):
                        n_curr = torch.randn_like(x) * t_cur
                        n_prev = torch.randn_like(x) * t_next

                        x_known_t_prev = x + n_prev
                        x_unknown_t_prev = step(net, num_steps, i, t_cur, t_next, x_t)

                        x_t_prev = x_known_t_prev * (1 - mask) + x_unknown_t_prev * mask

                        noise = torch.randn_like(x) * (t_cur.pow(2) - t_next.pow(2)).sqrt()

                        if j == n - 1:
                            x_t = x_t_prev  # turn to x_{t-1}
                        else:
                            x_t = x_t_prev + noise  # new x_t
            z = x_t
            z = z.view(len(df), -1, self.token_embed_dim)
            num, cat = self.ae.decode(z)
            if not self.model_flags['onehot']:
                cat = torch.stack([e.argmax(dim=-1) for e in cat], dim=1) if self.categorical_dim else (
                    torch.zeros(len(df), 0, device=self.device, dtype=torch.int32))
            cat_results.append(cat)
            num_results.append(num)
        cat = torch.stack(cat_results).mode(dim=0).values.cpu().numpy()
        num = torch.stack(num_results).median(dim=0).values.cpu().numpy()
        imputed = self.tabular_transform.inverse_transform(num, cat)
        df[df_mask] = imputed[df_mask]
        df.index = index
        return df

    def _generate_by_class(self, y, **kwargs) -> pd.DataFrame:
        return self._impute(y, num_average=1, **kwargs)

    def _generate_by_condition(self, y, **kwargs) -> pd.DataFrame:
        return self._impute(y, num_average=1, **kwargs)

    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = x[0]
        with torch.no_grad():
            x = self.ae.encode(x).view(len(x), -1)
        x = torch.tile(x, [self.batch_mul, 1])
        return self.model(x)[~x.isnan()].mean()

    def training_step(self, x):
        loss = self(x)
        self.log('t.loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=0)
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.9,
                                                                       patience=20)
        sch_config = {
            "scheduler": self.lr_scheduler,
            "interval": "epoch",
            "monitor": 't.loss_epoch'
        }
        return {"optimizer": self.optimizer, "lr_scheduler": sch_config}

    def configure_callbacks(self, *args, **kwargs):
        callbacks = [LearningRateMonitor(logging_interval='epoch'),
                     EarlyStopping(monitor='t.loss_epoch', patience=500),
                     ModelCheckpoint(
                         save_top_k=1, monitor='t.loss_epoch', mode='min',
                         filename="%s.best" % self.name,
                         save_weights_only=True
                     )]
        return callbacks

    def on_train_end(self) -> None:
        self.load_state_dict(torch.load(self.trainer.checkpoint_callback.best_model_path, weights_only=False)['state_dict'])

    def generate_samples_from_distribution(self, n_samples):
        data = {}
        for col, info in self.distribution_info.items():
            if isinstance(info, dict):
                categories = list(info.keys())
                probabilities = list(info.values())
                data[col] = np.random.choice(categories, size=n_samples, p=probabilities)
            else:
                data[col] = np.full(n_samples, info)

        return pd.DataFrame(data)

    def on_save_checkpoint(self, checkpoint):
        checkpoint['distribution_info'] = self.distribution_info

    def on_load_checkpoint(self, checkpoint):
        self.distribution_info = checkpoint['distribution_info']

    def fit(self, scenario=lambda x: x):
        cfg = self._cfg
        train_df = pd.read_csv(cfg.dataset.train_path)
        train_df = scenario(train_df)
        if len(train_df) < 2048:
            self._cfg.model.batch_mul = self.auto_batch_mul
            self.batch_mul = self.auto_batch_mul
        distribution_info = {}
        for col in self.tabular_transform.categorical_columns:
            distribution_info[col] = train_df[col].value_counts(normalize=True).to_dict()
        for col in self.tabular_transform.numerical_columns:
            distribution_info[col] = train_df[col].dropna().mean()
        self.distribution_info = distribution_info
        return super().fit(scenario)
