import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat
from torch_ema import ExponentialMovingAverage
from tqdm import tqdm

from .layers import CatEmbedding, Timewarp, Timewarp_Logistic, WeightNetwork
from .utils import low_discrepancy_sampler, set_seeds


class CatCDTD(nn.Module):
    """
    This implements the categorical part of the CDTD.
    """

    def __init__(
        self,
        score_model,
        num_classes,
        proportions,
        emb_dim,
        sigma_min=1e-5,
        sigma_max=100,
        sigma_data=1,
        normalize_by_entropy=True,
        weight_low_noise=1.0,
        timewarp_variant="logistic",
        cat_emb_init_sigma=0.001,
    ):
        super(CatCDTD, self).__init__()
        self.num_features = len(num_classes)
        self.num_classes = num_classes
        self.sigma_data = sigma_data
        self.emb_dim = emb_dim

        if normalize_by_entropy:
            entropy = torch.tensor([-torch.sum(p * p.log()) for p in proportions])
        else:
            entropy = torch.ones(self.num_features)
        self.register_buffer("entropy", entropy)

        self.score_model = score_model
        self.weightnet = WeightNetwork(1024)
        self.encoder = CatEmbedding(emb_dim, num_classes, cat_emb_init_sigma, bias=True, normalize_emb=True)

        self.timewarp_variant = timewarp_variant
        if timewarp_variant == "logistic":
            self.timewarp = Timewarp_Logistic(
                "single",
                self.num_features,
                0,
                torch.tensor(sigma_min),
                torch.tensor(sigma_max),
                weight_low_noise=weight_low_noise,
                decay=0,
            )
        elif timewarp_variant == "pwl":
            self.timewarp = Timewarp(sigma_min=sigma_min, sigma_max=sigma_max, decay=0.1)

    @property
    def device(self):
        return next(self.score_model.parameters()).device

    def loss_fn(self, x, t=None, validation=False):
        B = x.shape[0]
        x_emb = self.encoder(x)

        # draw time and convert to standard deviations for noise
        with torch.no_grad():
            if t is None:
                t = low_discrepancy_sampler(B, device=self.device)  # (B,)
            # sigma = self.timewarp(t, invert=True).detach().to(torch.float32).unsqueeze(-1)

            if not validation:
                sigma = self.timewarp.get_sigmas(t)
            else:
                sigma = (
                    low_discrepancy_sampler(B, device=self.device).to(torch.float32)
                    * (self.timewarp.sigma_max - self.timewarp.sigma_min)
                    + self.timewarp.sigma_min
                )
                t = self.timewarp.get_t(sigma)
            sigma = repeat(sigma, "B F -> B F G", F=self.num_features, G=1)

            t = t.to(torch.float32)
            assert sigma.shape == (B, self.num_features, 1)

        # add noise
        x_emb_t = x_emb + torch.randn_like(x_emb) * sigma

        # pass to score model
        logits = self.precondition(x_emb_t, t, sigma)

        ce_losses = torch.stack(
            [F.cross_entropy(logits[i], x[:, i], reduction="none") for i in range(self.num_features)],
            dim=1,
        )

        losses = {}
        losses["weighted"] = ce_losses / (self.entropy + 1e-8)
        time_reweight = self.weightnet(t).unsqueeze(1)
        losses["weighted_calibrated"] = losses["weighted"] / time_reweight.exp().detach()

        if self.timewarp_variant == "logistic":
            losses["timewarping"] = self.timewarp.loss_fn(sigma.squeeze(-1).detach(), losses["weighted"].detach())
        elif self.timewarp_variant == "pwl":
            losses["timewarping"] = self.timewarp.loss_fn(
                sigma.detach()[:, 0].squeeze(-1), losses["weighted"].detach().mean(1).squeeze(-1)
            )

        losses["weightnet"] = (time_reweight.exp() - losses["weighted"].detach().mean(1)) ** 2

        train_loss = losses["weighted_calibrated"].mean() + losses["timewarping"].mean() + losses["weightnet"].mean()

        losses["train_loss"] = train_loss

        return losses

    def precondition(self, x_emb_t, t, sigma):
        """
        Improved preconditioning proposed in the paper "Elucidating the Design
        Space of Diffusion-Based Generative Models" (EDM) adjusted for categorical data
        """

        c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()  # B, num_features, 1
        c_noise = torch.log(t + 1e-8) * 0.25

        return self.score_model(c_in * x_emb_t, c_noise * 1000)

    def init_score_interpolation(self):
        # copy data to embedding bag
        full_emb = self.encoder.cat_emb.weight.data.detach()

        # add bias to embedding bag
        if self.encoder.bias:
            bias = []
            for i in range(self.num_features):
                bias.append(self.encoder.cat_bias[i].unsqueeze(0).expand(self.num_classes[i], -1))
            bias = torch.row_stack(bias)

        assert bias.shape == full_emb.shape
        full_emb = full_emb + bias

        # before running score interpolation, normalize embedding bag weights once
        full_emb = F.normalize(full_emb, dim=1, eps=1e-20) * torch.tensor(self.emb_dim).sqrt()
        full_emb = full_emb.to(torch.float64)

        self.embeddings_score_interp = torch.split(full_emb, self.num_classes, dim=0)

    def score_interpolation(self, x_emb_t, logits, sigma, return_probs=False):
        if return_probs:
            # transform logits for categorical features to probabilities
            probs = [F.softmax(l.to(torch.float64), dim=1) for l in logits]
            return probs

        x_emb_hat = torch.zeros_like(x_emb_t, device=self.device, dtype=torch.float64)

        for i, logs in enumerate(logits):
            probs = F.softmax(logs.to(torch.float64), dim=1)
            x_emb_hat[:, i, :] = torch.matmul(probs, self.embeddings_score_interp[i])

        # plug interpolated embedding into score function to interpolate score
        interpolated_score = (x_emb_t - x_emb_hat) / sigma

        return interpolated_score, x_emb_hat

    @torch.inference_mode()
    def sampler(self, latents, num_steps=200):
        B = latents.shape[0]

        # construct time steps
        t_steps = torch.linspace(1, 0, num_steps + 1, device=self.device, dtype=torch.float64)
        s_steps = self.timewarp.get_sigmas(t_steps).to(torch.float64)

        assert torch.allclose(s_steps[0].to(torch.float32), self.timewarp.sigma_max.float())
        assert torch.allclose(s_steps[-1].to(torch.float32), self.timewarp.sigma_min.float())
        # the final step goes onto t = 0, i.e., sigma = sigma_min = 0

        # initialize latents at maximum noise level
        x_next = latents.to(torch.float64) * s_steps[0].unsqueeze(1)

        for i, (s_cur, s_next, t_cur) in enumerate(zip(s_steps[:-1], s_steps[1:], t_steps[:-1])):
            s_cur = s_cur.repeat((B, 1))
            s_next = s_next.repeat((B, 1))

            # get score model output
            logits = self.precondition(
                x_emb_t=x_next.to(torch.float32),
                t=t_cur.to(torch.float32).repeat((B,)),
                sigma=s_cur.to(torch.float32).unsqueeze(-1),
            )

            # estimate scores
            d_cur, _ = self.score_interpolation(x_next, logits, s_cur.unsqueeze(-1))

            # adjust data samples
            h = s_next - s_cur
            x_next = x_next + h.unsqueeze(-1) * d_cur

        # final prediction of classes for categorical feature
        t_final = t_steps[:-1][-1]
        s_final = s_steps[:-1][-1].repeat(B, 1)

        logits = self.precondition(
            x_emb_t=x_next.to(torch.float32),
            t=t_final.to(torch.float32).repeat((B,)),
            sigma=s_final.to(torch.float32).unsqueeze(-1),
        )

        # get probabilities for each category and derive generated classes
        probs = self.score_interpolation(x_next, logits, s_final.unsqueeze(-1), return_probs=True)
        x_gen = torch.empty(B, self.num_features, device=self.device)
        for i in range(self.num_features):
            x_gen[:, i] = probs[i].argmax(1)

        return x_gen.cpu()

    @torch.inference_mode()
    def sample_data(self, num_samples, num_steps=200, batch_size=4096, seed=42, verbose=True):
        # init required data
        self.init_score_interpolation()

        set_seeds(seed, cuda_deterministic=True)
        n_batches, remainder = divmod(num_samples, batch_size)
        sample_sizes = n_batches * [batch_size] + [remainder] if remainder != 0 else n_batches * [batch_size]

        x = []
        for i, num_samples in enumerate(tqdm(sample_sizes, disable=(not verbose))):
            latents = torch.randn((num_samples, self.num_features, self.emb_dim), device=self.device)
            x_gen = self.sampler(latents, num_steps=num_steps)
            x.append(x_gen)

        x = torch.cat(x).cpu().long()

        return x


class CatFlow(nn.Module):
    def __init__(
        self,
        score_model,
        num_classes,
        proportions,
        emb_dim,
        sigma_min=1e-5,
        sigma_max=100,
        sigma_data=1,
        normalize_by_entropy=True,
        weight_low_noise=1.0,
        timewarp_variant="logistic",
        cat_emb_init_sigma=0.001,
        learn_noise_schedule=False,
        init_embs_zero=False,
        learn_latents=False,
        gamma_emb_size=16,
        norm_dim=None,
        time_reweight=False,
    ):
        super(CatFlow, self).__init__()
        self.num_features = len(num_classes)
        self.num_classes = num_classes
        self.sigma_data = sigma_data
        self.emb_dim = emb_dim

        if normalize_by_entropy:
            entropy = torch.tensor([-torch.sum(p * p.log()) for p in proportions])
        else:
            entropy = torch.ones(self.num_features)
        self.register_buffer("entropy", entropy)

        self.score_model = score_model
        # self.weightnet = WeightNetwork(1024)
        self.norm_dim = norm_dim
        self.encoder = CatEmbedding(
            emb_dim, num_classes, cat_emb_init_sigma, bias=True, normalize_emb=True, norm_dim=norm_dim
        )
        if init_embs_zero:
            nn.init.zeros_(self.encoder.cat_emb.weight)

        # self.timewarp_variant = timewarp_variant
        # if timewarp_variant == 'logistic':
        #     self.timewarp = Timewarp_Logistic('single', self.num_features, 0, torch.tensor(sigma_min), torch.tensor(sigma_max), weight_low_noise=weight_low_noise, decay=0)
        # elif timewarp_variant == 'pwl':
        #     self.timewarp = Timewarp(sigma_min=sigma_min, sigma_max=sigma_max, decay=0.1)

        self.learn_latents = learn_latents
        self.learn_noise_schedule = learn_noise_schedule
        self.gamma_emb_size = gamma_emb_size
        if self.learn_noise_schedule:
            self.gamma = PolyNoiseSchedule(gamma_emb_size, self.num_features)
            self.gamma_emb = nn.Parameter(torch.randn(1, gamma_emb_size))

            if self.learn_latents:
                self.z_encoder = nn.Sequential(
                    nn.Linear(self.num_features * self.emb_dim, 4 * gamma_emb_size),
                    nn.SiLU(),
                    # nn.Linear(4*gamma_emb_size, 4*gamma_emb_size),
                    # nn.SiLU(),
                    nn.Linear(4 * gamma_emb_size, 2 * gamma_emb_size),
                )
        self.time_reweight = time_reweight
        if self.time_reweight:
            self.weightnet = WeightNetwork(1000)

    @property
    def device(self):
        return next(self.score_model.parameters()).device

    def loss_fn(self, x, t=None, validation=False):
        B = x.shape[0]
        x_1 = self.encoder(x)

        # draw time and convert to standard deviations for noise
        with torch.no_grad():
            if t is None:
                t = low_discrepancy_sampler(B, device=self.device).to(torch.float32)  # (B,)

        if self.learn_noise_schedule:
            if self.learn_latents:
                z, kl_div = self.get_z(x_1.flatten(1))
                g = self.gamma(z, t).unsqueeze(-1)
            else:
                g = self.gamma(self.gamma_emb.repeat(t.shape[0], 1), t).unsqueeze(-1)

        x_0 = torch.randn_like(x_1)

        if self.learn_noise_schedule:
            x_t = g * x_1 + (1 - g) * x_0
        else:
            x_t = t[:, None, None] * x_1 + (1 - t[:, None, None]) * x_0

        # get predictions and compute loss
        logits = self.precondition(x_t, t)
        ce_losses = torch.stack(
            [F.cross_entropy(logits[i], x[:, i], reduction="none") for i in range(self.num_features)],
            dim=1,
        )

        losses = {}
        if not self.time_reweight:
            losses["weighted"] = (ce_losses / (self.entropy + 1e-8)).mean()
            losses["train_loss"] = losses["weighted"]
        else:
            time_reweight = self.weightnet(t).exp()
            loss = ce_losses / (self.entropy + 1e-8)
            losses["weighted"] = loss.mean()
            weightnet_loss = self.weightnet.loss_fn(time_reweight, loss.detach().mean(1)).mean()
            loss = loss.mean(1) / time_reweight.detach()
            losses["train_loss"] = loss.mean() + weightnet_loss

        if self.learn_noise_schedule and self.learn_latents:
            losses["train_loss"] = losses["train_loss"] + kl_div.mean()

        return losses

    def precondition(self, x_t, t):
        """
        Improved preconditioning proposed in the paper "Elucidating the Design
        Space of Diffusion-Based Generative Models" (EDM) adjusted for categorical data
        """

        # c_in = 1 / (self.sigma_data**2 + sigma ** 2).sqrt() # B, num_features, 1
        c_noise = torch.log(t + 1e-8) * 0.25

        return self.score_model(x_t, c_noise * 1000)

    def init_embeddings(self):
        # copy data to embedding bag
        full_emb = self.encoder.cat_emb.weight.data.detach()

        # add bias to embedding bag
        if self.encoder.bias:
            bias = []
            for i in range(self.num_features):
                bias.append(self.encoder.cat_bias[i].unsqueeze(0).expand(self.num_classes[i], -1))
            bias = torch.row_stack(bias)

        assert bias.shape == full_emb.shape
        full_emb = full_emb + bias

        # before running score interpolation, normalize embedding bag weights once
        norm_dim = self.norm_dim if self.norm_dim is not None else self.emb_dim
        full_emb = F.normalize(full_emb, dim=1, eps=1e-20) * torch.tensor(norm_dim).sqrt()
        full_emb = full_emb.to(torch.float64)

        self.prepped_embs = torch.split(full_emb, self.num_classes, dim=0)

    @torch.inference_mode()
    def learned_velocity(self, x_t, logits, t, g_grads=None, return_probs=False):
        if return_probs:
            # transform logits for categorical features to probabilities
            probs = [F.softmax(l.to(torch.float64), dim=1) for l in logits]
            return probs

        mu = torch.zeros_like(x_t, device=self.device, dtype=torch.float64)

        for i, logs in enumerate(logits):
            probs = F.softmax(logs.to(torch.float64), dim=1)
            mu[:, i, :] = torch.matmul(probs, self.prepped_embs[i])

        # plug estimated expectation into velocity
        if not self.learn_noise_schedule:
            velocity = (mu - x_t) / (1 - t.unsqueeze(1).unsqueeze(2) + 1e-8)
        else:
            assert g_grads is not None
            # dalpha_dt = g_grads.unsqueeze(2)
            # dbeta_dt = -g_grads.unsqueeze(2)
            # # t = g in this case
            # beta_t = 1 - t.unsqueeze(2) + 1e-8
            # alpha_t = t.unsqueeze(2)

            # velocity = (
            #     (dalpha_dt - (dbeta_dt / beta_t) * alpha_t) * mu +
            #     (dbeta_dt / beta_t) * x_t
            # )

            velocity = g_grads.unsqueeze(2) * (mu - x_t) / (1 - t.unsqueeze(2) + 1e-8)

        return velocity

    @torch.inference_mode()
    def sampler(self, latents, num_steps=200):
        B = latents.shape[0]

        # construct time steps
        t_steps = torch.linspace(0, 1, num_steps + 1, device=self.device, dtype=torch.float64)
        # initialize latents at maximum noise level
        x_next = latents.to(torch.float64)

        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            t_cur = t_cur.repeat((B,))
            t_next = t_next.repeat((B,))

            # get score model output
            logits = self.precondition(
                x_t=x_next.to(torch.float32),
                t=t_cur.to(torch.float32),
            )

            # estimate velocity
            v_cur = self.learned_velocity(x_next, logits, t_cur)

            # adjust data samples
            h = t_next - t_cur
            x_next = x_next + h[:, None, None] * v_cur

        # final prediction of classes for categorical feature
        t_final = t_steps[-2].repeat((B,))
        logits = self.precondition(
            x_t=x_next.to(torch.float32),
            t=t_final.to(torch.float32),
        )

        # get probabilities for each category and derive generated classes
        probs = self.learned_velocity(x_next, logits, t_final, return_probs=True)
        x_gen = torch.empty(B, self.num_features, device=self.device)
        for i in range(self.num_features):
            x_gen[:, i] = probs[i].argmax(1)

        return x_gen.cpu()

    @torch.inference_mode()
    def sampler_feat_spec_noise(self, latents, num_steps=200):
        B = latents.shape[0]

        # construct time steps
        t_steps = torch.linspace(0, 1, num_steps + 1, device=self.device, dtype=torch.float64)
        if self.learn_latents:
            emb = torch.randn(t_steps.shape[0], self.gamma_emb_size, device=self.device)
        else:
            emb = self.gamma_emb.repeat(t_steps.shape[0], 1)

        g_steps = self.gamma(emb, t_steps.to(torch.float32)).to(torch.float64).detach()
        g_grads = self.gamma.get_grads(emb, t_steps.to(torch.float32)).to(torch.float64).detach()

        # initialize latents at maximum noise level
        x_next = latents.to(torch.float64)

        for i, (t_cur, t_next, g_cur, g_grad) in enumerate(zip(t_steps[:-1], t_steps[1:], g_steps[:-1], g_grads[:-1])):
            t_cur = t_cur.repeat((B,))
            t_next = t_next.repeat((B,))
            g_cur = g_cur.repeat((B, 1))
            g_grad = g_grad.repeat((B, 1))

            # get score model output
            logits = self.precondition(
                x_t=x_next.to(torch.float32),
                t=t_cur.to(torch.float32),
            )

            # estimate velocity
            v_cur = self.learned_velocity(x_next, logits, g_cur, g_grads=g_grad)

            # adjust data samples
            h = t_next - t_cur
            x_next = x_next + h[:, None, None] * v_cur

            if x_next.isnan().any():
                print(f"nan in step {i}")

        # final prediction of classes for categorical feature
        logits = self.precondition(
            x_t=x_next.to(torch.float32),
            t=t_cur.to(torch.float32),
        )

        # get probabilities for each category and derive generated classes
        probs = [F.softmax(l.to(torch.float64), dim=1) for l in logits]
        x_gen = torch.empty(B, self.num_features, device=self.device)
        for i in range(self.num_features):
            x_gen[:, i] = probs[i].argmax(1)

        return x_gen.cpu()

    @torch.inference_mode()
    def sample_data(self, num_samples, num_steps=200, batch_size=4096, seed=42, verbose=True):
        # init required data
        self.init_embeddings()

        set_seeds(seed, cuda_deterministic=True)
        n_batches, remainder = divmod(num_samples, batch_size)
        sample_sizes = n_batches * [batch_size] + [remainder] if remainder != 0 else n_batches * [batch_size]

        x = []
        for i, num_samples in enumerate(tqdm(sample_sizes, disable=(not verbose))):
            latents = torch.randn((num_samples, self.num_features, self.emb_dim), device=self.device)
            if not self.learn_noise_schedule:
                x_gen = self.sampler(latents, num_steps=num_steps)
            else:
                x_gen = self.sampler_feat_spec_noise(latents, num_steps=num_steps)
            x.append(x_gen)

        x = torch.cat(x).cpu().long()

        return x

    def plot_gamma(self):
        t_steps = torch.linspace(0, 1, 200, device=self.device, dtype=torch.float32)

        if self.learn_latents:
            emb = torch.randn(1, self.gamma_emb_size, device=self.device).repeat(t_steps.shape[0], 1)
        else:
            emb = self.gamma_emb.repeat(t_steps.shape[0], 1)

        g_steps = self.gamma(emb, t_steps)

        return g_steps

    def get_z(self, x):
        out = self.z_encoder(x)
        mu, h = torch.chunk(out, 2, dim=-1)

        # reparameterization trick
        var = F.softplus(h)
        z = mu + var.sqrt() * torch.randn_like(mu)

        # compute KL divergence
        kl_div = 0.5 * torch.sum(mu.pow(2) + var - torch.log(var) - 1.0, dim=1)

        return z, kl_div


class PolyNoiseSchedule(nn.Module):
    def __init__(self, emb_dim, num_features, gamma_min=0.0, gamma_max=1.0, grad_min_epsilon=0):
        super().__init__()
        self.gamma_min = gamma_min
        self.gamma_max = gamma_max
        self.gamma_range = self.gamma_max - self.gamma_min
        self.grad_min_epsilon = grad_min_epsilon

        self.h_net = nn.Sequential(
            nn.Linear(emb_dim, num_features),
            nn.SiLU(),
            nn.Linear(num_features, num_features),
            nn.SiLU(),
        )
        self.l_a = nn.Linear(num_features, num_features)
        nn.init.zeros_(self.l_a.weight)
        nn.init.zeros_(self.l_a.bias)
        self.l_b = nn.Linear(num_features, num_features)
        self.l_c = nn.Linear(num_features, num_features)

    def forward(self, emb, t):
        if t.numel() == 1:
            # scalar
            t = t * torch.ones((emb.shape[0], 1), device=emb.device)
        else:
            t = t.unsqueeze(-1)

        assert len(emb.shape) == 2
        assert emb.shape[0] == t.shape[0]

        a, b, c = self.get_params(emb)
        return self._eval_poly(t, a, b, c)

    def get_grads(self, emb, t):
        t = t.unsqueeze(-1)
        assert len(emb.shape) == 2
        assert emb.shape[0] == t.shape[0]
        a, b, c = self.get_params(emb)
        return self._grad_t(t, a, b, c)

    def get_params(self, emb):
        h = self.h_net(emb)
        a = self.l_a(h)
        b = self.l_b(h)
        c = 1e-3 + F.softplus(self.l_c(h))
        return a, b, c

    def _eval_poly(self, t, a, b, c):
        polynomial = (
            (a**2) * (t**5) / 5.0
            + (b**2 + 2 * a * c) * (t**3) / 3.0
            + a * b * (t**4) / 2.0
            + b * c * (t**2)
            + (c**2 + self.grad_min_epsilon) * t
        )

        scale = (a**2) / 5.0 + (b**2 + 2 * a * c) / 3.0 + a * b / 2.0 + b * c + (c**2 + self.grad_min_epsilon)

        return self.gamma_min + self.gamma_range * polynomial / scale

    def _grad_t(self, t, a, b, c):
        polynomial = (a**2) * (t**4) + (b**2 + 2 * a * c) * (t**2) + a * b * (t**3) * 2.0 + b * c * t * 2 + (c**2)

        scale = (a**2) / 5.0 + (b**2 + 2 * a * c) / 3.0 + a * b / 2.0 + b * c + (c**2)

        return self.gamma_range * polynomial / scale
