
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from einops import repeat
from torch_ema import ExponentialMovingAverage
import matplotlib.pyplot as plt
        
from .layers import CatEmbedding, WeightNetwork, Timewarp, Timewarp_Logistic
from .utils import low_discrepancy_sampler, set_seeds

# from experiments.models.lowres.utils import low_discrepancy_sampler
# from experiments.models.lowres.layers import WeightNetwork

    
class CatCDTD(nn.Module):
    """
    This implements the categorical part of the CDTD.
    """
    
    def __init__(self, score_model, num_classes, proportions, emb_dim, sigma_min=1e-5, sigma_max=100, sigma_data=1, normalize_by_entropy=True, weight_low_noise=1.0, timewarp_variant='logistic', cat_emb_init_sigma=0.001):
        super(CatCDTD, self).__init__()
        self.num_features = len(num_classes)
        self.num_classes = num_classes
        self.sigma_data = sigma_data
        self.emb_dim = emb_dim
        
        if normalize_by_entropy:
            entropy = torch.tensor([-torch.sum(p * p.log()) for p in proportions])
        else:
            entropy = torch.ones(self.num_features)
        self.register_buffer("entropy", entropy)

        self.score_model = score_model
        self.weightnet = WeightNetwork(1024)
        self.encoder = CatEmbedding(emb_dim, num_classes, cat_emb_init_sigma, bias=True, normalize_emb=True)
        
        self.timewarp_variant = timewarp_variant
        if timewarp_variant == 'logistic':
            self.timewarp = Timewarp_Logistic('single', self.num_features, 0, torch.tensor(sigma_min), torch.tensor(sigma_max), weight_low_noise=weight_low_noise, decay=0)
        elif timewarp_variant == 'pwl':
            self.timewarp = Timewarp(sigma_min=sigma_min, sigma_max=sigma_max, decay=0.1)
                
    @property
    def device(self):
        return next(self.score_model.parameters()).device

    def loss_fn(self, x, t=None, validation=False):
        
        B = x.shape[0]
        x_emb = self.encoder(x)
        
        # draw time and convert to standard deviations for noise
        with torch.no_grad():
            if t is None:
                t = low_discrepancy_sampler(B, device=self.device)  # (B,)
            # sigma = self.timewarp(t, invert=True).detach().to(torch.float32).unsqueeze(-1)
            
            if not validation:
                sigma = self.timewarp.get_sigmas(t)
            else:
                sigma = low_discrepancy_sampler(B, device=self.device).to(torch.float32) * (self.timewarp.sigma_max - self.timewarp.sigma_min) + self.timewarp.sigma_min
                t = self.timewarp.get_t(sigma)
            sigma = repeat(sigma, 'B F -> B F G', F=self.num_features, G=1)
            
            
            t = t.to(torch.float32)
            assert sigma.shape == (B, self.num_features, 1)
            
        # add noise
        x_emb_t = (x_emb + torch.randn_like(x_emb) * sigma)
        
        # pass to score model
        logits = self.precondition(x_emb_t, t, sigma)
        
        ce_losses = torch.stack(
            [
                F.cross_entropy(logits[i], x[:, i], reduction="none")
                for i in range(self.num_features)
            ],
            dim=1,
        )
        
        losses = {}
        losses['weighted'] = ce_losses / (self.entropy + 1e-8)
        time_reweight = self.weightnet(t).unsqueeze(1)
        losses["weighted_calibrated"] = (
            losses['weighted'] / time_reweight.exp().detach()
        )
    
        if self.timewarp_variant == 'logistic':
            losses["timewarping"] = self.timewarp.loss_fn(
                sigma.squeeze(-1).detach(), losses['weighted'].detach()
            )
        elif self.timewarp_variant == 'pwl':
            losses["timewarping"] = self.timewarp.loss_fn(
                sigma.detach()[:,0].squeeze(-1), losses['weighted'].detach().mean(1).squeeze(-1)
            )
            
        losses['weightnet'] = (
            time_reweight.exp() - losses['weighted'].detach().mean(1)
        ) ** 2
        
        train_loss = (
            losses["weighted_calibrated"].mean()
            + losses["timewarping"].mean()
            + losses['weightnet'].mean()
        )

        losses["train_loss"] = train_loss
        
        return losses
        
        
    def precondition(self, x_emb_t, t, sigma):
        """
        Improved preconditioning proposed in the paper "Elucidating the Design
        Space of Diffusion-Based Generative Models" (EDM) adjusted for categorical data
        """

        c_in = 1 / (self.sigma_data**2 + sigma ** 2).sqrt() # B, num_features, 1
        c_noise = torch.log(t + 1e-8) * 0.25

        return  self.score_model(c_in * x_emb_t, c_noise * 1000)
    
    
    def init_score_interpolation(self):
        # copy data to embedding bag
        full_emb = self.encoder.cat_emb.weight.data.detach()

        # add bias to embedding bag
        if self.encoder.bias:
            bias = []
            for i in range(self.num_features):
                bias.append(self.encoder.cat_bias[i].unsqueeze(0).expand(self.num_classes[i], -1))
            bias = torch.row_stack(bias)
            
        assert bias.shape == full_emb.shape
        full_emb = full_emb + bias

        # before running score interpolation, normalize embedding bag weights once
        full_emb = F.normalize(full_emb, dim=1, eps=1e-20) * torch.tensor(self.emb_dim).sqrt()
        full_emb = full_emb.to(torch.float64)
        
        self.embeddings_score_interp = torch.split(full_emb, self.num_classes, dim=0)

        
    def score_interpolation(self, x_emb_t, logits, sigma, return_probs=False):
        if return_probs:
            # transform logits for categorical features to probabilities
            probs = [F.softmax(l.to(torch.float64), dim=1) for l in logits]
            return probs
        
        x_emb_hat = torch.zeros_like(
            x_emb_t, device=self.device, dtype=torch.float64
        )
        
        for i, logs in enumerate(logits):
            probs = F.softmax(logs.to(torch.float64), dim=1)
            x_emb_hat[:,i,:] = torch.matmul(probs, self.embeddings_score_interp[i])

        # plug interpolated embedding into score function to interpolate score
        interpolated_score = (x_emb_t - x_emb_hat) / sigma

        return interpolated_score, x_emb_hat
    
    
    @torch.inference_mode()
    def sampler(self, latents, num_steps=200):
        
        B = latents.shape[0]

        # construct time steps
        t_steps = torch.linspace(
            1, 0, num_steps + 1, device=self.device, dtype=torch.float64
        )
        s_steps = self.timewarp.get_sigmas(t_steps).to(torch.float64)

        assert torch.allclose(s_steps[0].to(torch.float32), self.timewarp.sigma_max.float())
        assert torch.allclose(s_steps[-1].to(torch.float32), self.timewarp.sigma_min.float())
        # the final step goes onto t = 0, i.e., sigma = sigma_min = 0

        # initialize latents at maximum noise level
        x_next = latents.to(torch.float64) * s_steps[0].unsqueeze(1)

        for i, (s_cur, s_next, t_cur) in enumerate(
            zip(s_steps[:-1], s_steps[1:], t_steps[:-1])
        ):
            s_cur = s_cur.repeat((B, 1))
            s_next = s_next.repeat((B, 1))

            # get score model output
            logits = self.precondition(
                x_emb_t=x_next.to(torch.float32),
                t=t_cur.to(torch.float32).repeat((B,)),
                sigma=s_cur.to(torch.float32).unsqueeze(-1),
            )

            # estimate scores
            d_cur, _ = self.score_interpolation(x_next, logits, s_cur.unsqueeze(-1))
 
            # adjust data samples
            h = s_next - s_cur
            x_next = x_next + h.unsqueeze(-1) * d_cur

        # final prediction of classes for categorical feature
        t_final = t_steps[:-1][-1]
        s_final = s_steps[:-1][-1].repeat(B, 1)

        logits = self.precondition(
            x_emb_t=x_next.to(torch.float32),
            t=t_final.to(torch.float32).repeat((B,)),
            sigma=s_final.to(torch.float32).unsqueeze(-1),
        )

        # get probabilities for each category and derive generated classes
        probs = self.score_interpolation(x_next, logits, s_final.unsqueeze(-1), return_probs=True)
        x_gen = torch.empty(B, self.num_features, device=self.device)
        for i in range(self.num_features):
            x_gen[:, i] = probs[i].argmax(1)

        return x_gen.cpu()
    

    @torch.inference_mode()
    def sample_data(self, num_samples, num_steps=200, batch_size=4096, seed=42, verbose=True):
        
        # init required data
        self.init_score_interpolation()
        
        set_seeds(seed, cuda_deterministic=True)
        n_batches, remainder = divmod(num_samples, batch_size)
        sample_sizes = (
            n_batches * [batch_size] + [remainder]
            if remainder != 0
            else n_batches * [batch_size]
        )
        
        x = []
        for i, num_samples in enumerate(tqdm(sample_sizes, disable=(not verbose))):
            latents = torch.randn(
                (num_samples, self.num_features, self.emb_dim), device=self.device
            )
            x_gen = self.sampler(latents, num_steps=num_steps)
            x.append(x_gen)

        x = torch.cat(x).cpu().long()

        return x


class CatFlow(nn.Module):

    
    def __init__(self, score_model, num_classes, proportions, emb_dim, sigma_min=1e-5, sigma_max=100, sigma_data=1, normalize_by_entropy=True, weight_low_noise=1.0, timewarp_variant='logistic', cat_emb_init_sigma=0.001, learn_noise_schedule=False, init_embs_zero=False, learn_latents=False, gamma_emb_size=16, norm_dim=None,
                 time_reweight=False):
        super(CatFlow, self).__init__()
        self.num_features = len(num_classes)
        self.num_classes = num_classes
        self.sigma_data = sigma_data
        self.emb_dim = emb_dim
        
        if normalize_by_entropy:
            entropy = torch.tensor([-torch.sum(p * p.log()) for p in proportions])
        else:
            entropy = torch.ones(self.num_features)
        self.register_buffer("entropy", entropy)

        self.score_model = score_model
        # self.weightnet = WeightNetwork(1024)
        self.norm_dim = norm_dim
        self.encoder = CatEmbedding(emb_dim, num_classes, cat_emb_init_sigma, bias=True, normalize_emb=True, norm_dim=norm_dim)
        if init_embs_zero:
            nn.init.zeros_(self.encoder.cat_emb.weight)
        
        # self.timewarp_variant = timewarp_variant
        # if timewarp_variant == 'logistic':
        #     self.timewarp = Timewarp_Logistic('single', self.num_features, 0, torch.tensor(sigma_min), torch.tensor(sigma_max), weight_low_noise=weight_low_noise, decay=0)
        # elif timewarp_variant == 'pwl':
        #     self.timewarp = Timewarp(sigma_min=sigma_min, sigma_max=sigma_max, decay=0.1)
           
        self.learn_latents = learn_latents
        self.learn_noise_schedule = learn_noise_schedule
        self.gamma_emb_size = gamma_emb_size
        if self.learn_noise_schedule:
            self.gamma = PolyNoiseSchedule(gamma_emb_size, self.num_features)
            self.gamma_emb = nn.Parameter(torch.randn(1, gamma_emb_size))
            
            if self.learn_latents:
                self.z_encoder = nn.Sequential(
                    nn.Linear(self.num_features * self.emb_dim, 4*gamma_emb_size),
                    nn.SiLU(),
                    # nn.Linear(4*gamma_emb_size, 4*gamma_emb_size),
                    # nn.SiLU(),
                    nn.Linear(4*gamma_emb_size, 2*gamma_emb_size),
                )
        self.time_reweight = time_reweight
        if self.time_reweight:
            self.weightnet = WeightNetwork(1000)
        
                
    @property
    def device(self):
        return next(self.score_model.parameters()).device

    def loss_fn(self, x, t=None, validation=False):
        
        B = x.shape[0]
        x_1 = self.encoder(x)
        
        # draw time and convert to standard deviations for noise
        with torch.no_grad():
            if t is None:
                t = low_discrepancy_sampler(B, device=self.device).to(torch.float32)  # (B,)
            
        if self.learn_noise_schedule:
            if self.learn_latents:
                z, kl_div = self.get_z(x_1.flatten(1))
                g = self.gamma(z, t).unsqueeze(-1)
            else:
                g = self.gamma(self.gamma_emb.repeat(t.shape[0], 1), t).unsqueeze(-1)
            
        x_0 = torch.randn_like(x_1)
        
        if self.learn_noise_schedule:
            x_t = g * x_1 + (1 - g) * x_0
        else:
            x_t = t[:, None, None] * x_1 + (1 - t[:, None, None]) * x_0
        
        # get predictions and compute loss
        logits = self.precondition(x_t, t)
        ce_losses = torch.stack(
            [
                F.cross_entropy(logits[i], x[:, i], reduction="none")
                for i in range(self.num_features)
            ],
            dim=1,
        )
        
        losses = {}
        if not self.time_reweight:
            losses['weighted'] = (ce_losses / (self.entropy + 1e-8)).mean()
            losses['train_loss'] = losses['weighted']
        else:
            time_reweight = self.weightnet(t).exp()
            loss = (ce_losses / (self.entropy + 1e-8))
            losses['weighted'] = loss.mean()
            weightnet_loss = self.weightnet.loss_fn(time_reweight, loss.detach().mean(1)).mean()
            loss = loss.mean(1) / time_reweight.detach()
            losses['train_loss'] = loss.mean() + weightnet_loss
        
        if self.learn_noise_schedule and self.learn_latents:
            losses['train_loss'] = losses['train_loss'] + kl_div.mean()
            
        return losses
        
        
    def precondition(self, x_t, t):
        """
        Improved preconditioning proposed in the paper "Elucidating the Design
        Space of Diffusion-Based Generative Models" (EDM) adjusted for categorical data
        """

        # c_in = 1 / (self.sigma_data**2 + sigma ** 2).sqrt() # B, num_features, 1
        c_noise = torch.log(t + 1e-8) * 0.25

        return  self.score_model(x_t, c_noise * 1000)
    
    
    def init_embeddings(self):
        # copy data to embedding bag
        full_emb = self.encoder.cat_emb.weight.data.detach()

        # add bias to embedding bag
        if self.encoder.bias:
            bias = []
            for i in range(self.num_features):
                bias.append(self.encoder.cat_bias[i].unsqueeze(0).expand(self.num_classes[i], -1))
            bias = torch.row_stack(bias)
            
        assert bias.shape == full_emb.shape
        full_emb = full_emb + bias

        # before running score interpolation, normalize embedding bag weights once
        norm_dim = self.norm_dim if self.norm_dim is not None else self.emb_dim
        full_emb = F.normalize(full_emb, dim=1, eps=1e-20) * torch.tensor(norm_dim).sqrt()
        full_emb = full_emb.to(torch.float64)
        
        self.prepped_embs = torch.split(full_emb, self.num_classes, dim=0)


    @torch.inference_mode()
    def learned_velocity(self, x_t, logits, t, g_grads=None, return_probs=False):
        if return_probs:
            # transform logits for categorical features to probabilities
            probs = [F.softmax(l.to(torch.float64), dim=1) for l in logits]
            return probs
        
        mu = torch.zeros_like(
            x_t, device=self.device, dtype=torch.float64
        )
        
        for i, logs in enumerate(logits):
            probs = F.softmax(logs.to(torch.float64), dim=1)
            mu[:,i,:] = torch.matmul(probs, self.prepped_embs[i])

        # plug estimated expectation into velocity
        if not self.learn_noise_schedule:
            velocity = (mu - x_t) / (1 - t.unsqueeze(1).unsqueeze(2) + 1e-8)
        else:
            assert g_grads is not None
            # dalpha_dt = g_grads.unsqueeze(2)
            # dbeta_dt = -g_grads.unsqueeze(2)
            # # t = g in this case
            # beta_t = 1 - t.unsqueeze(2) + 1e-8
            # alpha_t = t.unsqueeze(2)
            
            # velocity = (
            #     (dalpha_dt - (dbeta_dt / beta_t) * alpha_t) * mu +
            #     (dbeta_dt / beta_t) * x_t
            # )
            
            velocity = g_grads.unsqueeze(2) * (mu - x_t) / (1 - t.unsqueeze(2) + 1e-8)
            

        return velocity
    
    
    @torch.inference_mode()
    def sampler(self, latents, num_steps=200):
        
        B = latents.shape[0]

        # construct time steps
        t_steps = torch.linspace(
            0, 1, num_steps + 1, device=self.device, dtype=torch.float64
        )
        # initialize latents at maximum noise level
        x_next = latents.to(torch.float64)

        for i, (t_cur, t_next) in enumerate(
            zip(t_steps[:-1], t_steps[1:])
        ):
            t_cur = t_cur.repeat((B,))
            t_next = t_next.repeat((B,))

            # get score model output
            logits = self.precondition(
                x_t=x_next.to(torch.float32),
                t=t_cur.to(torch.float32),
            )

            # estimate velocity
            v_cur = self.learned_velocity(x_next, logits, t_cur)
 
            # adjust data samples
            h = t_next - t_cur
            x_next = x_next + h[:, None, None] * v_cur

        # final prediction of classes for categorical feature
        t_final = t_steps[-2].repeat((B,))
        logits = self.precondition(
            x_t=x_next.to(torch.float32),
            t=t_final.to(torch.float32),
        )

        # get probabilities for each category and derive generated classes
        probs = self.learned_velocity(x_next, logits, t_final, return_probs=True)
        x_gen = torch.empty(B, self.num_features, device=self.device)
        for i in range(self.num_features):
            x_gen[:, i] = probs[i].argmax(1)

        return x_gen.cpu()
    
    
    @torch.inference_mode()
    def sampler_feat_spec_noise(self, latents, num_steps=200):
        
        B = latents.shape[0]

        # construct time steps
        t_steps = torch.linspace(
            0, 1, num_steps + 1, device=self.device, dtype=torch.float64
        )
        if self.learn_latents:
            emb = torch.randn(t_steps.shape[0], self.gamma_emb_size, device=self.device)
        else:
            emb = self.gamma_emb.repeat(t_steps.shape[0], 1)
        
        g_steps = self.gamma(emb, t_steps.to(torch.float32)).to(torch.float64).detach()
        g_grads = self.gamma.get_grads(emb, t_steps.to(torch.float32)).to(torch.float64).detach()

        # initialize latents at maximum noise level
        x_next = latents.to(torch.float64)

        for i, (t_cur, t_next, g_cur, g_grad) in enumerate(
            zip(t_steps[:-1], t_steps[1:], g_steps[:-1], g_grads[:-1])
        ):
            t_cur = t_cur.repeat((B,))
            t_next = t_next.repeat((B,))
            g_cur = g_cur.repeat((B, 1))
            g_grad = g_grad.repeat((B, 1))
            
            # get score model output
            logits = self.precondition(
                x_t=x_next.to(torch.float32),
                t=t_cur.to(torch.float32),
            )
            
            # estimate velocity
            v_cur = self.learned_velocity(x_next, logits, g_cur, g_grads=g_grad)
 
            # adjust data samples
            h = t_next - t_cur
            x_next = x_next + h[:, None, None] * v_cur
            
            if x_next.isnan().any():
                print(f"nan in step {i}")
            
        # final prediction of classes for categorical feature
        logits = self.precondition(
            x_t=x_next.to(torch.float32),
            t=t_cur.to(torch.float32),
        )

        # get probabilities for each category and derive generated classes
        probs = [F.softmax(l.to(torch.float64), dim=1) for l in logits]
        x_gen = torch.empty(B, self.num_features, device=self.device)
        for i in range(self.num_features):
            x_gen[:, i] = probs[i].argmax(1)

        return x_gen.cpu()
    
    
    @torch.inference_mode()
    def sample_data(self, num_samples, num_steps=200, batch_size=4096, seed=42, verbose=True):
        
        # init required data
        self.init_embeddings()
        
        set_seeds(seed, cuda_deterministic=True)
        n_batches, remainder = divmod(num_samples, batch_size)
        sample_sizes = (
            n_batches * [batch_size] + [remainder]
            if remainder != 0
            else n_batches * [batch_size]
        )
        
        x = []
        for i, num_samples in enumerate(tqdm(sample_sizes, disable=(not verbose))):
            latents = torch.randn(
                (num_samples, self.num_features, self.emb_dim), device=self.device
            )
            if not self.learn_noise_schedule:
                x_gen = self.sampler(latents, num_steps=num_steps)
            else:
                x_gen = self.sampler_feat_spec_noise(latents, num_steps=num_steps)
            x.append(x_gen)

        x = torch.cat(x).cpu().long()

        return x
    
    
    def plot_gamma(self):
        
        t_steps = torch.linspace(
            0, 1, 200, device=self.device, dtype=torch.float32
        )
        
        if self.learn_latents:
            emb = torch.randn(1, self.gamma_emb_size, device=self.device).repeat(t_steps.shape[0], 1)
        else:
            emb = self.gamma_emb.repeat(t_steps.shape[0], 1)
            
        g_steps = self.gamma(emb, t_steps)
        
        return g_steps
        


    def get_z(self, x):
        out = self.z_encoder(x)
        mu, h = torch.chunk(out, 2, dim=-1)
        
        # reparameterization trick
        var = F.softplus(h)
        z = mu + var.sqrt() * torch.randn_like(mu)
        
        # compute KL divergence
        kl_div = 0.5 * torch.sum(mu.pow(2) + var - torch.log(var) - 1., dim=1)
        
        return z, kl_div

        
        
        
        
    
    
class PolyNoiseSchedule(nn.Module):
    def __init__(self, emb_dim, num_features, gamma_min=0.0, gamma_max=1.0, grad_min_epsilon=0):
        super().__init__()
        self.gamma_min = gamma_min
        self.gamma_max = gamma_max
        self.gamma_range = self.gamma_max - self.gamma_min
        self.grad_min_epsilon = grad_min_epsilon
        
        self.h_net = nn.Sequential(
            nn.Linear(emb_dim, num_features),
            nn.SiLU(),
            nn.Linear(num_features, num_features),
            nn.SiLU(),
        )
        self.l_a = nn.Linear(num_features, num_features)
        nn.init.zeros_(self.l_a.weight)
        nn.init.zeros_(self.l_a.bias)
        self.l_b = nn.Linear(num_features, num_features)
        self.l_c = nn.Linear(num_features, num_features)
        
        
    def forward(self, emb, t):
        
        if t.numel() == 1:
            # scalar
            t = t * torch.ones((emb.shape[0], 1), device=emb.device)
        else:
            t = t.unsqueeze(-1)
            
        assert len(emb.shape) == 2
        assert emb.shape[0] == t.shape[0]

        a, b, c = self.get_params(emb)
        return self._eval_poly(t, a, b, c)
    
    def get_grads(self, emb, t):
        
        t = t.unsqueeze(-1)
        assert len(emb.shape) == 2
        assert emb.shape[0] == t.shape[0]
        a, b, c = self.get_params(emb)
        return self._grad_t(t, a, b, c)
    
    def get_params(self, emb):
        h = self.h_net(emb)
        a = self.l_a(h)
        b = self.l_b(h)
        c = 1e-3 + F.softplus(self.l_c(h))
        return a, b, c
    
    
    def _eval_poly(self, t, a, b, c):
        
        polynomial = (
            (a ** 2) * (t ** 5) / 5.0
            + (b ** 2 + 2 * a * c) * (t ** 3) / 3.0
            + a * b * (t ** 4) / 2.0
            + b * c * (t ** 2)
            + (c ** 2 + self.grad_min_epsilon) * t
        )
        
        scale = ((a ** 2) / 5.0
             + (b ** 2 + 2 * a * c) / 3.0
             + a * b / 2.0
             + b * c
             + (c ** 2 + self.grad_min_epsilon)
        )
    
        return self.gamma_min + self.gamma_range * polynomial / scale
    
    def _grad_t(self, t, a, b, c):
        
        polynomial = (
            (a ** 2) * (t ** 4)
            + (b ** 2 + 2 * a * c) * (t ** 2)
            + a * b * (t ** 3) * 2.0
            + b * c * t * 2
            + (c ** 2)
        )
        
        scale = ((a ** 2) / 5.0
                + (b ** 2 + 2 * a * c) / 3.0
                + a * b / 2.0
                + b * c
                + (c ** 2)
        )

        return self.gamma_range * polynomial / scale
    
    