import matplotlib.hatch
import torch
import torch.nn.functional as F
import numpy as np
import math

import didigress.utils as utils
from didigress.diffusion import diffusion_utils

import torch.nn.functional as F
from torch_geometric.utils import to_dense_batch

from torch.profiler import profile, record_function, ProfilerActivity

import time

class NoiseModel:
    def __init__(self, cfg):
        self.mapping = ['x', 'e', 'y', 'c', 'p']
        self.inverse_mapping = {m: i for i, m in enumerate(self.mapping)}
        nu = cfg.model.nu
        self.nu_arr = []
        for m in self.mapping:
            self.nu_arr.append(nu[m])

        self.use_3d     = cfg.features.use_3d
        self.use_charges= cfg.features.use_charges
        self.use_ins_del= cfg.features.use_ins_del
        self.use_guidance=cfg.guidance.p_uncond >= 0

        # Define the transition matrices for the discrete features
        self.Px = None
        self.Pe = None
        self.Py = None

        self.X_classes = None
        self.E_classes = None
        self.y_classes = None

        self.X_marginals = None
        self.E_marginals = None
        self.y_marginals = None

        self.Pcharges           = None
        self.charges_classes    = None
        self.charges_marginals  = None

        self.noise_schedule = cfg.model.diffusion_noise_schedule
        self.timesteps = cfg.model.diffusion_steps
        self.T = cfg.model.diffusion_steps

        self.x_ins_del_slots= 0
        self.C              = None
        self.C_sgn          = None

        if self.noise_schedule == 'cosine':
            betas, alphas, a_bar = diffusion_utils.cosine_beta_schedule_discrete(self.timesteps, self.nu_arr, 
                                                                         compute_full = self.use_ins_del)
        else:
            raise NotImplementedError(self.noise_schedule)
        
        #TODO: this is probably where we compute the delete noise scheduler

        self._betas = torch.from_numpy(betas)
        self._alphas = torch.from_numpy(alphas)
        # self._alphas = 1 - self._betas

        # log_alpha = torch.log(self._alphas)
        # log_alpha_bar = torch.cumsum(log_alpha, dim=0)
        # self._alphas_bar = torch.exp(log_alpha_bar)
        self._alphas_bar = torch.from_numpy(a_bar)
        # self._alphas_bar = torch.clamp(self._alphas_bar, min=0, max=0.9999)
        log_alpha_bar = torch.log(self._alphas_bar)
        self._log_alpha_bar = log_alpha_bar

        self._sigma2_bar = -torch.expm1(2 * log_alpha_bar)
        self._sigma_bar = torch.sqrt(self._sigma2_bar)
        self._gamma = torch.log(-torch.special.expm1(2 * log_alpha_bar)) - 2 * log_alpha_bar
        # print(f"[Noise schedule: {noise_schedule}] alpha_bar:", self.alphas_bar)
        
        if(self.use_ins_del):
            self.zeta_D             = cfg.features.zeta_D
            self.zeta_w             = cfg.features.zeta_w

            zeta, d_zeta            = diffusion_utils.compute_linear_zeta(T = self.T, 
                                            D = self.zeta_D,
                                            w = self.zeta_w)

            zeta_bar_matr           = diffusion_utils.compute_zeta_bar_matrix(zeta)
            self._zetas             = torch.from_numpy(zeta).unsqueeze(-1)
            self._d_zetas           = torch.from_numpy(d_zeta)
            self._zetas_bar         = torch.from_numpy(zeta_bar_matr)

            self.x_ins_del_slots    = 2


    def move_P_device(self, tensor):
        """ Move the transition matrices to the device specified by tensor."""
        return diffusion_utils.PlaceHolder(X=self.Px.float().to(tensor.device),
                                           E=self.Pe.float().to(tensor.device).float(),
                                           y=self.Py.float().to(tensor.device), 
                                           charges=self.Pcharges.float().to(tensor.device) if self.use_charges else None,
                                           pos=None)

    def get_Qt(self, t_int):
        """ Returns one-step transition matrices for X and E, from step t - 1 to step t.
        Qt = (1 - beta_t) * I + beta_t / K

        beta_t: (bs)                         noise level between 0 and 1
        returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy).
        """
        P = self.move_P_device(t_int)
        kwargs = {'device': t_int.device, 'dtype': torch.float32}

        bx = self.get_beta(t_int=t_int, key='x').unsqueeze(1)
        be = self.get_beta(t_int=t_int, key='e').unsqueeze(1)
        by = self.get_beta(t_int=t_int, key='y').unsqueeze(1)

        A_x     = self.Ax.to(t_int.device)
        B_x     = P.X
        q_x     = (1 - bx)*A_x + bx*B_x

        A_e     = self.Ae.to(t_int.device)       
        B_e     = P.E
        q_e     = (1 - be)*A_e + be*B_e

        # if self.use_ins_del:
        #     zeta_t  = self.get_zeta(t_int=t_int)
        #     C_x     = self.Cx.to(t_int.device)
        #     C_e     = self.Ce.to(t_int.device)

        #     q_x     = q_x*zeta_t + (1-zeta_t)*C_x
        #     q_e     = q_e*zeta_t + (1-zeta_t)*C_e

        q_y = by * P.y + (1 - by) * torch.eye(self.y_classes, **kwargs).unsqueeze(0)

        q_c = None
        if(self.use_charges):
            bc = self.get_beta(t_int=t_int, key='c').unsqueeze(1)
            A_c     = self.Ac.to(t_int.device)
            B_c     = P.charges
            q_c     = (1 - bc)*A_c + bc*B_c

            # if self.use_ins_del:
            #     C_c     = self.Cc.to(t_int.device)
            #     q_c     = q_c*zeta_t + (1-zeta_t)*C_c

        return utils.PlaceHolder(X=q_x, E=q_e, y=q_y, charges=q_c, pos=None)

    def get_Qt_bar(self, t_int, s_int=None):
        """ Returns t-step transition matrices for X and E, from step 0 to step t.
            Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) / K

            alpha_bar_t: (bs)         Product of the (1 - beta_t) for each time step from 0 to t.
            returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy).
        """
        if s_int == None:
            s_int = torch.zeros_like(t_int)

        unsqueeze_dim = -1 if self.use_ins_del else 1

        a_x = self.get_alpha_bar(t_int=t_int, key='x', s_int=s_int).unsqueeze(unsqueeze_dim)
        a_e = self.get_alpha_bar(t_int=t_int, key='e', s_int=s_int).unsqueeze(unsqueeze_dim)
        a_y = self.get_alpha_bar(t_int=t_int, key='y', s_int=s_int).unsqueeze(unsqueeze_dim)

        P = self.move_P_device(t_int)
        # [X, charges, E, y, pos]
        dev = t_int.device
        
        Ax = self.Ax.to(dev)
        Ae = self.Ae.to(dev)

        Bx = P.X
        Be = P.E

        if(self.use_ins_del):
            Cx  = self.Cx.float().to(dev)
            Ce  = self.Ce.float().to(dev)

            Dx  = self.Dx.float().to(dev)
            De  = self.De.float().to(dev)

            z_btm1  = self.get_zeta_bar(t_int=t_int-1, s_int=s_int)
            z_bt= self.get_zeta_bar(t_int=t_int, s_int=s_int)
            z   = self.get_zeta(t_int=t_int)

            A   = utils.PlaceHolder(X=Ax, E=Ae, y=None)
            B   = utils.PlaceHolder(X=Bx, E=Be, y=None)
            C   = utils.PlaceHolder(X=Cx, E=Ce, y=None)
            D   = utils.PlaceHolder(X=Dx, E=De, y=None)

            alpha_b = utils.PlaceHolder(X=a_x,E=a_e,y=None)
        else:
            q_x = a_x*Ax + (1 - a_x)*Bx
            q_e = a_e*Ae + (1 - a_e)*Be
            q_y = a_y * torch.eye(self.y_classes, device=dev).unsqueeze(0) + (1 - a_y) * P.y


        q_c = None
        if(self.use_charges):
            Ac = self.Ac.to(dev)
            a_c = self.get_alpha_bar(t_int=t_int, key='c', s_int=s_int).unsqueeze(unsqueeze_dim)

            if(self.use_ins_del):
                Bc          = P.charges
                Cc          = self.Cc.float().to(dev)
                Dc          = self.Dc.float().to(dev)

                A.charges   = Ac
                B.charges   = Bc
                C.charges   = Cc
                D.charges   = Dc
                alpha_b.charges = a_c
            else:
                q_c = a_c * Ac + (1 - a_c) * P.charges

                assert ((q_c.sum(dim=2) - 1.).abs() < 1e-4).all(), q_c.sum(dim=2) - 1
                
        if self.use_ins_del:
            return A,B,C,D, alpha_b, z, z_bt, z_btm1
        else:
            assert ((q_x.sum(dim=2) - 1.).abs() < 1e-4).all(), q_x.sum(dim=2) - 1
            assert ((q_e.sum(dim=2) - 1.).abs() < 1e-4).all()
            return utils.PlaceHolder(X=q_x, E=q_e, y=q_y, charges=q_c, pos=None)

    #get_beta does not need s_int since it has shape (T+1, n_components)
    def get_beta(self, t_normalized=None, t_int=None, key=None):
        assert int(t_normalized is None) + int(t_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.T)
        b = self._betas.to(t_int.device)[t_int.long()]
        if key is None:
            return b.float()
        else:
            return b[..., self.inverse_mapping[key]].float()

    def get_alpha_bar(self, t_normalized=None, t_int=None, key=None, 
                            s_normalized=None, s_int=torch.tensor(0)):
        assert int(t_normalized is None) + int(t_int is None) == 1
        assert int(s_normalized is None) + int(s_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.T)
        if s_int is None:
            s_int = torch.round(s_normalized * self.T)
        # a = self._alphas_bar.to(t_int.device)[t_int.long()]
        a = self._alphas_bar.to(t_int.device)[:, t_int.long(), s_int.long()]
        if key is None:
            return a.float()
        else:
            #return a[..., self.inverse_mapping[key]].float()
            return a[self.inverse_mapping[key], ...].float()

    def get_sigma_bar(self, t_normalized=None, t_int=None, key=None, 
                            s_normalized=None, s_int=torch.tensor(0)):
        assert int(t_normalized is None) + int(t_int is None) == 1
        assert int(s_normalized is None) + int(s_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.T)
        if s_int is None:
            s_int = torch.round(s_normalized * self.T)
        s = self._sigma_bar.to(t_int.device)[:, t_int.long(), s_int.long()]
        if key is None:
            return s.float()
        else:
            # return s[..., self.inverse_mapping[key]].float()
            return s[self.inverse_mapping[key], ...].float()

    def get_sigma2_bar(self, t_normalized=None, t_int=None, key=None, 
                             s_normalized=None, s_int=torch.tensor(0)):
        assert int(t_normalized is None) + int(t_int is None) == 1
        assert int(s_normalized is None) + int(s_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.T)
        if s_int is None:
            s_int = torch.round(s_normalized * self.T)
        s = self._sigma2_bar.to(t_int.device)[:, t_int.long(), s_int.long()]
        if key is None:
            return s.float()
        else:
            return s[self.inverse_mapping[key], ...].float()

    def get_gamma(self, t_normalized=None, t_int=None, key=None, 
                        s_normalized=None, s_int=torch.tensor(0)):
        assert int(t_normalized is None) + int(t_int is None) == 1
        assert int(s_normalized is None) + int(s_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.T)
        if s_int is None:
            s_int = torch.round(s_normalized * self.T)
        g = self._gamma.to(t_int.device)[:, t_int.long(), s_int.long()]
        if key is None:
            return g.float()
        else:
            return g[self.inverse_mapping[key], ...].float()

    def sigma_pos_ts_sq(self, t_int, s_int):
        gamma_s = self.get_gamma(t_int=s_int, key='p')
        gamma_t = self.get_gamma(t_int=t_int, key='p')
        delta_soft = F.softplus(gamma_s) - F.softplus(gamma_t)
        sigma_squared = - torch.expm1(delta_soft)
        return sigma_squared

    # def get_ratio_sigma_ts(self, t_int, s_int):
    #     """ Compute sigma_t_over_s^2 / (1 - sigma_t_over_s^2)"""
    #     delta_soft = F.softplus(self._gamma[t_int]) - F.softplus(self._gamma[s_int])
    #     return torch.expm1(delta_soft)

    def get_alpha_pos_ts(self, t_int, s_int):
        log_a_bar = self._log_alpha_bar[self.inverse_mapping['p'], ...].to(t_int.device)
        ratio = torch.exp(log_a_bar[t_int, 0] - log_a_bar[s_int, 0])
        return ratio.float()

    def get_alpha_pos_ts_sq(self, t_int, s_int):
        log_a_bar = self._log_alpha_bar[self.inverse_mapping['p'], ...].to(t_int.device)
        ratio = torch.exp(2 * log_a_bar[t_int,0] - 2 * log_a_bar[s_int,0])
        return ratio.float()


    def get_sigma_pos_sq_ratio(self, s_int, t_int):
        log_a_bar = self._log_alpha_bar[self.inverse_mapping['p'],...].to(t_int.device)
        s2_s = - torch.expm1(2 * log_a_bar[s_int,0])
        s2_t = - torch.expm1(2 * log_a_bar[t_int,0])
        ratio = torch.exp(torch.log(s2_s) - torch.log(s2_t))
        return ratio.float()

    def get_zeta(self, t_normalized=None, t_int=None):
        assert int(t_normalized is None) + int(t_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.T)
        
        z = self._zetas.to(t_int.device)[t_int.long()]
        
        return z.float()

    def get_zeta_d(self, t_normalized=None, t_int=None):
        assert int(t_normalized is None) + int(t_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.T)
        
        z = self._d_zetas.to(t_int.device)[t_int.long()]
        
        return z.float()

    def get_zeta_bar(self, t_normalized=None, t_int=None,
                           s_normalized=None, s_int=torch.tensor(0)):
        assert int(t_normalized is None) + int(t_int is None) == 1
        assert int(s_normalized is None) + int(s_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.T)
        if s_int is None:
            s_int = torch.round(s_normalized * self.T)
        
        z = self._zetas_bar.to(t_int.device)[t_int.long(), s_int.long()]
        z = z.unsqueeze(-1)
        
        return z.float()
    
    def get_p_del(self, t_normalized=None, t_int=None,
                           s_normalized=None, s_int=torch.tensor(0)):
        assert int(t_normalized is None) + int(t_int is None) == 1
        assert int(s_normalized is None) + int(s_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.T)
        if s_int is None:
            s_int = torch.round(s_normalized * self.T)
        
        p = self.del_probs.to(t_int.device)[t_int.long(), s_int.long()]
        
        return p.float()

    def get_x_pos_prefactor(self, s_int, t_int):
        """ a_s (s_t^2 - a_t_s^2 s_s^2) / s_t^2"""
        a_s = self.get_alpha_bar(t_int=s_int, key='p')
        alpha_ratio_sq = self.get_alpha_pos_ts_sq(t_int=t_int, s_int=s_int)
        sigma_ratio_sq = self.get_sigma_pos_sq_ratio(s_int=s_int, t_int=t_int)
        prefactor = a_s * (1 - alpha_ratio_sq * sigma_ratio_sq)
        return prefactor.float()

    #TODO: add t as an optional starting parameter
    @torch.no_grad()
    def apply_noise(self, dense_data, corruption_step=-1):
        # Sample noise and apply it to the data. 
        device = dense_data.X.device
        if corruption_step == -1:
            t_int = torch.randint(1, self.T + 1, size=(dense_data.X.size(0), 1), device=device)
        else:
            t_int = torch.full(size=(dense_data.X.size(0), 1), fill_value=corruption_step).to(device=device)
        t_float = t_int.float() / self.T
        
        prob_charges = None
        if(self.use_ins_del):
            # st = time.time()
            # This removes nodes in x0 that will eventually get removed AND adds
            # the nodes that get inserted between t0 and t_int. It also contains
            # the activation time. IT DOES NOT corrupt them to t_int (that is
            # done below)
            delta_T = None
            #used only when we perform optimization since dry corruption (without
            #insert and deletes) works slightly better
            if corruption_step >= 0:
                delta_T = torch.zeros((dense_data.X.size(0), 1), dtype=torch.int, 
                                        device=dense_data.X.device)
            target_data = self.apply_ins_del(dense_data=dense_data, t_int=t_int, delta_T=delta_T)
            # ed = time.time()
            # print("TOTAL TIME:", ed-st)

            t_int               = target_data.t_int
            s_int               = target_data.insert_time
            batch_sz            = target_data.X.size(0)
            max_n_nodes         = target_data.X.size(1)

            # Computing the whole thing is a little more elaborate than DiGress
            # First of all, we need to extract the required alphas_bar. Note how
            # t_int has size (batch_sz) while s_int has size (batch_sz, max_n_nodes)
            # => first, we repeat each element of t_int for max_n_nodes times.
            t_int_ext           = t_int.expand(-1,max_n_nodes)
            t_int_ext_flat      = t_int_ext.reshape(-1)
            s_int_flat          = s_int.reshape(-1)

            # torch.set_printoptions(3)
            # print(f"max_n_nodes:{max_n_nodes}")

            P = self.move_P_device(target_data.X)

            alphas_ts_x         = self.get_alpha_bar(t_int=t_int_ext_flat, s_int=s_int_flat, key='x')
            alphas_ts_x_m1      = 1 - alphas_ts_x
            X_flat              = target_data.X.reshape((-1, self.X_classes + self.extra_classes))
            X_term_1            = alphas_ts_x[:, None]*X_flat
            X_term_2            = alphas_ts_x_m1[:, None]*(X_flat @ P.X)
            probX               = X_term_1 + X_term_2
            probX               = probX.squeeze(0)

            # =================================================================
            # The charges are a little more involved. We have a 1D vector of
            # activation times, in the form
            # [s1,s2,s3,s4]
            # and it needs to be converted to the form
            # [[s1,s2,s3,s4],
            #  [s2,s2,s3,s4],
            #  [s3,s3,s3,s4],
            #  [s4,s4,s4,s4]]
            # The following algorithm does just that.
            e_s_map             = self.mask_to_e_size(mask=s_int)
            e_s_map             = e_s_map.reshape(-1)

            t_int_ext_flat_e    = t_int.expand(-1,max_n_nodes**2).reshape(-1)
            alphas_ts_e         = self.get_alpha_bar(t_int=t_int_ext_flat_e, s_int=e_s_map, key='e')

            alphas_ts_e_m1      = 1 - alphas_ts_e
            E_flat              = target_data.E.reshape((-1, self.E_classes + self.extra_classes))

            E_term_1            = alphas_ts_e[   :, None]* E_flat
            E_term_2            = alphas_ts_e_m1[:, None]*(E_flat.float() @ P.E)
            probE               = E_term_1 + E_term_2
            probE               = probE.squeeze(0)
            # =================================================================
            # TODO: pos
            # =================================================================
            # Mask the DELt nodes
            delt_mask           = target_data.delt_mask

            X_delt_vector       = torch.zeros(probX.size(-1), device=device)
            X_delt_vector[-1]   = 1
            probX[delt_mask.reshape(-1)] = X_delt_vector

            E_delt_mask         = delt_mask.repeat(1,1,max_n_nodes)
            E_delt_mask         = E_delt_mask | E_delt_mask.swapaxes(-1,-2)
            E_delt_mask         = E_delt_mask.reshape(-1)
            E_delt_vector       = torch.zeros(probE.size(-1), device=device)
            E_delt_vector[-1]   = 1
            probE[E_delt_mask]  = E_delt_vector
            # =================================================================
            # Reshape
            probX               = probX.reshape((batch_sz, max_n_nodes, -1))
            probE               = probE.reshape((batch_sz, max_n_nodes, max_n_nodes, -1))
            
            # =================================================================
            if(self.use_charges):
                alphas_ts_c     = self.get_alpha_bar(t_int=t_int_ext_flat, s_int=s_int_flat, key='c')
                alphas_ts_c_m1  = 1 - alphas_ts_c
                c_flat          = target_data.charges.reshape((-1, self.charges_classes + self.extra_classes))
                c_term_1        = alphas_ts_c[:, None]*c_flat
                c_term_2        = alphas_ts_c_m1[:, None]*(c_flat @ P.charges)
                prob_charges    = c_term_1 + c_term_2
                prob_charges    = prob_charges.squeeze(0)

                c_delt_vector   = torch.zeros(prob_charges.size(-1), device=device)
                c_delt_vector[-1] = 1
                prob_charges[delt_mask.reshape(-1)] = c_delt_vector

                prob_charges    = prob_charges.reshape((batch_sz, max_n_nodes, -1))

            t_float = target_data.t_int.float() / self.T
            target_data.t = t_float
            node_mask = target_data.node_mask

            # TODO: pos
        else:
            # Qtb returns two matrices of shape (bs, dx_in, dx_out) and (bs, de_in, de_out)
            Qtb = self.get_Qt_bar(t_int=t_int)

            # Compute transition probabilities
            probX = dense_data.X @ Qtb.X  # (bs, n, dx_out)
            probE = dense_data.E @ Qtb.E.unsqueeze(1)  # (bs, n, n, de_out)
            
            node_mask = dense_data.node_mask

            if(self.use_charges):
                prob_charges = dense_data.charges @ Qtb.charges

            target_data = dense_data

        sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, prob_charges=prob_charges,
                                                             node_mask=node_mask)

        X_t = F.one_hot(sampled_t.X, num_classes=self.X_classes+self.extra_classes)
        E_t = F.one_hot(sampled_t.E, num_classes=self.E_classes+self.extra_classes)

        charges_t = None
        if(self.use_charges):
            charges_t = F.one_hot(sampled_t.charges, num_classes=self.charges_classes+self.extra_classes)

        # This assert does not make sense in didigress since we may have different amounts of nodes
        if(self.use_ins_del == False):
            assert (dense_data.X.shape == X_t.shape) and (dense_data.E.shape == E_t.shape)

        pos_t = None
        if(self.use_3d):
            noise_pos = torch.randn(dense_data.pos.shape, device=dense_data.pos.device)
            noise_pos_masked = noise_pos * dense_data.node_mask.unsqueeze(-1)
            noise_pos_masked = utils.remove_mean_with_mask(noise_pos_masked, dense_data.node_mask)

            # beta_t and alpha_s_bar are used for denoising/loss computation
            a = self.get_alpha_bar(t_int=t_int, key='p').unsqueeze(-1)
            s = self.get_sigma_bar(t_int=t_int, key='p').unsqueeze(-1)
            pos_t = a * dense_data.pos + s * noise_pos_masked

        target = target_data.copy() if self.use_ins_del else None

        z_t = utils.PlaceHolder(X=X_t.detach(), E=E_t.detach(), 
                                y=target_data.y.detach() if not (target_data.y is None) else None, 
                                charges=charges_t.detach() if not (charges_t is None) else None, 
                                pos=pos_t.detach() if not (pos_t is None) else None, 
                                t_int=t_int.detach() if not (t_int is None) else None, 
                                t=t_float if not (t_float is None) else None, 
                                insert_time = torch.zeros((X_t.size(0), X_t.size(1), 1), dtype=int, device=X_t.device),
                                node_mask=target_data.node_mask.detach(), 
                                guidance=target_data.guidance.detach(),
                                delt_mask=target_data.delt_mask.detach() if not (target_data.delt_mask is None) else None).mask()
        
        # Next thing next, we need to create a version of this sample without the DELt.
        if self.use_ins_del:
            delt_sample = z_t.copy()
            delt_sample.delt_mask   = delt_mask.detach().clone()
            delt_sample.insert_time = s_int.detach().clone()
            
            delt_target = delt_sample.delt_mask.int().squeeze(-1).sum(-1)
            delt_target = F.one_hot(delt_target, num_classes=self.max_n_nodes+1)
            
            delt_sample.node_mask  &= ~delt_mask.squeeze(-1)
            delt_sample.mask()
            return z_t, target, delt_sample, delt_target
        else:
            return z_t

    def apply_ins_del(self, dense_data, t_int, delta_T=None):
        device      = dense_data.X.device
        n0          = dense_data.n_nodes
        b_sz, n     = dense_data.X.size(0), dense_data.X.size(1)
        pad_mask    = ~dense_data.node_mask

        node_dist   = self._node_dist.to(device)
        n_dist      = node_dist[n0]
        nT          = torch.multinomial(n_dist, 1)  # Node count at T
        if delta_T is None:
            delta_T = n0.unsqueeze(-1) - nT         # Count differente 0->T
        n_events    = torch.abs(delta_T)            # Number of events happening
        ins_mask    = delta_T <= 0                  # True = we insert. False = we delete

        n_ins       = torch.where( ins_mask, -delta_T, 0)
        n_del       = torch.where(~ins_mask,  delta_T, 0)
        max_n_ins   = torch.max(n_ins).item()

        # torch.set_printoptions(2, 1000000, linewidth=150)
        # print("===============================================================")
        # print("n_eve", n_events.squeeze(-1).tolist())
        # print("n_ins", n_ins.squeeze(-1).tolist())
        # print("n_del", n_del.squeeze(-1).tolist())

        # rand_mask is used to assign the events to the nodes. It shuffles the
        # events randomly among the various VALID nodes
        rand_mask   = torch.rand((b_sz, n), device=device)
        rand_mask   = torch.where(pad_mask, self.T+1, rand_mask)            # rand_mask   = torch.where(pad_mask, self.T+1, rand_mask)  # 
        rand_order  = torch.argsort(rand_mask, dim=-1)

        # print("pad_mask:\n", pad_mask.int())
        # print("rand_mask\n", rand_mask)
        # print("rand_order\n",rand_order)

        # Samples the INSERT and DELETE events:
        event_distr = self._d_zetas[1:]             # Event distribution
        tmp_full    = torch.multinomial(event_distr, b_sz*(n + max_n_ins), replacement=True)
        tmp_full    = tmp_full.reshape((b_sz, (n + max_n_ins))).to(device)
        tmp_del     = tmp_full[:, :n]
        tmp_ins     = tmp_full[:, n:]

        # print("t_int:", t_int.squeeze(-1).tolist())
        # print("n_del", n_del.squeeze(-1).tolist())
        # print("tmp_del BEFORE\n", tmp_del)
        # print("n_ins", n_ins.squeeze(-1).tolist())
        # print("tmp_ins BEFORE\n", tmp_ins)

        # Pads out the padding events and the events happening after t_int
        # NOTE: for the ins, tmp (actually, the "event" matrix below will have to
        #       get the entries equal to t_int removed
        # tmp:                  the un-padded matrix of events
        # t_int:                the corruption step for each sample
        # events_per_sample:    the number of events for each sample (be them INS or DEL)
        def mask_events(tmp, t_int, events_per_sample):
            n_events    = tmp.size(-1)
            extd_t_int  = t_int.expand((-1, n_events))
            tmp_mask    = torch.arange(n_events, device=device)[None, :] < events_per_sample.squeeze(-1)[:, None]
            tmp         = torch.where((~tmp_mask) | (tmp > extd_t_int), self.T+1, tmp)  # tmp[tmp > t_int] = self.T + 1
            
            return tmp

        tmp_del     = mask_events(tmp_del, t_int, n_del)
        tmp_ins     = mask_events(tmp_ins, t_int, n_ins)

        # print("t_int:", t_int.squeeze(-1).tolist())
        # print("tmp_del AFTER\n", tmp_del)
        # print("tmp_ins AFTER\n", tmp_ins)

        ########################################################################
        # The insert mask is easy. We just mask out the events in the graphs
        # where ins_mask = False and whose value is == t_int
        # NOTE: they will likely be unsorted, but that's not important considering
        #       how we will eventually remove excess pad anyway.
        ext_ins_mask= ins_mask.expand((-1, max_n_ins))
        # print("ext_ins_mask\n",ext_ins_mask.int())
        ins_times   = torch.where((~ext_ins_mask) | (tmp_ins == t_int) | (tmp_ins == self.T+1), 0, tmp_ins)
        # print("ins_times\n",ins_times)
        node_mask_2 = ins_times != 0
        # print("node_mask_2\n",node_mask_2.int())
        ########################################################################

        # Orders the mask. Now each event is assigned to a valid node.
        # Only the events happening before (OR AT) t_int are kept
        tmp_del     = torch.gather(tmp_del, 1, rand_order)
        # print("tmp_del\n",tmp_del)

        # The delete mask is basically the events < t_int
        delt_event  = tmp_del.detach().clone()
        # print("delt_event\n",delt_event)
        ext_del_mask= ins_mask.expand((-1, n))
        # print("ext_del_mask\n",ext_del_mask.int())
        delt_event  = torch.where(ext_del_mask, self.T+1, delt_event)       # delt_event[ins_mask[:, 0]] = self.T + 1
        # print("delt_event\n",delt_event)
        del_event   = delt_event.detach().clone()   #the DEL event mask is the same
        # print("del_event\n",del_event)

        # print("t_int:", t_int.squeeze(-1).tolist())
        delt_event  = delt_event == t_int
        del_event   =  del_event <  t_int

        # print("delt_event\n", delt_event.int())
        # print("del_event\n", del_event.int())

        # "not delete and not pad <=> not (delete or pad")"
        #NOTE: pad_mask is true if padding (see above) 
        node_mask_1 = ~(del_event | pad_mask)
        # print("node_mask_1\n", node_mask_1.int())

        # Save for later by the way:
        # target[indexes[:, 0], indexes[:, 1], :] = values
        #######################################################################
        # Now we can insert the new nodes
        final_X         = dense_data.X
        final_E         = dense_data.E
        final_c         = dense_data.charges if self.use_charges else None
        final_pos       = None  # TODO

        final_node_mask = node_mask_1
        final_delt_mask = delt_event.unsqueeze(-1)
        final_ins_time  = dense_data.insert_time

        if(max_n_ins > 0):
            sample_X_pbs= dense_data.node_stats.to(device).unsqueeze(1) #(b_sz, 1, X_classes)
            sample_E_pbs= dense_data.edge_stats.to(device).unsqueeze(1) #(b_sz, 1, E_classes)
            
            full_X_probs= sample_X_pbs.expand(-1,max_n_ins,-1)                  #(b_sz, max_n, x_classes)
            full_E_probs= sample_E_pbs.expand(-1,max_n_ins,-1)                  #(b_sz, max_n, x_classes)
            full_E_probs= full_E_probs.unsqueeze(1).expand(-1,n+max_n_ins,-1,-1)#(b_sz, n+max_n, max_n, x_classes)
            
            sampled_X   = torch.multinomial(full_X_probs.flatten(0,1), 1, replacement=True).squeeze(-1)
            sampled_X   = sampled_X.reshape((b_sz, max_n_ins))
            sampled_X_1h= F.one_hot(sampled_X, num_classes=self.X_classes).squeeze(1)
            sampled_X_1h= sampled_X_1h.reshape((b_sz, max_n_ins, -1))

            final_X     = torch.cat((dense_data.X, sampled_X_1h), dim=1)

            sampled_E   = torch.multinomial(full_E_probs.flatten(0,2), 1, replacement=True)
            sampled_E   = sampled_E.squeeze(-1)                                             #(bsz*(n+max_n_ins)*max_n_ins)
            sampled_E_1h= F.one_hot(sampled_E, num_classes=self.E_classes).squeeze(1)       #(bsz*(n+max_n_ins)*max_n_ins, E_classes)
            sampled_E_1h= sampled_E_1h.reshape((b_sz, n+max_n_ins, max_n_ins, self.E_classes))
            
            # Adds the extra slices to E to account for the new events
            final_E     = F.pad(dense_data.E, (0,0,0,0,0,max_n_ins,0,0), value=self.E_classes)
            final_E     = torch.cat((final_E, sampled_E_1h), dim=2)

            if(self.use_charges):
                expanded_sampled_X = sampled_X.unsqueeze(-1).expand(-1,-1,dense_data.charge_types.size(-1))
                sample_c_pbs= torch.gather(input=dense_data.charge_types.to(device), 
                                        dim=1, index=expanded_sampled_X)
                full_c_probs= sample_c_pbs.expand(-1,n,-1)
                sampled_c   = torch.multinomial(full_c_probs.flatten(0,1), 1, replacement=True)
                sampled_c_1h= F.one_hot(sampled_c, num_classes=self.charges_classes).squeeze(1)
                sampled_c_1h= sampled_c_1h.reshape((b_sz, max_n_ins, -1))
                final_c     = torch.cat((dense_data.charges, sampled_c_1h), dim=1)


            # The new node mask is the stack between node_mask_1 and node_mask_2
            final_node_mask = torch.hstack((node_mask_1,  node_mask_2))
            final_delt_mask = torch.hstack((delt_event, torch.zeros_like(ins_times))).unsqueeze(-1).int()
            final_ins_time  = torch.cat((dense_data.insert_time, ins_times.unsqueeze(-1) + 1), dim=1).int()

        to_return       = utils.PlaceHolder(X=final_X, E=final_E, charges=final_c, 
                            y=dense_data.y, pos=final_pos, t_int=t_int,
                            node_mask=final_node_mask, delt_mask=final_delt_mask,
                            guidance=dense_data.guidance, insert_time=final_ins_time,
                            n_nodes=dense_data.n_nodes)
        to_return.remove_elements(~final_node_mask, only_mask=True)

        to_return.X         = F.pad(to_return.X,       (0,2,0,0,0,0))
        to_return.E         = F.pad(to_return.E,   (0,2,0,0,0,0,0,0))
        if self.use_charges:
            to_return.charges   = F.pad(to_return.charges, (0,2,0,0,0,0))

        assert (final_node_mask.sum(dim=-1) <= self.max_n_nodes).all()
        return to_return

    def get_limit_dist(self):
        X_marginals = self.X_marginals + 1e-7
        X_marginals = X_marginals / torch.sum(X_marginals)
        E_marginals = self.E_marginals + 1e-7
        E_marginals = E_marginals / torch.sum(E_marginals)

        charges_marginals = None
        if(self.use_charges):
            charges_marginals = self.charges_marginals + 1e-7
            charges_marginals = charges_marginals / torch.sum(charges_marginals)

        limit_dist = utils.PlaceHolder(X=X_marginals, E=E_marginals, charges=charges_marginals,
                                       y=None, pos=None)
        return limit_dist

    def sample_limit_dist(self, node_mask):
        """ Sample from the limit distribution of the diffusion process"""

        bs, n_max = node_mask.shape
        x_limit = self.X_marginals.expand(bs, n_max, -1)
        e_limit = self.E_marginals[None, None, None, :].expand(bs, n_max, n_max, -1)

        U_X = x_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max).to(node_mask.device)
        U_E = e_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max, n_max).to(node_mask.device)
        U_y = torch.zeros((bs, 0), device=node_mask.device)

        U_X = F.one_hot(U_X, num_classes=x_limit.shape[-1]).float()
        U_E = F.one_hot(U_E, num_classes=e_limit.shape[-1]).float()

        U_c = None
        if(self.use_charges):
            charges_limit = self.charges_marginals.expand(bs, n_max, -1)
            U_c = charges_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max).to(node_mask.device)
            U_c = F.one_hot(U_c, num_classes=charges_limit.shape[-1]).float()

        # Get upper triangular part of edge noise, without main diagonal
        upper_triangular_mask = torch.zeros_like(U_E)
        indices = torch.triu_indices(row=U_E.size(1), col=U_E.size(2), offset=1)
        upper_triangular_mask[:, indices[0], indices[1], :] = 1

        U_E = U_E * upper_triangular_mask
        U_E = (U_E + torch.transpose(U_E, 1, 2))
        assert (U_E == torch.transpose(U_E, 1, 2)).all()

        pos = None
        if(self.use_3d):
            pos = torch.randn(node_mask.shape[0], node_mask.shape[1], 3, device=node_mask.device)
            pos = pos * node_mask.unsqueeze(-1)
            pos = utils.remove_mean_with_mask(pos, node_mask)

        
        #t_array = pos.new_ones((pos.shape[0], 1)) #TODO: in case of issues with t, this was the original line
        t_array = U_X.new_ones((U_X.shape[0], 1))  #shape[0] is the batch size, right? so we should be good?
        
        delt_mask = None
        insert_time = None
        if(self.use_ins_del):
            U_X = F.pad(U_X, (0,self.extra_classes,0,0,0,0))
            U_E = F.pad(U_E, (0,self.extra_classes,0,0,0,0,0,0))
            if self.use_charges: U_c = F.pad(U_c, (0,self.extra_classes,0,0,0,0))
            delt_mask = torch.zeros((U_X.size(0), U_X.size(1), 1), dtype=bool, device=U_X.device)
            insert_time = torch.zeros((U_X.size(0), U_X.size(1), 1), dtype=int, device=U_X.device)
        
        t_int_array = self.T * t_array.long()
        return utils.PlaceHolder(X=U_X, E=U_E, y=U_y, charges=U_c, pos=pos, t_int=t_int_array, t=t_array,
                                 node_mask=node_mask, delt_mask=delt_mask, insert_time=insert_time).mask(node_mask)

    def sample_zs_from_zt_and_pred(self, z_t, pred, s_int):
        """Samples from zs ~ p(zs | zt). Only used during sampling. """
        bs, n, dxs = z_t.X.shape
        node_mask = z_t.node_mask
        t_int = z_t.t_int

        # If this is the first denoising step, we can treat all Zt as 0 and all Z bar as 1
        is_first_step = (t_int == self.T).all().item()

        # avoids unpleasant issues
        if self.use_ins_del:
            pred.X[:,:, -self.extra_classes:] = -float("inf")
            pred.E[:,:,:, -self.extra_classes:] = -float("inf")
            if self.use_charges:
                pred.charges[:,:, -self.extra_classes:] = -float("inf")

        # Normalize predictions for the categorical features
        # TODO: in case you want to use NLL to train (why?) remember to make
        #       a special case where you have to remove the softmax from here
        pred_X = F.softmax(pred.X, dim=-1)               # bs, n, d0
        pred_E = F.softmax(pred.E, dim=-1)               # bs, n, n, d0

        # Retrieve transitions matrix
        Qt = self.get_Qt(t_int)
        if(self.use_ins_del):
            max_nodes = pred.insert_time.size(1)
            a_int     = z_t.insert_time.squeeze(-1)
            exp_t_int = t_int.expand(-1,max_nodes)
            exp_s_int = s_int.expand(-1,max_nodes)
            delt_mask = z_t.delt_mask
            
            A,B,C,D,alpha_bt,Zt,Ztb,Ztm1b = self.get_Qt_bar(t_int=exp_t_int, s_int=a_int)
            _,_,_,_,alpha_bs,Zs,Zsb,Zsm1b = self.get_Qt_bar(t_int=exp_s_int, s_int=a_int)

            Zt      = torch.where(delt_mask, Zt   , 1)
            Zs      = torch.where(delt_mask, Zs   , 1)
            Ztb     = torch.where(delt_mask, Ztb  , 1)
            Zsb     = torch.where(delt_mask, Zsb  , 1)
            Ztm1b   = torch.where(delt_mask, Ztm1b, 1)
            Zsm1b   = torch.where(delt_mask, Zsm1b, 1)

            Qtbx        = (A.X, B.X, C.X, D.X, alpha_bt.X, Zt, Ztb, Ztm1b)
            Qsbx        = (A.X, B.X, C.X, D.X, alpha_bs.X, Zs, Zsb, Zsm1b)

            # the insertion time of an edge is the maximum between the insert time of the
            # incoming and outgoing nodes
            ins_times_E = self.mask_to_e_size(pred.insert_time).flatten(-2, -1)                         # bsz, 1, n_nodes


            # print("pred.insert_time:\n", pred.insert_time)
            # print("ins_times_E:\n", ins_times_E)

            t_int_E     = t_int.expand(-1,max_nodes**2) # bsz, n_nodes**2
            s_int_E     = s_int.expand(-1,max_nodes**2)

            # print("ins_times_E size:", ins_times_E.size(), 
            #       "t_int_E size:", t_int_E.size(),
            #       "s_int_E size:", s_int_E.size())
            
            alpha_bt.E  = self.get_alpha_bar(t_int = t_int_E, s_int = ins_times_E, key='e')
            alpha_bs.E  = self.get_alpha_bar(t_int = s_int_E, s_int = ins_times_E, key='e')

            # print("alpha_bt.E size:",  alpha_bt.E.size(), "alpha_bs.E size:",  alpha_bs.E.size())

            alpha_bt.E  = alpha_bt.E.reshape((bs, max_nodes**2, 1))
            alpha_bs.E  = alpha_bs.E.reshape((bs, max_nodes**2, 1))

            # (bsz, n_nodes^2, 1) (dunno where the 1 comes from...)
            ZtE         = self.get_zeta(t_int = t_int_E)
            ZsE         = self.get_zeta(t_int = s_int_E)
            ZtbE        = self.get_zeta_bar(t_int = t_int_E, s_int = ins_times_E)
            ZsbE        = self.get_zeta_bar(t_int = s_int_E, s_int = ins_times_E)
            Ztm1bE      = self.get_zeta_bar(t_int = t_int_E - 1, s_int = ins_times_E)
            Zsm1bE      = self.get_zeta_bar(t_int = s_int_E - 1, s_int = ins_times_E)
            
            delt_mask_E = delt_mask.expand(-1, -1, delt_mask.size(1))
            delt_mask_E = delt_mask_E | delt_mask_E.swapaxes(-1, -2)
            delt_mask_E = delt_mask_E.flatten(-2,-1).unsqueeze(-1)

            ZtE         = torch.where(delt_mask_E, ZtE   , 1)
            ZsE         = torch.where(delt_mask_E, ZsE   , 1)
            ZtbE        = torch.where(delt_mask_E, ZtbE  , 1)
            ZsbE        = torch.where(delt_mask_E, ZsbE  , 1)
            Ztm1bE      = torch.where(delt_mask_E, Ztm1bE, 1)
            Zsm1bE      = torch.where(delt_mask_E, Zsm1bE, 1)

            Qtbe        = (A.E, B.E, C.E, D.E, alpha_bt.E, ZtE, ZtbE, Ztm1bE)
            Qsbe        = (A.E, B.E, C.E, D.E, alpha_bs.E, ZsE, ZsbE, Zsm1bE)
        else:
            Qtb = self.get_Qt_bar(t_int=t_int)
            Qsb = self.get_Qt_bar(t_int=s_int)

            Qsbx = Qsb.X
            Qtbx = Qtb.X

            Qtbe = Qtb.E
            Qsbe = Qsb.E

        p_s_and_t_given_0_X = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=z_t.X,
                                                                                           Qt=Qt.X,
                                                                                           Qsb=Qsbx,
                                                                                           Qtb=Qtbx,
                                                                                           use_ins_del=self.use_ins_del)

        p_s_and_t_given_0_E = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=z_t.E,
                                                                                           Qt=Qt.E,
                                                                                           Qsb=Qsbe,
                                                                                           Qtb=Qtbe,
                                                                                           use_ins_del=self.use_ins_del)

        # Dim of these two tensors: bs, N, d0, d_t-1
        weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X         # bs, n, d0, d_t-1
        unnormalized_prob_X = weighted_X.sum(dim=2)                     # bs, n, d_t-1
        unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
        prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  # bs, n, d_t-1

        pred_E = pred_E.reshape((bs, -1, pred_E.shape[-1]))
        weighted_E = pred_E.unsqueeze(-1) * p_s_and_t_given_0_E        # bs, N, d0, d_t-1
        unnormalized_prob_E = weighted_E.sum(dim=-2)
        unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
        prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)
        prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])

        assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
        assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()

        prob_c = None
        if(self.use_charges):
            pred_charges = F.softmax(pred.charges, dim=-1)
            if self.use_ins_del:
                Qsbc = (A.charges, B.charges, C.charges, D.charges, alpha_bs.charges, Zs, Zsb, Zsm1b)
                Qtbc = (A.charges, B.charges, C.charges, D.charges, alpha_bt.charges, Zt, Ztb, Ztm1b)
            else:
                Qsbc = Qsb.charges
                Qtbc = Qtb.charges

            p_s_and_t_given_0_c = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=z_t.charges,
                                                                                            Qt=Qt.charges,
                                                                                            Qsb=Qsbc,
                                                                                            Qtb=Qtbc,
                                                                                            use_ins_del=self.use_ins_del)

            weighted_c = pred_charges.unsqueeze(-1) * p_s_and_t_given_0_c         # bs, n, d0, d_t-1
            unnormalized_prob_c = weighted_c.sum(dim=2)                     # bs, n, d_t-1
            unnormalized_prob_c[torch.sum(unnormalized_prob_c, dim=-1) == 0] = 1e-5
            prob_c = unnormalized_prob_c / torch.sum(unnormalized_prob_c, dim=-1, keepdim=True)  # bs, n, d_t-1

            assert ((prob_c.sum(dim=-1) - 1).abs() < 1e-4).all()

        pos = None
        if(self.use_3d):
            # # Sample the positions
            sigma_sq_ratio = self.get_sigma_pos_sq_ratio(s_int=s_int, t_int=t_int)
            z_t_prefactor = (self.get_alpha_pos_ts(t_int=t_int, s_int=s_int) * sigma_sq_ratio).unsqueeze(-1)
            x_prefactor = self.get_x_pos_prefactor(s_int=s_int, t_int=t_int).unsqueeze(-1)

            mu = z_t_prefactor * z_t.pos + x_prefactor * pred.pos             # bs, n, 3

            sampled_pos = torch.randn(z_t.pos.shape, device=z_t.pos.device) * node_mask.unsqueeze(-1)
            noise = utils.remove_mean_with_mask(sampled_pos, node_mask=node_mask)

            prefactor1 = self.get_sigma2_bar(t_int=t_int, key='p')
            prefactor2 = self.get_sigma2_bar(t_int=s_int, key='p') * self.get_alpha_pos_ts_sq(t_int=t_int, s_int=s_int)
            sigma2_t_s = prefactor1 - prefactor2
            noise_prefactor_sq = sigma2_t_s * sigma_sq_ratio
            noise_prefactor = torch.sqrt(noise_prefactor_sq).unsqueeze(-1)
            pos = mu + noise_prefactor * noise                                # bs, n, 3

        sampled_s = diffusion_utils.sample_discrete_features(probX=prob_X, probE=prob_E, prob_charges=prob_c, node_mask=z_t.node_mask)

        X_s = F.one_hot(sampled_s.X, num_classes=self.X_classes + self.extra_classes).float()
        E_s = F.one_hot(sampled_s.E, num_classes=self.E_classes + self.extra_classes).float()
        charges_s = None
        if(self.use_charges):
            charges_s = F.one_hot(sampled_s.charges, num_classes=self.charges_classes+ self.extra_classes).float()

        assert (E_s == torch.transpose(E_s, 1, 2)).all()
        assert (z_t.X.shape == X_s.shape) and (z_t.E.shape == E_s.shape)

        delt_mask=None
        if(self.use_ins_del):
            delt_mask = sampled_s.X == self.X_classes + self.extra_classes - 1
            delt_mask = delt_mask.unsqueeze(-1)

        z_s = utils.PlaceHolder(X=X_s, charges=charges_s,
                                E=E_s, y=torch.zeros(z_t.y.shape[0], 0, device=X_s.device), pos=pos,
                                t_int=s_int, t=s_int / self.T, node_mask=node_mask, guidance=z_t.guidance,
                                delt_mask=delt_mask, insert_time=z_t.insert_time).mask(node_mask)
        return z_s
    
    def mask_to_e_size(self, mask):
        e_s_map1        = mask                  #(b_sz, n_nodes, 1)
        e_s_map2        = mask.swapaxes(-1, -2) #(b_sz, 1, n_nodes)
        e_s_map         = torch.max(e_s_map1, e_s_map2) #(b_sz, n_nodes, n_nodes)

        return e_s_map

class DiscreteUniformTransition(NoiseModel):
    def __init__(self, cfg, output_dims):
        super().__init__(cfg=cfg)
        self.X_classes = output_dims.X
        self.E_classes = output_dims.E
        self.y_classes = output_dims.y

        self.X_marginals = torch.ones(self.X_classes) / self.X_classes
        self.E_marginals = torch.ones(self.E_classes) / self.E_classes
        self.y_marginals = torch.ones(self.y_classes) / self.y_classes

        self.Px = torch.ones(1, self.X_classes, self.X_classes) / self.X_classes
        self.Pe = torch.ones(1, self.E_classes, self.E_classes) / self.E_classes
        self.Pe = torch.ones(1, self.y_classes, self.y_classes) / self.y_classes

        if(cfg.features.use_charges):
            self.charges_classes = output_dims.charges
            self.charges_marginals = torch.ones(self.charges_classes) / self.charges_classes
            self.Pcharges = torch.ones(1, self.charges_classes, self.charges_classes) / self.charges_classes

class MarginalUniformTransition(NoiseModel):
    def __init__(self, cfg, dataset_infos):
        super().__init__(cfg=cfg)

        x_marginals         = dataset_infos.node_types
        e_marginals         = dataset_infos.edge_types
        y_classes           = dataset_infos.output_dims.y
        charges_marginals   = dataset_infos.charges_marginals

        self.extra_classes = 0

        self.X_classes = len(x_marginals)
        self.E_classes = len(e_marginals)
        self.y_classes = y_classes

        self.X_marginals = x_marginals
        self.E_marginals = e_marginals
        self.y_marginals = torch.ones(self.y_classes) / self.y_classes

        self.Px = x_marginals.unsqueeze(0).expand(self.X_classes, -1)
        self.Pe = e_marginals.unsqueeze(0).expand(self.E_classes, -1)
        self.Py = torch.ones(1, self.y_classes, self.y_classes) / self.y_classes

        self.Ax = create_A_matrix(self.X_classes, self.use_ins_del).unsqueeze(0)
        self.Ae = create_A_matrix(self.E_classes, self.use_ins_del).unsqueeze(0)

        if(self.use_ins_del):
            self.extra_classes = 2
            self.Px         = F.pad(self.Px, (0,2,0,2))
            self.Pe         = F.pad(self.Pe, (0,2,0,2))

            self.Px[-2,-2]  = 1
            self.Px[-1,-2]  = 1

            self.Pe[-2,-2]  = 1
            self.Pe[-1,-2]  = 1

            self.Cx = create_C_matrix(self.X_classes).unsqueeze(0)
            self.Dx = create_D_matrix(self.X_classes).unsqueeze(0)

            self.Ce = create_C_matrix(self.E_classes).unsqueeze(0)
            self.De = create_D_matrix(self.E_classes).unsqueeze(0)

            # It would have been better to do it in the mother class, but this
            # one is the one with access to the dataset_infos
            self.max_n_nodes = dataset_infos.max_n_nodes if cfg.features.max_n < 1 \
                                                         else cfg.features.max_n

            allow_zero_nodes = True
            if cfg.model.extra_features is not None and\
                cfg.model.extra_features in ["eigenvalues", "eigenvalues_strict"]:
                allow_zero_nodes = False
            self.node_dist  = diffusion_utils.compute_p_nodes(max_n=self.max_n_nodes,
                                            p_min=cfg.features.node_p_min,
                                            p_max=cfg.features.node_p_max,
                                            allow_zero_nodes = allow_zero_nodes)
            self._node_dist = torch.from_numpy(self.node_dist)

        self.Px = self.Px.unsqueeze(0)
        self.Pe = self.Pe.unsqueeze(0)

        if(cfg.features.use_charges):
            self.charges_classes = len(charges_marginals)
            self.charges_marginals = charges_marginals
            self.Pcharges = charges_marginals.unsqueeze(0).expand(self.charges_classes, -1)
            self.Ac = create_A_matrix(self.charges_classes, self.use_ins_del)
            if(self.use_ins_del):
                self.Pcharges           = F.pad(self.Pcharges, (0,2,0,2))
                self.Pcharges[-2,-2]    = 1
                self.Pcharges[-1,-2]    = 1

                self.Cc = create_C_matrix(self.charges_classes).unsqueeze(0)
                self.Dc = create_D_matrix(self.charges_classes).unsqueeze(0)
                
            self.Pcharges = self.Pcharges.unsqueeze(0)
            self.charges_types = dataset_infos.charges_types

def create_A_matrix(size, use_ins_del):
    if(use_ins_del):
        A = torch.eye(size+2)
        A[-1,-1] = 0
        A[-1,-2] = 1
        return A
    else:
        return torch.eye(size)

# the B matrix is self. Px/Pe/Pcharges so we already have it

def create_C_matrix(size):
    C = torch.zeros((size+2, size+2))
    C[:-2,-1] = 1
    C[-2:, -2] = 1
    return C

def create_D_matrix(size):
    D = torch.zeros((size+2, size+2))
    D[:, -2] = 1
    return D