import time

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch_ema import ExponentialMovingAverage
from tqdm import tqdm

# from experiments.models.cdtd.utils import LinearScheduler, cycle, low_discrepancy_sampler, set_seeds
from .layers import CatEmbedding, CondMLP, Timewarp_Logistic, WeightNetwork
from .utils import (
    LinearScheduler,
    cycle,
    low_discrepancy_sampler,
    set_seeds,
)


class CondMixedTypeDiffusion(nn.Module):
    def __init__(
        self,
        model,
        dim,
        categories,
        proportions,
        num_features,
        sigma_data_cat,
        sigma_data_cont,
        sigma_min_cat,
        sigma_max_cat,
        sigma_min_cont,
        sigma_max_cont,
        cat_emb_init_sigma,
        timewarp_type="bytype",
        timewarp_weight_low_noise=1.0,
        rel_cat_weight=1.0,
    ):
        super(CondMixedTypeDiffusion, self).__init__()

        self.dim = dim
        self.num_features = num_features
        self.num_cat_features = len(categories)
        self.cat_exist = self.num_cat_features > 0
        self.num_cont_features = num_features - self.num_cat_features
        self.cont_exist = self.num_cont_features > 0
        self.categories = categories
        self.model = model

        self.cat_emb = CatEmbedding(dim, categories, cat_emb_init_sigma, bias=True)
        self.register_buffer("sigma_data_cat", torch.tensor(sigma_data_cat))
        self.register_buffer("sigma_data_cont", torch.tensor(sigma_data_cont))

        normal_const = []
        if self.num_cat_features > 0:
            if proportions is not None:
                entropy = torch.tensor([-torch.sum(p * p.log()) for p in proportions])
                normal_const.append(entropy)
            else:
                normal_const.append(torch.ones((self.num_cat_features,)))
        if self.num_cont_features > 0:
            normal_const.append(torch.ones((self.num_cont_features,)))
        self.register_buffer("normal_const", torch.cat(normal_const))
        self.rel_cat_weight = rel_cat_weight

        self.weight_network = WeightNetwork(1024)

        # timewarping
        self.timewarp_type = timewarp_type
        self.sigma_min_cat = torch.tensor(sigma_min_cat)
        self.sigma_max_cat = torch.tensor(sigma_max_cat)
        self.sigma_min_cont = torch.tensor(sigma_min_cont)
        self.sigma_max_cont = torch.tensor(sigma_max_cont)

        # combine sigma boundaries for transforming sigmas to [0,1]
        sigma_min = torch.cat(
            (
                torch.tensor(sigma_min_cat).repeat(self.num_cat_features),
                torch.tensor(sigma_min_cont).repeat(self.num_cont_features),
            ),
            dim=0,
        )
        sigma_max = torch.cat(
            (
                torch.tensor(sigma_max_cat).repeat(self.num_cat_features),
                torch.tensor(sigma_max_cont).repeat(self.num_cont_features),
            ),
            dim=0,
        )
        self.register_buffer("sigma_max", sigma_max)
        self.register_buffer("sigma_min", sigma_min)

        self.timewarp_cdf = Timewarp_Logistic(
            self.timewarp_type,
            self.num_cat_features,
            self.num_cont_features,
            sigma_min,
            sigma_max,
            weight_low_noise=timewarp_weight_low_noise,
            decay=0.0,
        )

    @property
    def device(self):
        return next(self.model.parameters()).device

    def diffusion_loss(self, x_cat_0, x_cont_0, cat_logits, cont_preds):
        if self.cat_exist:
            assert len(cat_logits) == self.num_cat_features

            # cross entropy over categorical features for each individual
            ce_losses = torch.stack(
                [F.cross_entropy(cat_logits[i], x_cat_0[:, i], reduction="none") for i in range(self.num_cat_features)],
                dim=1,
            )
        else:
            ce_losses = None

        if self.cont_exist:
            assert cont_preds.shape == x_cont_0.shape

            # MSE loss over numerical features
            mse_losses = (cont_preds - x_cont_0) ** 2
        else:
            mse_losses = None

        return ce_losses, mse_losses

    def add_noise(self, x_cat_emb_0, x_cont_0, sigma):
        if self.cat_exist:
            sigma_cat = sigma[:, : self.num_cat_features]
            x_cat_emb_t = x_cat_emb_0 + torch.randn_like(x_cat_emb_0) * sigma_cat.unsqueeze(2)
        else:
            x_cat_emb_t = None

        if self.cont_exist:
            sigma_cont = sigma[:, self.num_cat_features :]
            x_cont_t = x_cont_0 + torch.randn_like(x_cont_0) * sigma_cont
        else:
            x_cont_t = None

        return x_cat_emb_t, x_cont_t

    def loss_fn(self, x_cat, x_cont, cond, u=None):
        # get ground truth data
        if self.cat_exist:
            batch = x_cat.shape[0]
            x_cat_emb_0 = self.cat_emb(x_cat)
        else:
            batch = x_cont.shape[0]
            x_cat_emb_0 = None

        x_cont_0 = x_cont
        x_cat_0 = x_cat

        # draw u and convert to standard deviations for noise
        with torch.no_grad():
            if u is None:
                u = low_discrepancy_sampler(batch, device=self.device)  # (B,)
            sigma = self.timewarp_cdf(u, invert=True).detach().to(torch.float32)
            u = u.to(torch.float32)
            assert sigma.shape == (batch, self.num_features)

        x_cat_emb_t, x_cont_t = self.add_noise(x_cat_emb_0, x_cont_0, sigma)
        cat_logits, cont_preds = self.precondition(x_cat_emb_t, x_cont_t, cond, u, sigma)
        ce_losses, mse_losses = self.diffusion_loss(
            x_cat_0,
            x_cont_0,
            cat_logits,
            cont_preds,
        )

        if self.cont_exist:
            # compute EDM weight
            sigma_cont = sigma[:, self.num_cat_features :]
            cont_weight = (sigma_cont**2 + self.sigma_data_cont**2) / ((sigma_cont * self.sigma_data_cont) ** 2 + 1e-7)

        losses = {}

        loss_list = []
        weight_loss_list = []
        if self.cat_exist:
            loss_list.append(ce_losses)
            weight_loss_list.append(ce_losses * self.rel_cat_weight)
        if self.cont_exist:
            loss_list.append(mse_losses)
            weight_loss_list.append(cont_weight * mse_losses)
        losses["unweighted"] = torch.cat(loss_list, dim=1)
        losses["unweighted_calibrated"] = losses["unweighted"] / self.normal_const

        weighted_calibrated = torch.cat(weight_loss_list, dim=1) / self.normal_const
        c_noise = torch.log(u.to(torch.float32) + 1e-8) * 0.25
        time_reweight = self.weight_network(c_noise).unsqueeze(1)

        losses["timewarping"] = self.timewarp_cdf.loss_fn(sigma.detach(), losses["unweighted_calibrated"].detach())
        weightnet_loss = (time_reweight.exp() - weighted_calibrated.detach().mean(1)) ** 2
        losses["weighted_calibrated"] = weighted_calibrated / time_reweight.exp().detach()

        train_loss = losses["weighted_calibrated"].mean() + losses["timewarping"].mean() + weightnet_loss.mean()

        losses["train_loss"] = train_loss

        return losses, sigma

    def precondition(self, x_cat_emb_t, x_cont_t, cond, u, sigma):
        """
        Improved preconditioning proposed in the paper "Elucidating the Design
        Space of Diffusion-Based Generative Models" (EDM) adjusted for categorical data
        """

        if self.cat_exist:
            sigma_cat = sigma[:, : self.num_cat_features]

            c_in_cat = 1 / (self.sigma_data_cat**2 + sigma_cat.unsqueeze(2) ** 2).sqrt()  # batch, num_features, 1

        if self.cont_exist:
            sigma_cont = sigma[:, self.num_cat_features :]
            c_in_cont = 1 / (self.sigma_data_cont**2 + sigma_cont**2).sqrt()

        # c_noise = u.log() / 4
        c_noise = torch.log(u + 1e-8) * 0.25 * 1000

        cat_logits, cont_preds = self.model(
            c_in_cat * x_cat_emb_t if self.cat_exist else None,
            c_in_cont * x_cont_t if self.cont_exist else None,
            cond,
            c_noise,
        )

        if self.cont_exist:
            # apply preconditioning to continuous features
            c_skip = self.sigma_data_cont**2 / (sigma_cont**2 + self.sigma_data_cont**2)
            c_out = sigma_cont * self.sigma_data_cont / (sigma_cont**2 + self.sigma_data_cont**2).sqrt()
            D_x = c_skip * x_cont_t + c_out * cont_preds
        else:
            D_x = None

        return cat_logits, D_x

    def score_interpolation(self, x_cat_emb_t, cat_logits, sigma, return_probs=False):
        if return_probs:
            # transform logits for categorical features to probabilities
            probs = []
            for logits in cat_logits:
                probs.append(F.softmax(logits.to(torch.float64), dim=1))
            return probs

        def interpolate_emb(i):
            p = F.softmax(cat_logits[i].to(torch.float64), dim=1)
            true_emb = self.cat_emb.get_all_feat_emb(i).to(torch.float64)
            return torch.matmul(p, true_emb)

        # take prob-weighted average of normalized ground truth embeddings
        x_cat_emb_0_hat = torch.zeros_like(x_cat_emb_t, device=self.device, dtype=torch.float64)
        for i in range(self.num_cat_features):
            x_cat_emb_0_hat[:, i, :] = interpolate_emb(i)

        # plug interpolated embedding into score function to interpolate score
        sigma_cat = sigma[:, : self.num_cat_features]
        interpolated_score = (x_cat_emb_t - x_cat_emb_0_hat) / sigma_cat.unsqueeze(2)

        return interpolated_score, x_cat_emb_0_hat

    @torch.inference_mode()
    def sampler(self, cat_latents, cont_latents, cond, num_steps=200):
        B = cont_latents.shape[0] if self.cont_exist else cat_latents.shape[0]

        # construct time steps
        u_steps = torch.linspace(1, 0, num_steps + 1, device=self.device, dtype=torch.float64)
        t_steps = self.timewarp_cdf(u_steps, invert=True)

        assert torch.allclose(t_steps[0].to(torch.float32), self.sigma_max.float())
        assert torch.allclose(t_steps[-1].to(torch.float32), self.sigma_min.float())
        # the final step goes onto t = 0, i.e., sigma = sigma_min = 0

        # initialize latents at maximum noise level
        if self.cat_exist:
            t_cat_next = t_steps[0, : self.num_cat_features]
            x_cat_next = cat_latents.to(torch.float64) * t_cat_next.unsqueeze(1)

        if self.cont_exist:
            t_cont_next = t_steps[0, self.num_cat_features :]
            x_cont_next = cont_latents.to(torch.float64) * t_cont_next

        for i, (t_cur, t_next, u_cur) in enumerate(zip(t_steps[:-1], t_steps[1:], u_steps[:-1])):
            t_cur = t_cur.repeat((B, 1))
            t_next = t_next.repeat((B, 1))

            # get score model output
            cat_logits, x_cont_denoised = self.precondition(
                x_cat_emb_t=x_cat_next.to(torch.float32) if self.cat_exist else None,
                x_cont_t=x_cont_next.to(torch.float32) if self.cont_exist else None,
                cond=cond.to(torch.float32),
                u=u_cur.to(torch.float32).repeat((B,)),
                sigma=t_cur.to(torch.float32),
            )

            h = t_next - t_cur

            # estimate scores and adjust samples
            if self.cat_exist:
                d_cat_cur, _ = self.score_interpolation(x_cat_next, cat_logits, t_cur)
                x_cat_next = x_cat_next + h[:, : self.num_cat_features].unsqueeze(2) * d_cat_cur

            if self.cont_exist:
                t_cont_cur = t_cur[:, self.num_cat_features :]
                d_cont_cur = (x_cont_next - x_cont_denoised.to(torch.float64)) / t_cont_cur
                x_cont_next = x_cont_next + h[:, self.num_cat_features :] * d_cont_cur

        if self.cat_exist:
            # final prediction of classes for categorical feature
            u_final = u_steps[:-1][-1]
            t_final = t_steps[:-1][-1].repeat(B, 1)

            cat_logits, _ = self.precondition(
                x_cat_emb_t=x_cat_next.to(torch.float32),
                x_cont_t=x_cont_next.to(torch.float32) if self.cont_exist else None,
                cond=cond.to(torch.float32),
                u=u_final.to(torch.float32).repeat((B,)),
                sigma=t_final.to(torch.float32),
            )

            # get probabilities for each category and derive generated classes
            probs = self.score_interpolation(x_cat_next, cat_logits, t_final, return_probs=True)
            x_cat_gen = torch.empty(B, self.num_cat_features, device=self.device)
            for i in range(self.num_cat_features):
                x_cat_gen[:, i] = probs[i].argmax(1)
            x_cat_gen = x_cat_gen.cpu()
        else:
            x_cat_gen = None

        if self.cont_exist:
            x_cont_next = x_cont_next.cpu()
        else:
            x_cont_next = None

        return x_cat_gen, x_cont_next


class CondFastTensorDataLoader:
    """
    A DataLoader-like object for a set of tensors that can be much faster than
    TensorDataset + DataLoader because dataloader grabs individual indices of
    the dataset and calls cat (slow).
    Adapted from: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
    """

    def __init__(self, X_cat, X_cont, cond, batch_size=32, shuffle=False, drop_last=False):
        self.dataset_len = X_cat.shape[0] if X_cat is not None else X_cont.shape[0]
        assert all(t.shape[0] == self.dataset_len for t in (X_cat, X_cont, cond) if t is not None)
        self.X_cat = X_cat
        self.X_cont = X_cont
        self.cond = cond

        self.batch_size = batch_size
        self.shuffle = shuffle

        if drop_last:
            self.dataset_len = (self.dataset_len // self.batch_size) * self.batch_size

        # Calculate # batches
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if remainder > 0:
            n_batches += 1
        self.n_batches = n_batches

    def __iter__(self):
        if self.shuffle:
            self.indices = torch.randperm(self.dataset_len)
        else:
            self.indices = None
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.dataset_len:
            raise StopIteration
        if self.indices is not None:
            indices = self.indices[self.i : self.i + self.batch_size]
            batch = {}
            batch["X_cat"] = torch.index_select(self.X_cat, 0, indices) if self.X_cat is not None else None
            batch["X_cont"] = torch.index_select(self.X_cont, 0, indices) if self.X_cont is not None else None
            batch["cond"] = torch.index_select(self.cond, 0, indices) if self.cond is not None else None

        else:
            batch = {}
            batch["X_cat"] = self.X_cat[self.i : self.i + self.batch_size] if self.X_cat is not None else None
            batch["X_cont"] = self.X_cont[self.i : self.i + self.batch_size] if self.X_cont is not None else None
            batch["cond"] = self.cond[self.i : self.i + self.batch_size] if self.cond is not None else None

        self.i += self.batch_size

        batch = tuple(batch.values())
        return batch

    def __len__(self):
        return self.n_batches


class CondCDTD:
    def __init__(
        self,
        X_cat_train,
        X_cont_train,
        cond,
        cat_emb_dim=16,
        mlp_emb_dim=256,
        mlp_n_layers=5,
        mlp_n_units=1024,
        sigma_data_cat=1.0,
        sigma_data_cont=1.0,
        sigma_min_cat=0.0,
        sigma_min_cont=0.0,
        sigma_max_cat=100.0,
        sigma_max_cont=80.0,
        cat_emb_init_sigma=0.001,
        timewarp_type="bytype",  # 'single', 'bytype', or 'all'
        timewarp_weight_low_noise=1.0,
        rel_cat_weight=1.0,
        calibrate_loss=True,
    ):
        """
        Same as CDTD but allows for conditioning on extra variables.
        These are simply concatenated to the numerical variables but not noisified.
        """
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_cat_features = X_cat_train.shape[1] if X_cat_train is not None else 0
        self.num_cont_features = X_cont_train.shape[1] if X_cont_train is not None else 0
        self.num_features = self.num_cat_features + self.num_cont_features
        self.cat_emb_dim = cat_emb_dim

        # if only one feature type is observed, default to 'single' noise schedule
        if self.num_cat_features == 0 or self.num_cont_features == 0:
            timewarp_type = "single"

        # derive number of categories for each categorical feature
        self.categories = []
        if self.num_cat_features > 0 and calibrate_loss:
            for i in range(self.num_cat_features):
                uniq_vals = np.unique(X_cat_train[:, i])
                self.categories.append(len(uniq_vals))

            # derive proportions for max CE losses at t = 1 for normalization
            self.proportions = []
            n_sample = X_cat_train.shape[0]
            for i in range(len(self.categories)):
                _, counts = X_cat_train[:, i].unique(return_counts=True)
                self.proportions.append(counts / n_sample)
        else:
            self.proportions = None

        score_model = CondMLP(
            self.num_cont_features,
            self.cat_emb_dim,
            self.categories,
            self.proportions,
            mlp_emb_dim,
            mlp_n_layers,
            mlp_n_units,
            cond_dim=cond.shape[1],
        )

        self.diff_model = CondMixedTypeDiffusion(
            model=score_model,
            dim=self.cat_emb_dim,
            categories=self.categories,
            num_features=self.num_features,
            sigma_data_cat=sigma_data_cat,
            sigma_data_cont=sigma_data_cont,
            sigma_min_cat=sigma_min_cat,
            sigma_max_cat=sigma_max_cat,
            sigma_min_cont=sigma_min_cont,
            sigma_max_cont=sigma_max_cont,
            proportions=self.proportions,
            cat_emb_init_sigma=cat_emb_init_sigma,
            timewarp_type=timewarp_type,
            timewarp_weight_low_noise=timewarp_weight_low_noise,
            rel_cat_weight=rel_cat_weight,
        )

    def fit(
        self,
        X_cat_train,
        X_cont_train,
        cond,
        num_steps_train=30_000,
        num_steps_warmup=1000,
        batch_size=4096,
        lr=1e-3,
        seed=42,
        ema_decay=0.999,
        log_steps=100,
    ):
        torch.set_float32_matmul_precision("high")

        train_loader = CondFastTensorDataLoader(
            X_cat_train,
            X_cont_train,
            cond,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
        )
        train_iter = cycle(train_loader)

        set_seeds(seed, cuda_deterministic=True)
        self.diff_model = self.diff_model.to(self.device)
        self.diff_model.train()

        self.ema_diff_model = ExponentialMovingAverage(self.diff_model.parameters(), decay=ema_decay)

        self.optimizer = torch.optim.AdamW(self.diff_model.parameters(), lr=lr, weight_decay=0)
        self.scheduler = LinearScheduler(
            num_steps_train,
            base_lr=lr,
            final_lr=1e-6,
            warmup_steps=num_steps_warmup,
            warmup_begin_lr=1e-6,
            anneal_lr=True,
        )

        self.current_step = 0
        n_obs = sum_loss = 0
        train_start = time.time()
        self.loss_hist = []

        with tqdm(
            initial=self.current_step,
            total=num_steps_train,
        ) as pbar:
            while self.current_step < num_steps_train:
                self.optimizer.zero_grad()

                inputs = next(train_iter)
                x_cat, x_cont, cond = (input.to(self.device) if input is not None else None for input in inputs)

                losses, _ = self.diff_model.loss_fn(x_cat, x_cont, cond, None)
                losses["train_loss"].backward()

                # update parameters
                self.optimizer.step()
                self.diff_model.timewarp_cdf.update_ema()
                self.ema_diff_model.update()

                B = losses["unweighted"].shape[0]
                sum_loss += losses["train_loss"].detach().mean().item() * B
                n_obs += B
                self.current_step += 1
                pbar.update(1)

                if self.current_step % log_steps == 0:
                    pbar.set_description(f"Loss (last {log_steps} steps): {(sum_loss / n_obs):.3f}")
                    self.loss_hist.append(sum_loss / n_obs)
                    n_obs = sum_loss = 0

                # anneal learning rate
                if self.scheduler:
                    for param_group in self.optimizer.param_groups:
                        param_group["lr"] = self.scheduler(self.current_step)

        # compute training duration
        train_duration = time.time() - train_start
        print(f"Training took {(train_duration / 60):.2f} min.")

        # take EMA of model parameters
        self.ema_diff_model.copy_to()
        self.diff_model.eval()

        return self.diff_model

    def sample(self, cond, num_steps=200, batch_size=4096, seed=42, verbose=False):
        set_seeds(seed, cuda_deterministic=True)
        n_batches, remainder = divmod(cond.shape[0], batch_size)
        sample_sizes = n_batches * [batch_size] + [remainder] if remainder != 0 else n_batches * [batch_size]

        # split cond into batches
        cond_parts = torch.split(cond, sample_sizes, dim=0)

        x_cat_list = []
        x_cont_list = []

        for i, num_samples in enumerate(tqdm(sample_sizes, disable=not verbose)):
            cat_latents = torch.randn(
                (num_samples, self.num_cat_features, self.cat_emb_dim),
                device=self.device,
            )
            cont_latents = torch.randn((num_samples, self.num_cont_features), device=self.device)
            cond = cond_parts[i].to(self.device)
            x_cat_gen, x_cont_gen = self.diff_model.sampler(cat_latents, cont_latents, cond, num_steps)
            x_cat_list.append(x_cat_gen)
            x_cont_list.append(x_cont_gen)

        if self.num_cat_features > 0:
            x_cat = torch.cat(x_cat_list).cpu().long().numpy()
        else:
            x_cat = None
        if self.num_cont_features > 0:
            x_cont = torch.cat(x_cont_list).cpu().numpy()
        else:
            x_cont = None

        return x_cat, x_cont
