import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from utils import PredefinedNoiseScheduleDiscrete, get_entropy, fuse_logits_by_log_probs, sin_mask_ratio_adapter


class DiscreteUniformTransition:
    def __init__(self, x_classes: int):
        self.X_classes = x_classes

        self.u_x = torch.ones(1, self.X_classes, self.X_classes)
        if self.X_classes > 0:
            self.u_x = self.u_x / self.X_classes

    def get_Qt(self, beta_t, device):
        """ Returns one-step transition matrices for X and E, from step t - 1 to step t.
        Qt = (1 - beta_t) * I + beta_t / K

        beta_t: (bs)                         noise level between 0 and 1
        returns: qx (bs, dx, dx)
        """
        beta_t = beta_t.unsqueeze(1)
        beta_t = beta_t.to(device)
        self.u_x = self.u_x.to(device)

        q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye(self.X_classes, device=device).unsqueeze(0)

        return q_x

    def get_Qt_bar(self, alpha_bar_t, device):
        """ Returns t-step transition matrices for X from step 0 to step t.
        Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) / K

        alpha_bar_t: (bs)         Product of the (1 - beta_t) for each time step from 0 to t.
        returns: qx (bs, dx, dx)
        """
        alpha_bar_t = alpha_bar_t.unsqueeze(1)
        alpha_bar_t = alpha_bar_t.to(device)
        self.u_x = self.u_x.to(device)

        q_x = alpha_bar_t * torch.eye(self.X_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_x

        return q_x


class RagDiff(nn.Module):
    def __init__(self, model, prior_model, timesteps=500, loss_type='CE', objective='pred_x0',
                 sample_method='ddim', min_mask_ratio=0.4, dev_mask_ratio=0.1, ensemble_num=50, ):
        super().__init__()
        self.model = model
        self.prior_model = prior_model
        for param in self.prior_model.parameters():
            param.requires_grad = False

        self.objective = objective
        self.timesteps = timesteps
        self.loss_type = loss_type
        self.noise_type = noise_type
        self.sample_method = sample_method
        self.min_mask_ratio = min_mask_ratio
        self.dev_mask_ratio = dev_mask_ratio
        self.ensemble_num = ensemble_num
        self.transition_model = DiscreteUniformTransition(x_classes=20)

        assert objective in {'pred_noise', 'pred_x0'}

        self.noise_schedule = PredefinedNoiseScheduleDiscrete(noise_schedule='cosine', timesteps=self.timesteps)

    @property
    def loss_fn(self):
        if self.loss_type == 'l1':
            return F.l1_loss
        elif self.loss_type == 'l2':
            return F.mse_loss
        elif self.loss_type == 'CE':
            return F.cross_entropy

    def apply_noise(self, data, t_int):
        t_float = t_int / self.timesteps

        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)  # (bs, 1)
        Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device=data.x.device)

        prob_X = (data.x[:, :20].unsqueeze(1) @ Qtb[data.batch]).squeeze()
        X_t = prob_X.multinomial(1).squeeze()
        noise_X = F.one_hot(X_t, num_classes=20)
        noise_data = data.clone()
        noise_data.x = noise_X
        return noise_data, alpha_t_bar

    def sample_discrete_feature_noise(self, limit_dist, num_node):
        x_limit = limit_dist[None, :].expand(num_node, -1)  # [num_node,20]
        U_X = x_limit.flatten(end_dim=-2).multinomial(1).squeeze()
        U_X = F.one_hot(U_X, num_classes=x_limit.shape[-1]).float()
        return U_X

    def diffusion_loss(self, data, t_int):
        '''
        Compute the divergence between  q(x_t-1|x_t,x_0) and p_{\theta}(x_t-1|x_t)
        '''
        # q(x_t-1|x_t,x_0)
        s_int = t_int - 1
        t_float = t_int / self.timesteps
        s_float = s_int / self.timesteps
        beta_t = self.noise_schedule(t_normalized=t_float)  # (bs, 1)
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)  # (bs, 1)
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)  # (bs, 1)
        Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device=data.x.device)
        Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device=data.x.device)
        Qt = self.transition_model.get_Qt(beta_t, data.x.device)
        # prob_X = (Qtb[data.batch] @ data.x[:, :20].unsqueeze(2)).squeeze()
        prob_X = (data.x[:, :20].unsqueeze(1) @ Qtb[data.batch]).squeeze()
        X_t = prob_X.multinomial(1).squeeze()
        noise_X = F.one_hot(X_t, num_classes=20).type_as(data.x)
        prob_true = self.compute_posterior_distribution(noise_X, Qt, Qsb, Qtb, data)  # [N,d_t-1]

        # p_{\theta}(x_t-1|x_t) = \sum_{x0} q(x_t-1|x_t,x_0)p(x0|xt)
        noise_data = data.clone()
        noise_data.x = noise_X  # x_t
        t = t_int * torch.ones(size=(data.batch[-1] + 1, 1), device=data.x.device).float()
        pred = self.model(noise_data, t)
        pred_X = F.softmax(pred, dim=-1)  # \hat{p(X)}_0
        p_s_and_t_given_0_X = self.compute_batched_over0_posterior_distribution(X_t=noise_X, Q_t=Qt, Qsb=Qsb, Qtb=Qtb,
                                                                                data=data)  # [N,d0,d_t-1] 20,20
        weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X  # [N,d0,d_t-1]
        unnormalized_prob_X = weighted_X.sum(dim=1)  # [N,d_t-1]
        unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
        prob_pred = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  # [N,d_t-1]
        loss = self.loss_fn(prob_pred, prob_true, reduction='mean')
        return loss

    def compute_val_loss(self, data):
        '''
        Compute the divergence between  q(x_t-1|x_t,x_0) and p_{\theta}(x_t-1|x_t)
        '''
        t_int = torch.randint(0, self.timesteps + 1, size=(data.batch[-1] + 1, 1), device=data.x.device).float()
        diffusion_loss = self.diffusion_loss(data, t_int)
        return diffusion_loss

    def compute_batched_over0_posterior_distribution(self, X_t, Q_t, Qsb, Qtb, data):
        """ M: X or E
        Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0
        X_t: bs, n, dt          or bs, n, n, dt
        Qt: bs, d_t-1, dt
        Qsb: bs, d0, d_t-1
        Qtb: bs, d0, dt.
        """
        # X_t is a sample of q(x_t|x_t+1)
        Qt_T = Q_t.transpose(-1, -2)
        X_t_ = X_t.unsqueeze(dim=-2)
        left_term = X_t_ @ Qt_T[data.batch]  # [N,1,d_t-1]
        # left_term = left_term.unsqueeze(dim = 1) #[N,1,dt-1]

        right_term = Qsb[data.batch]  # [N,d0,d_t-1]

        numerator = left_term * right_term  # [N,d0,d_t-1]

        prod = Qtb[data.batch] @ X_t.unsqueeze(dim=2)  # N,d0,1
        denominator = prod
        denominator[denominator == 0] = 1e-6

        out = numerator / denominator

        return out

    def compute_posterior_distribution(self, M_t, Qt_M, Qsb_M, Qtb_M, data):
        """
        M_t: X_t
        Compute  q(x_t-1|x_t,x_0) = xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0
        """

        # X_t is a sample of q(x_t|x_t+1)
        Qt_T = Qt_M.transpose(-1, -2)
        X_t = M_t.unsqueeze(dim=-2)
        left_term = X_t @ Qt_T[data.batch]  # [N,1,d_t-1]

        M_0 = data.x.unsqueeze(dim=-2)  # [N,1,d_t-1]
        right_term = M_0 @ Qsb_M[data.batch]  # [N,1,dt-1]
        numerator = (left_term * right_term).squeeze()  # [N,d_t-1]

        X_t_T = M_t.unsqueeze(dim=-1)
        prod = M_0 @ Qtb_M[data.batch] @ X_t_T  # [N,1,1]
        denominator = prod.squeeze()
        denominator[denominator == 0] = 1e-6

        out = (numerator / denominator.unsqueeze(dim=-1)).squeeze()

        return out  # [N,d_t-1]

    def sample_p_zs_given_zt(self, t, s, zt, g_data, ipa_data, cond, diverse, sample_type, last_step):
        """
        sample zs~p(zs|zt)
        """
        beta_t = self.noise_schedule(t_normalized=t)
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
        if self.noise_type == 'uniform' or self.noise_type == 'marginal':
            Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, g_data.x.device)
            Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, g_data.x.device)
        else:
            Qtb = self.transition_model.get_Qt_bar(t, g_data.x.device)
            Qsb = self.transition_model.get_Qt_bar(s, g_data.x.device)

        if sample_type == 'ddpm':
            Qt = self.transition_model.get_Qt(beta_t, g_data.x.device)
        elif sample_type == 'ddim':
            Qt = (Qsb / Qtb) / (Qsb / Qtb).sum(dim=-1).unsqueeze(dim=2)  # approximate
        else:
            raise NotImplementedError

        noise_data = g_data.clone()
        noise_data.x = zt

        ipa_noise_data = ipa_data.clone()
        base_logits = self.model(noise_data, t * self.timesteps)
        log_probs = F.log_softmax(base_logits, dim=-1)
        base_pred_x = log_probs.argmax(dim=-1)
        base_pred_x = F.one_hot(base_pred_x, num_classes=20).float()
        # pred_X = F.softmax(base_logits, dim=-1)
        # if last_step:
        #     return base_logits, base_pred_x
        entropy = get_entropy(log_probs)
        mask_entropy = torch.zeros_like(entropy, dtype=torch.bool)
        unique_batches = noise_data.batch.unique()

        mask_ratios = sin_mask_ratio_adapter(1 - alpha_t_bar, max_deviation=self.dev_mask_ratio,
                                             center=self.min_mask_ratio)
        for mask_ratio, b in zip(mask_ratios, unique_batches):
            mask_entropy[noise_data.batch == b] = entropy[noise_data.batch == b] > torch.quantile(
                entropy[noise_data.batch == b], 1 - mask_ratio)

        ipa_noise_data.x_mask[ipa_noise_data.x_pad == 1] = mask_entropy.long()
        ipa_noise_data.x[ipa_noise_data.x_pad == 1] = base_pred_x

        prior_logits = self.prior_model(ipa_noise_data.x, ipa_noise_data.atom_pos, ipa_noise_data.x_mask,
                                        ipa_noise_data.x_pad)
        prior_log_probs = F.log_softmax(prior_logits, dim=-1)

        prior_logits = prior_logits[ipa_noise_data.x_pad == 1]
        prior_log_probs = prior_log_probs[ipa_noise_data.x_pad == 1]

        # fuse log probs
        logits = fuse_logits_by_log_probs([log_probs, prior_log_probs], [base_logits, prior_logits])
        # logits = prior_logits
        pred_X = F.softmax(logits, dim=-1)
        if last_step:
            sample_s = pred_X.argmax(dim=1)
            final_predicted_X = F.one_hot(sample_s, num_classes=20).float()
            return logits, final_predicted_X

        p_s_and_t_given_0_X = self.compute_batched_over0_posterior_distribution(X_t=zt, Q_t=Qt, Qsb=Qsb, Qtb=Qtb,
                                                                                data=g_data)  # [N,d0,d_t-1] 20,20 approximate Q_t-s with Qt
        weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X  # [N,d0,d_t-1]
        unnormalized_prob_X = weighted_X.sum(dim=1)  # [N,d_t-1]
        unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
        prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  # [N,d_t-1]
        assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
        if diverse:
            sample_s = prob_X.multinomial(1).squeeze()
        else:
            sample_s = prob_X.argmax(dim=1).squeeze()

        predicted_X_s = F.one_hot(sample_s, num_classes=20).float()

        # predicted_X_s = base_pred_x

        return None, predicted_X_s

    def mc_ddim_sample(self, g_data, ipa_data=None, cond=False, diverse=True, stop=0, step=50):
        if self.noise_type == 'uniform' or self.noise_type == 'blosum':
            limit_dist = torch.ones(20) / 20
        elif self.noise_type == 'marginal':
            limit_dist = self.transition_model.x_marginal
        zt = self.sample_discrete_feature_noise(limit_dist=limit_dist, num_node=g_data.x.shape[0])  # [N,20] one hot
        zt = zt.to(g_data.x.device)
        # for s_int in tqdm(list(reversed(range(stop, self.timesteps, step)))):  500
        for s_int in reversed(range(stop, self.timesteps, step)):  # 500
            # z_t-1 ~p(z_t-1|z_t),
            s_array = s_int * torch.ones((g_data.batch[-1] + 1, 1)).type_as(g_data.x)
            t_array = s_array + step
            s_norm = s_array / self.timesteps
            t_norm = t_array / self.timesteps
            logits, zt = self.sample_p_zs_given_zt(t_norm, s_norm, zt, g_data, ipa_data, cond, diverse,
                                                   self.sample_method,
                                                   last_step=s_int == 0)
        return logits, zt

    def forward(self, g_data):
        t_int = torch.randint(0, self.timesteps + 1, size=(g_data.batch[-1] + 1, 1), device=g_data.x.device).float()
        noise_data, alpha_t_bar = self.apply_noise(g_data, t_int)

        if self.objective == 'pred_x0':
            target = g_data.x
        else:
            raise ValueError(f'unknown objective {self.objective}')

        logits = self.model(noise_data, t_int)

        base_loss = self.loss_fn(logits, target, reduction='mean')

        return base_loss
