import dgl.sparse as dglsp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from tqdm import tqdm
import math
from .gnn import *
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"

__all__ = ["ModelAsync"]

class PlaceHolder:
    def __init__(self, X, E, y):
        self.X = X
        self.E = E
        self.y = y

    def type_as(self, x: torch.Tensor):
        """ Changes the device and dtype of X, E, y. """
        self.X = self.X.type_as(x)
        self.E = self.E.type_as(x)
        self.y = self.y.type_as(x)
        return self

    def mask(self, node_mask, argmax=False, collapse=False, diag=True):
        x_mask = node_mask.unsqueeze(-1)          # bs, n, 1
        e_mask1 = x_mask.unsqueeze(2)             # bs, n, 1, 1
        e_mask2 = x_mask.unsqueeze(1)             # bs, 1, n, 1

        if collapse:
            if argmax:
                self.X = torch.argmax(self.X, dim=-1)
                self.E = torch.argmax(self.E, dim=-1)

            self.X[node_mask == 0] = - 1
            self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
        else:
            self.X = self.X * x_mask
            self.E = self.E * e_mask1 * e_mask2
            if diag:
                diag_m = torch.eye(self.E.shape[1], dtype=torch.bool).unsqueeze(0).expand(self.E.shape[0], -1, -1)
                self.E[diag_m] = 0
            assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
        return self

class RFF(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(RFF, self).__init__()
        self.out_dim = out_dim
        self.weights = nn.Parameter(torch.randn([in_dim, round(out_dim/2)]), requires_grad=False).cuda()

    def forward(self, x):
        x = x.matmul(self.weights)
        z = torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
        D = self.out_dim
        z = z / math.sqrt(D)
        return z

class GCKLayer_E(nn.Module):
    def __init__(self, gamma, in_dim, out_dim):
        super(GCKLayer_E, self).__init__()
        self.gamma = gamma
        self.map = RFF(in_dim, out_dim)

    def forward(self, adj, inputs):
        Z_t = adj @ inputs
        Z = self.map(Z_t/math.sqrt(self.gamma))

#        Z = self.prop(inputs, adj)
#        Z = self.map(Z / math.sqrt(self.gamma))
        return Z

class GCKM_E(nn.Module):
    def __init__(self, gamma_list, dim_list, mode):
        super(GCKM_E, self).__init__()
        self.model = nn.ModuleList()
        self.gamma_list = gamma_list
        self.dim_list = dim_list
        for i, gamma in enumerate(gamma_list):
            self.model.append(GCKLayer_E(gamma, dim_list[i], dim_list[i+1]))
        self.mode = mode

    def reset(self):
        self.model = nn.ModuleList()
        for i, gamma in enumerate(self.gamma_list):
            self.model.append(GCKLayer_E(gamma, self.dim_list[i], self.dim_list[i+1]))

    def forward(self, adj, X, do_reset=False):
        if do_reset == True:
            self.reset()

        Z = X
        for layer in self.model:
            Z = layer(adj, Z)
        if self.mode == 1:
            return Z @ Z.t()
        elif self.mode == 2:
            return Z

class MarginalTransition(nn.Module):
    """
    Parameters
    ----------
    device : torch.device
    X_marginal : torch.Tensor of shape (F, 2)
        X_marginal[f, :] is the marginal distribution of the f-th node attribute.
    E_marginal : torch.Tensor of shape (2)
        Marginal distribution of the edge existence.
    num_classes_E : int
        Number of edge classes.
    """
    def __init__(self,
                 device,
                 X_marginal,
                 E_marginal,
                 X_cond_Y_marginals,
                 num_classes_E):
        super().__init__()

        num_attrs_X, num_classes_X = X_marginal.shape
        # (F, 2, 2)
        self.I_X = torch.eye(num_classes_X, device=device).unsqueeze(0).expand(
            num_attrs_X, num_classes_X, num_classes_X).clone()
        # (2, 2)
        self.I_E = torch.eye(num_classes_E, device=device)

        # (F, 2, 2)
        self.m_X = X_marginal.unsqueeze(1).expand(
            num_attrs_X, num_classes_X, -1).clone()

        self.tran_X = X_cond_Y_marginals.unsqueeze(2).repeat(1,1,2,1)
        self.tran_X = nn.Parameter(self.tran_X, requires_grad=False)

        # (2, 2)
        self.m_E = E_marginal.unsqueeze(0).expand(num_classes_E, -1).clone()

        self.I_X = nn.Parameter(self.I_X, requires_grad=False)
        self.I_E = nn.Parameter(self.I_E, requires_grad=False)

        self.m_X = nn.Parameter(self.m_X, requires_grad=False)
        self.m_E = nn.Parameter(self.m_E, requires_grad=False)

    def get_Q_bar_E(self, alpha_bar_t):
        """Compute the probability transition matrices for obtaining A^t.

        Parameters
        ----------
        alpha_bar_t : torch.Tensor of shape (1)
            A value in [0, 1].

        Returns
        -------
        Q_bar_t_E : torch.Tensor of shape (2, 2)
            Transition matrix for corrupting graph structure at time step t.
        """
        Q_bar_t_E = alpha_bar_t * self.I_E + (1 - alpha_bar_t) * self.m_E

        return Q_bar_t_E

    def get_Q_bar_X(self, alpha_bar_t):
        """Compute the probability transition matrices for obtaining X^t.

        Parameters
        ----------
        alpha_bar_t : torch.Tensor of shape (1)
            A value in [0, 1].

        Returns
        -------
        Q_bar_t_X : torch.Tensor of shape (F, 2, 2)
            Transition matrix for corrupting node attributes at time step t.
        """
        Q_bar_t_X = alpha_bar_t * self.I_X + (1 - alpha_bar_t) * self.m_X

        return Q_bar_t_X

class NoiseSchedule(nn.Module):
    """
    Parameters
    ----------
    T : int
        Number of diffusion time steps.
    device : torch.device
    s : float
        Small constant for numerical stability.
    """
    def __init__(self, T, device, s=0.008):
        super().__init__()

        # Cosine schedule as proposed in
        # https://arxiv.org/abs/2102.09672
        num_steps = T + 2
        t = np.linspace(0, num_steps, num_steps)
        # Schedule for \bar{alpha}_t = alpha_1 * ... * alpha_t
        alpha_bars = np.cos(0.5 * np.pi * ((t / num_steps) + s) / (1 + s)) ** 2
        # Make the largest value 1.
        alpha_bars = alpha_bars / alpha_bars[0]
        alphas = alpha_bars[1:] / alpha_bars[:-1]

        self.betas = torch.from_numpy(1 - alphas).float().to(device)
        self.alphas = 1 - torch.clamp(self.betas, min=0, max=0.9999)

        log_alphas = torch.log(self.alphas)
        log_alpha_bars = torch.cumsum(log_alphas, dim=0)
        self.alpha_bars = torch.exp(log_alpha_bars)

        self.betas = nn.Parameter(self.betas, requires_grad=False)
        self.alphas = nn.Parameter(self.alphas, requires_grad=False)
        self.alpha_bars = nn.Parameter(self.alpha_bars, requires_grad=False)

        # Hyper-parameters for BetaDIffusion
        self.sigmoid_start = 9
        self.sigmoid_end = -9
        self.sigmoid_power = 0.5
        self.beta_max = 20
        self.beta_min = 0.1
        self.Scale = 0.39
        self.Shift = 0.60
        self.eta = torch.tensor(1432, dtype=torch.float32)
        self.use_fea = True

        t_norm = torch.arange(num_steps) / (num_steps - 1) * (1. - 1e-5)
        s_norm = t_norm * 0.95
        logit_alpha_t = self.sigmoid_start + (self.sigmoid_end - self.sigmoid_start) * (t_norm ** self.sigmoid_power)
        logit_alpha_s = self.sigmoid_start + (self.sigmoid_end - self.sigmoid_start) * (s_norm ** self.sigmoid_power)

        alpha_t, alpha_s = torch.sigmoid(logit_alpha_t), torch.sigmoid(logit_alpha_s)
        delta = (logit_alpha_s.to(torch.float64).sigmoid() - logit_alpha_t.to(torch.float64).sigmoid()).to(torch.float32)        
        self.betadiff_alpha_t = alpha_t.to(device)
        self.betadiff_alpha_s = alpha_s.to(device)
        self.betadiff_delta = delta.to(device)

    def log_gamma(self, alpha):
        return torch.log(torch._standard_gamma(alpha.to(torch.float32)).clamp(torch.finfo(torch.float32).tiny))

class LossE(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, true_E, logit_E):
        """
        Parameters
        ----------
        true_E : torch.Tensor of shape (B, 2)
            One-hot encoding of the edge existence for a batch of node pairs.
        logit_E : torch.Tensor of shape (B, 2)
            Predicted logits for the edge existence.

        Returns
        -------
        loss_E : torch.Tensor
            Scalar representing the loss for edge existence.
        """
        true_E = torch.argmax(true_E, dim=-1)    # (B)
        loss_E = F.cross_entropy(logit_E, true_E)

        return loss_E

class BaseModel(nn.Module):
    """
    Parameters
    ----------
    T : int
        Number of diffusion time steps - 1.
    X_marginal : torch.Tensor of shape (F, 2)
        X_marginal[f, :] is the marginal distribution of the f-th node attribute.
    Y_marginal : torch.Tensor of shape (C)
        Marginal distribution of the node labels.
    E_marginal : torch.Tensor of shape (2)
        Marginal distribution of the edge existence.
    num_nodes : int
        Number of nodes in the original graph.
    """
    def __init__(self,
                 T,
                 X_marginal,
                 Y_marginal,
                 E_marginal,
                 X_cond_Y_marginals,
                 data_train_mask,
                 num_nodes):
        super().__init__()

        device = X_marginal.device
        self.device = device
        # 2 for if edge exists or not.
        self.num_classes_E = 2
        self.num_attrs_X, self.num_classes_X = X_marginal.shape
        self.num_classes_Y = len(Y_marginal)

        self.transition = MarginalTransition(device, X_marginal,
                                             E_marginal, X_cond_Y_marginals, self.num_classes_E)

        self.T = T
        # Number of intermediate time steps to use for validation.
        self.num_denoise_match_samples = self.T
        self.noise_schedule = NoiseSchedule(T, device)

        self.num_nodes = num_nodes

        self.X_marginal = X_marginal
        self.Y_marginal = Y_marginal
        self.E_marginal = E_marginal

        self.loss_E = LossE()

    def sample_E(self, prob_E):
        """Sample a graph structure from prob_E.

        Parameters
        ----------
        prob_E : torch.Tensor of shape (|V|, |V|, 2)
            Probability distribution for edge existence.

        Returns
        -------
        E_t : torch.LongTensor of shape (|V|, |V|)
            Sampled symmetric adjacency matrix.
        """
        # (|V|^2, 1)
#        E_t = prob_E.reshape(-1, prob_E.size(-1)).multinomial(1)

        device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
        tmp1 = prob_E[:6000,:,:].to(device).reshape(-1, prob_E.size(-1)).multinomial(1)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        tmp2 = prob_E[6000:12000,:,:].to(device).reshape(-1, prob_E.size(-1)).multinomial(1)
        device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
        tmp3 = prob_E[12000:,:,:].to(device).reshape(-1, prob_E.size(-1)).multinomial(1)

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        E_t = torch.cat([tmp1.to(device), tmp2.to(device), tmp3.to(device)], 0)

        # (|V|, |V|)
        num_nodes = prob_E.size(0)
        del prob_E
        E_t = E_t.reshape(num_nodes, num_nodes)
        # Make it symmetric for undirected graphs.
        src, dst = torch.triu_indices(
            num_nodes, num_nodes, device=E_t.device)
        E_t[dst, src] = E_t[src, dst]
        return E_t

    def sample_X(self, prob_X):
        """Sample node attributes from prob_X.

        Parameters
        ----------
        prob_X : torch.Tensor of shape (F, |V|, 2)
            Probability distributions for node attributes.

        Returns
        -------
        X_t_one_hot : torch.Tensor of shape (|V|, 2 * F)
            One-hot encoding of the sampled node attributes.
        """
        # (F * |V|)
        X_t = prob_X.reshape(-1, prob_X.size(-1)).multinomial(1)
        # (F, |V|)
        X_t = X_t.reshape(self.num_attrs_X, -1)
        # (|V|, 2 * F)
        X_t_one_hot = torch.cat([
            F.one_hot(X_t[i], num_classes=self.num_classes_X)
            for i in range(self.num_attrs_X)
        ], dim=1).float()

        del X_t, prob_X
        return X_t_one_hot  

    def get_adj(self, E_t):
        """
        Parameters
        ----------
        E_t : torch.LongTensor of shape (|V|, |V|)
            Sampled symmetric adjacency matrix.

        Returns
        -------
        dglsp.SparseMatrix
            Row-normalized adjacency matrix.
        """
        # Row normalization.
        edges_t = E_t.nonzero().T
        num_nodes = E_t.size(0)
        A_t = torch.zeros(num_nodes, num_nodes).to(self.device)
        A_t[edges_t[0], edges_t[1]] = 1
        D_t = torch.diag(A_t.sum(1) ** -1)
        from numpy import inf
        D_t[D_t == inf] = 0
#        A_t = dglsp.spmatrix(edges_t, shape=(num_nodes, num_nodes))
#        D_t = dglsp.diag(A_t.sum(1)) ** -1
        return torch.spmm(D_t, A_t)

    def denoise_match_Z(self,
                        Z_t_one_hot,
                        Q_t_Z,
                        Z_one_hot,
                        Q_bar_s_Z,
                        pred_Z):
        """Compute the denoising match term for Z given a
        sampled t, i.e., the KL divergence between q(D^{t-1}| D, D^t) and
        q(D^{t-1}| hat{p}^{D}, D^t).

        Parameters
        ----------
        Z_t_one_hot : torch.Tensor of shape (B, C) or (A, B, C)
            One-hot encoding of the data sampled at time step t.
        Q_t_Z : torch.Tensor of shape (C, C) or (A, C, C)
            Transition matrix from time step t - 1 to t.
        Z_one_hot : torch.Tensor of shape (B, C) or (A, B, C)
            One-hot encoding of the original data.
        Q_bar_s_Z : torch.Tensor of shape (C, C) or (A, C, C)
            Transition matrix from timestep 0 to t-1.
        pred_Z : torch.Tensor of shape (B, C) or (A, B, C)
            Predicted probs for the original data.

        Returns
        -------
        float
            KL value.
        """
        # q(Z^{t-1}| Z, Z^t)
        left_term = Z_t_one_hot @ torch.transpose(Q_t_Z, -1, -2) # (B, C) or (A, B, C)
        right_term = Z_one_hot @ Q_bar_s_Z                       # (B, C) or (A, B, C)
        product = left_term * right_term                         # (B, C) or (A, B, C)
        denom = product.sum(dim=-1)                              # (B,) or (A, B)
        denom[denom == 0.] = 1
        prob_true = product / denom.unsqueeze(-1)                # (B, C) or (A, B, C)

        # q(Z^{t-1}| hat{p}^{Z}, Z^t)
        right_term = pred_Z @ Q_bar_s_Z                          # (B, C) or (A, B, C)
        product = left_term * right_term                         # (B, C) or (A, B, C)
        denom = product.sum(dim=-1)                              # (B,) or (A, B)
        denom[denom == 0.] = 1
        prob_pred = product / denom.unsqueeze(-1)                # (B, C) or (A, B, C)

        # KL(q(Z^{t-1}| hat{p}^{Z}, Z^t) || q(Z^{t-1}| Z, Z^t))
        kl = F.kl_div(input=prob_pred.log(), target=prob_true, reduction='none')
        return kl.clamp(min=0).mean().item()

    def denoise_match_E(self,
                        t_float,
                        logit_E,
                        E_t_one_hot,
                        E_one_hot):
        """Compute the denoising match term for edge prediction given a
        sampled t, i.e., the KL divergence between q(D^{t-1}| D, D^t) and
        q(D^{t-1}| hat{p}^{D}, D^t).

        Parameters
        ----------
        t_float : torch.Tensor of shape (1)
            Sampled timestep divided by self.T.
        logit_E : torch.Tensor of shape (B, 2)
            Predicted logits for the edge existence of a batch of node pairs.
        E_t_one_hot : torch.Tensor of shape (B, 2)
            One-hot encoding of sampled edge existence for the batch of
            node pairs.
        E_one_hot : torch.Tensor of shape (B, 2)
            One-hot encoding of the original edge existence for the batch of
            node pairs.

        Returns
        -------
        float
            KL value.
        """
        t = int(t_float.item() * self.T)
        s = t - 1

        alpha_bar_s = self.noise_schedule.alpha_bars[s]
        alpha_t = self.noise_schedule.alphas[t]

        Q_bar_s_E = self.transition.get_Q_bar_E(alpha_bar_s)
        # Note that computing Q_bar_t from alpha_bar_t is the same
        # as computing Q_t from alpha_t.
        Q_t_E = self.transition.get_Q_bar_E(alpha_t)

        pred_E = logit_E.softmax(dim=-1)

        return self.denoise_match_Z(E_t_one_hot,
                                    Q_t_E,
                                    E_one_hot,
                                    Q_bar_s_E,
                                    pred_E)

    def posterior(self,
                  Z_t,
                  Q_t,
                  Q_bar_s,
                  Q_bar_t,
                  prior):
        """Compute the posterior distribution for time step s, i.e., t - 1.

        Parameters
        ----------
        Z_t : torch.Tensor of shape (B, 2) or (F, |V|, 2)
            One-hot encoding of the sampled data at timestep t.
            B for batch size, C for number of classes, D for number
            of features.
        Q_t : torch.Tensor of shape (2, 2) or (F, 2, 2)
            The transition matrix from timestep t-1 to t.
        Q_bar_s : torch.Tensor of shape (2, 2) or (F, 2, 2)
            The transition matrix from timestep 0 to t-1.
        Q_bar_t : torch.Tensor of shape (2, 2) or (F, 2, 2)
            The transition matrix from timestep 0 to t.
        prior : torch.Tensor of shape (B, 2) or (F, |V|, 2)
            Reconstructed prior distribution.

        Returns
        -------
        prob : torch.Tensor of shape (B, 2) or (D, B, C)
            Posterior distribution.
        """
        # (B, 2) or (F, |V|, 2)
        left_term = Z_t @ torch.transpose(Q_t, -1, -2)
        # (B, 1, 2) or (F, |V|, 1, 2)
        left_term = left_term.unsqueeze(dim=-2)
        # (1, 2, 2) or (F, 1, 2, 2)
        right_term = Q_bar_s.unsqueeze(dim=-3)
        # (B, 2, 2) or (F, |V|, 2, 2)
        # Different from denoise_match_z, this function does not
        # compute (Z_t @ Q_t.T) * (Z_0 @ Q_bar_s) for a specific
        # Z_0, but compute for all possible values of Z_0.
        numerator = left_term * right_term

        # (2, B) or (F, 2, |V|)
        prod = Q_bar_t @ torch.transpose(Z_t, -1, -2)
        # (B, 2) or (F, |V|, 2)
        prod = torch.transpose(prod, -1, -2)
        # (B, 2, 1) or (F, |V|, 2, 1)
        denominator = prod.unsqueeze(-1)
        denominator[denominator == 0.] = 1.

        # (B, 2, 2) or (F, |V|, 2, 2)
        out = numerator / denominator

        # (B, 2, 2) or (F, |V|, 2, 2)
        prob = prior.unsqueeze(-1) * out
        # (B, 2) or (F, |V|, C)
        prob = prob.sum(dim=-2)

        return prob

    def sample_E_infer(self, prob_E):
        """Draw a sample from prob_E

        Parameters
        ----------
        prob_E : torch.Tensor of shape (B, 2)
            Probability distributions for edge existence.

        Returns
        -------
        E_t : torch.LongTensor of shape (|V|, |V|)
            Sampled adjacency matrix.
        """
        E_t_ = prob_E.multinomial(1).squeeze(-1)
        E_t = torch.zeros(self.num_nodes, self.num_nodes).long().to(E_t_.device)
        E_t[self.dst, self.src] = E_t_
        E_t[self.src, self.dst] = E_t_

        return E_t

    def get_E_t(self,
                device,
                edge_data_loader,
                pred_E_func,
                t_float,
                X_t_one_hot,
                Y_0,
                E_t,
                Q_t_E,
                Q_bar_s_E,
                Q_bar_t_E,
                batch_size
                ):
        A_t = self.get_adj(E_t)
        E_prob = torch.zeros(len(self.src), self.num_classes_E).to(device)

        start = 0
        for batch_edge_index in edge_data_loader:
            # (B, 2)
            batch_edge_index = batch_edge_index.to(device)
            batch_dst, batch_src = batch_edge_index.T
            # Reconstruct the edges.
            # (B, 2)
            batch_pred_E = pred_E_func(t_float,
                                       X_t_one_hot,
                                       Y_0,
                                       A_t,
                                       batch_src,
                                       batch_dst)

            batch_pred_E = batch_pred_E.softmax(dim=-1)

            # (B, 2)
            batch_E_t_one_hot = F.one_hot(
                E_t[batch_src, batch_dst],
                num_classes=self.num_classes_E).float()

            batch_E_prob_ = self.posterior(batch_E_t_one_hot, Q_t_E,
                                           Q_bar_s_E, Q_bar_t_E, batch_pred_E)

            end = start + batch_size
            E_prob[start: end] = batch_E_prob_
            start = end

        E_prob[:,1] = (E_prob[:,1] * self.transition.m_E[0,1] / E_prob[:,1].mean()).clamp_max(1)
        E_prob[:,0] = 1 - E_prob[:,1]
        E_prob = E_prob + 1e-8
        E_t = self.sample_E_infer(E_prob)
        return A_t, E_t
    

class LossX(nn.Module):
    """
    Parameters
    ----------
    num_attrs_X : int
        Number of node attributes.
    num_classes_X : int
        Number of classes for each node attribute.
    """
    def __init__(self,
                 num_attrs_X,
                 num_classes_X):
        super().__init__()

        self.num_attrs_X = num_attrs_X
        self.num_classes_X = num_classes_X

    def forward(self, true_X, logit_X):
        """
        Parameters
        ----------
        true_X : torch.Tensor of shape (F, |V|, 2)
            X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute.
        logit_X : torch.Tensor of shape (|V|, F, 2)
            Predicted logits for the node attributes.

        Returns
        -------
        loss_X : torch.Tensor
            Scalar representing the loss for node attributes.
        """
        true_X = true_X.transpose(0, 1)               # (|V|, F, 2)
        # v1x1, v1x2, ..., v1xd, v2x1, ...
        true_X = true_X.reshape(-1, true_X.size(-1))  # (|V| * F, 2)

        # v1x1, v1x2, ..., v1xd, v2x1, ...
        logit_X = logit_X.reshape(true_X.size(0), -1) # (|V| * F, 2)

        true_X = torch.argmax(true_X, dim=-1)         # (|V| * F)
        loss_X = F.cross_entropy(logit_X, true_X)

        return loss_X

class ModelAsync(BaseModel):
    """
    Parameters
    ----------
    T_X : int
        Number of diffusion time steps - 1 for node attributes.
    T_E : int
        Number of diffusion time steps - 1 for edges.
    X_marginal : torch.Tensor of shape (F, 2)
        X_marginal[f, :] is the marginal distribution of the f-th node attribute.
    Y_marginal : torch.Tensor of shape (C)
        Marginal distribution of the node labels.
    E_marginal : torch.Tensor of shape (2)
        Marginal distribution of the edge existence.
    mlp_X_config : dict
        Configuration of the MLP for reconstructing node attributes.
    gnn_E_config : dict
        Configuration of the GNN for reconstructing edges.
    num_nodes : int
        Number of nodes in the original graph.
    """
    def __init__(self,
                 T_X,
                 T_E,
                 X_marginal,
                 Y_marginal,
                 E_marginal,
                 data_train_mask,
                 X_cond_Y_marginals,
                 mlp_X_config,
                 gnn_E_config,
                 pretrained_model,
                 num_nodes):
        super().__init__(T=T_X,
                         X_marginal=X_marginal,
                         Y_marginal=Y_marginal,
                         E_marginal=E_marginal,
                         X_cond_Y_marginals=X_cond_Y_marginals,
                         data_train_mask=data_train_mask,
                         num_nodes=num_nodes)

        del self.T
        del self.noise_schedule

        self.T_X = T_X
        self.T_E = T_E

        device = X_marginal.device
        self.noise_schedule_X = NoiseSchedule(T_X, device)
        self.noise_schedule_E = NoiseSchedule(T_E, device)
        self.X_cond_Y_marginals = X_cond_Y_marginals

        self.graph_encoder = GNNAsymm(num_attrs_X=self.num_attrs_X,
                                      num_classes_X=self.num_classes_X,
                                      num_classes_Y=self.num_classes_Y,
                                      num_classes_E=self.num_classes_E,
                                      data_train_mask=data_train_mask,
                                      mlp_X_config=mlp_X_config,
                                      gnn_E_config=gnn_E_config)

        self.loss_X = LossX(self.num_attrs_X, self.num_classes_X)
        self.prop_X = None
        self.linear = nn.Linear(self.num_attrs_X*2, 64)
        self.GCKM_E_model = GCKM_E([2, 1], [self.num_attrs_X, 64, 64], 2).to(device)
        self.pretrained_model = pretrained_model

    def apply_noise_X(self, X_one_hot_3d, t=None):
        """Corrupt X and sample X^t.

        Parameters
        ----------
        X_one_hot_3d : torch.Tensor of shape (F, |V|, 2)
            X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute
            in the real graph.
        t : torch.LongTensor of shape (1), optional
            If specified, a time step will be enforced rather than sampled.

        Returns
        -------
        t_float_X : torch.Tensor of shape (1)
            Sampled timestep divided by self.T_X.
        X_t_one_hot : torch.Tensor of shape (|V|, 2 * F)
            One-hot encoding of the sampled node attributes.
        """
        if t is None:
            # Sample a timestep t uniformly.
            # Note that the notation is slightly inconsistent with the paper.
            # t=0 corresponds to t=1 in the paper, where corruption has already taken place.
            t = torch.randint(low=0, high=self.T_X + 1, size=(1,),
                              device=X_one_hot_3d.device)

        alpha_bar_t = self.noise_schedule_X.alpha_bars[t]
        Q_bar_t_X = self.transition.get_Q_bar_X(alpha_bar_t) # (F, 2, 2)
        # Compute matrix multiplication over the first batch dimension.
        prob_X = torch.bmm(X_one_hot_3d, Q_bar_t_X)          # (F, |V|, 2)     

        # Sample X_t.
        X_t_one_hot = self.sample_X(prob_X)
        t_float_X = t / self.T_X
        del prob_X
        return t_float_X, X_t_one_hot
    
    def apply_noise_E(self, E_one_hot, t=None):
        """Corrupt A and sample A^t.

        Parameters
        ----------
        E_one_hot : torch.Tensor of shape (|V|, |V|, 2)
            - E_one_hot[:, :, 0] indicates the absence of an edge in the real graph.
            - E_one_hot[:, :, 1] is the adjacency matrix of the real graph.
        t : torch.LongTensor of shape (1), optional
            If specified, a time step will be enforced rather than sampled.

        Returns
        -------
        t_float_E : torch.Tensor of shape (1)
            Sampled timestep divided by self.T_E.
        E_t : torch.LongTensor of shape (|V|, |V|)
            Sampled symmetric adjacency matrix.
        """
        if t is None:
            # Sample a timestep t uniformly.
            # Note that the notation is slightly inconsistent with the paper.
            # t=0 corresponds to t=1 in the paper, where corruption has already taken place.
            t = torch.randint(low=0, high=self.T_E + 1, size=(1,),
                              device=E_one_hot.device)

        alpha_bar_X_cur = self.noise_schedule_E.alpha_bars[t]
        Q_bar_t_E = self.transition.get_Q_bar_E(alpha_bar_X_cur) # (2, 2)
        prob_E = E_one_hot @ Q_bar_t_E                       # (|V|, |V|, 2)
        del E_one_hot

        E_t = self.sample_E(prob_E)
        t_float_E = t / self.T_E
        del prob_E
        return t_float_E, E_t

    def log_p_t(self,
                X_one_hot_3d,
                E_one_hot,
                Y_distribution,
                Y,
                X_one_hot_2d,
                batch_src,
                batch_dst,
                batch_E_one_hot,
                t_X=None,
                t_E=None):
        """Obtain G^t and compute log p(G | G^t, Y, t).

        Parameters
        ----------
        X_one_hot_3d : torch.Tensor of shape (F, |V|, 2)
            X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute
            in the real graph.
        E_one_hot : torch.Tensor of shape (|V|, |V|, 2)
            - E_one_hot[:, :, 0] indicates the absence of an edge in the real graph.
            - E_one_hot[:, :, 1] is the adjacency matrix of the real graph.
        Y : torch.Tensor of shape (|V|)
            Categorical node labels of the real graph.
        X_one_hot_2d : torch.Tensor of shape (|V|, 2 * F)
            Flattened one-hot encoding of the node attributes in the real graph.
        batch_src : torch.LongTensor of shape (B)
            Source node IDs for a batch of candidate edges (node pairs).
        batch_dst : torch.LongTensor of shape (B)
            Destination node IDs for a batch of candidate edges (node pairs).
        batch_E_one_hot : torch.Tensor of shape (B, 2)
            E_one_hot[batch_dst, batch_src].
        t_X : torch.LongTensor of shape (1), optional
            If specified, a time step will be enforced rather than sampled for
            node attributes.
        t_E : torch.LongTensor of shape (1), optional
            If specified, a time step will be enforced rather than sampled for
            edges.

        Returns
        -------
        loss_X : torch.Tensor
            Scalar representing the loss for node attributes.
        loss_E : torch.Tensor
            Scalar representing the loss for edge existence.
        """
        t_float_X, X_t_one_hot = self.apply_noise_X(X_one_hot_3d, t_X)
        t_float_E, E_t = self.apply_noise_E(E_one_hot, t_E)
        A_t = self.get_adj(E_t)
        Y = Y_distribution

        logit_X, logit_E, classification_loss = self.graph_encoder(t_float_X,
                                              t_float_E,
                                              X_t_one_hot,
                                              Y,
                                              X_one_hot_2d,
                                              A_t,
                                              batch_src,
                                              batch_dst)
        if self.pretrained_model != None:
            teacher_prediction = self.pretrained_model(F.normalize(X_t_one_hot[:,1::2].to('cuda:1'), p=1, dim=1), E_t.nonzero().T.to('cuda:1'), None)
            teacher_prediction = F.log_softmax(teacher_prediction, dim=-1)
            classification_loss = F.kl_div(teacher_prediction, Y_distribution.softmax(1))
        else:
            classification_loss = 0
        del A_t, E_t, E_one_hot

        loss_X = self.loss_X(X_one_hot_3d, logit_X)
        loss_E = self.loss_E(batch_E_one_hot, logit_E)
        return loss_X, loss_E, classification_loss

    def sys_normalized_adjacency(self, adj):
        row_sum = adj.sum(1)
        row_sum = (row_sum == 0) * 1 + row_sum
        d_inv_sqrt = row_sum ** (-0.5)
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
        return d_mat_inv_sqrt @ adj @ d_mat_inv_sqrt
        
    def denoise_match_X(self,
                        t_float,
                        logit_X,
                        X_t_one_hot,
                        X_one_hot_3d):
        """Compute the denoising match term for node attribute prediction given a
        sampled t, i.e., the KL divergence between q(D^{t-1}| D, D^t) and
        q(D^{t-1}| hat{p}^{D}, D^t).

        Parameters
        ----------
        t_float : torch.Tensor of shape (1)
            Sampled timestep divided by self.T_X.
        logit_X : torch.Tensor of shape (|V|, F, 2)
            Predicted logits for the node attributes.
        X_t_one_hot : torch.Tensor of shape (|V|, 2 * F)
            One-hot encoding of the node attributes sampled at time step t.
        X_one_hot_3d : torch.Tensor of shape (F, |V|, 2)
            X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute.

        Returns
        -------
        float
            KL value for node attributes.
        """
        t = int(t_float.item() * self.T_X)
        s = t - 1

        alpha_bar_s = self.noise_schedule_X.alpha_bars[s]
        alpha_t = self.noise_schedule_X.alphas[t]

        Q_bar_s_X = self.transition.get_Q_bar_X(alpha_bar_s)
        # Note that computing Q_bar_t from alpha_bar_t is the same
        # as computing Q_t from alpha_t.
        Q_t_X = self.transition.get_Q_bar_X(alpha_t)

        # (|V|, F, 2)
        pred_X = logit_X.softmax(dim=-1)
        # (F, |V|, 2)
        pred_X = torch.transpose(pred_X, 0, 1)

        num_nodes = X_t_one_hot.size(0)
        # (|V|, F, 2)
        X_t_one_hot = X_t_one_hot.reshape(num_nodes, self.num_attrs_X, -1)
        # (F, |V|, 2)
        X_t_one_hot = torch.transpose(X_t_one_hot, 0, 1)
        return self.denoise_match_Z(X_t_one_hot, Q_t_X, X_one_hot_3d,  Q_bar_s_X, pred_X)  

    def denoise_match_E(self,
                        t_float,
                        logit_E,
                        E_t_one_hot,
                        E_one_hot):
        """Compute the denoising match term for edge prediction given a
        sampled t, i.e., the KL divergence between q(D^{t-1}| D, D^t) and
        q(D^{t-1}| hat{p}^{D}, D^t).

        Parameters
        ----------
        t_float : torch.Tensor of shape (1)
            Sampled timestep divided by self.T_E.
        logit_E : torch.Tensor of shape (B, 2)
            Predicted logits for the edge existence of a batch of node pairs.
        E_t_one_hot : torch.Tensor of shape (B, 2)
            One-hot encoding of sampled edge existence for the batch of
            node pairs.
        E_one_hot : torch.Tensor of shape (B, 2)
            One-hot encoding of the original edge existence for the batch of
            node pairs.

        Returns
        -------
        float
            KL value for edges.
        """
        t = int(t_float.item() * self.T_E)
        s = t - 1

        alpha_bar_s = self.noise_schedule_E.alpha_bars[s]
        alpha_t = self.noise_schedule_E.alphas[t]

        Q_bar_s_E = self.transition.get_Q_bar_E(alpha_bar_s)
        # Note that computing Q_bar_t from alpha_bar_t is the same
        # as computing Q_t from alpha_t.
        Q_t_E = self.transition.get_Q_bar_E(alpha_t)

        pred_E = logit_E.softmax(dim=-1)

        return self.denoise_match_Z(E_t_one_hot,
                                    Q_t_E,
                                    E_one_hot,
                                    Q_bar_s_E,
                                    pred_E)

    @torch.no_grad()
    def val_step(self,
                 X_one_hot_3d,
                 E_one_hot,
                 Y_distribution,
                 Y,
                 X_one_hot_2d,
                 batch_src,
                 batch_dst,
                 batch_E_one_hot):
        """Perform a validation step.

        Parameters
        ----------
        X_one_hot_3d : torch.Tensor of shape (F, |V|, 2)
            X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute
            in the real graph.
        E_one_hot : torch.Tensor of shape (|V|, |V|, 2)
            - E_one_hot[:, :, 0] indicates the absence of an edge in the real graph.
            - E_one_hot[:, :, 1] is the adjacency matrix of the real graph.
        Y : torch.Tensor of shape (|V|)
            Categorical node labels of the real graph.
        X_one_hot_2d : torch.Tensor of shape (|V|, 2 * F)
            Flattened one-hot encoding of the node attributes in the real graph.
        batch_src : torch.LongTensor of shape (B)
            Source node IDs for a batch of candidate edges (node pairs).
        batch_dst : torch.LongTensor of shape (B)
            Destination node IDs for a batch of candidate edges (node pairs).
        batch_E_one_hot : torch.Tensor of shape (B, 2)
            E_one_hot[batch_dst, batch_src].

        Returns
        -------
        denoise_match_E : float
            Denoising matching term for edge.
        denoise_match_X : float
            Denoising matching term for node attributes.
        log_p_0_E : float
            Reconstruction term for edge.
        log_p_0_X : float
            Reconstruction term for node attributes.
        """
        device = E_one_hot.device
        self.device = device
        Y = Y_distribution

        # Case1: X
        denoise_match_X = []
        # t=0 is handled separately.
        for t_sample_X in range(1, self.T_X + 1):
            t_X = torch.LongTensor([t_sample_X]).to(device)
            t_float_X, X_t_one_hot = self.apply_noise_X(X_one_hot_3d, t_X)
            logit_X = self.graph_encoder.pred_X(t_float_X,
                                                X_t_one_hot,
                                                Y)

            denoise_match_X_t = self.denoise_match_X(t_float_X,
                                                     logit_X,
                                                     X_t_one_hot,
                                                     X_one_hot_3d)
            denoise_match_X.append(denoise_match_X_t)
        denoise_match_X = float(np.mean(denoise_match_X)) * self.T_X

        # Case2: E
        denoise_match_E = []
        # t=0 is handled separately.
        for t_sample_E in range(1, self.T_E + 1):
            t_E = torch.LongTensor([t_sample_E]).to(device)
            t_float_E, E_t = self.apply_noise_E(E_one_hot, t_E)
            A_t = self.get_adj(E_t)
            logit_E = self.graph_encoder.pred_E(t_float_E,
                                                X_one_hot_2d,
                                                Y,
                                                A_t,
                                                batch_src,
                                                batch_dst)

            E_t_one_hot = F.one_hot(E_t[batch_src, batch_dst],
                                    num_classes=self.num_classes_E).float()
            denoise_match_E_t = self.denoise_match_E(t_float_E,
                                                     logit_E,
                                                     E_t_one_hot.to('cuda:1'),
                                                     batch_E_one_hot)
            denoise_match_E.append(denoise_match_E_t)
        denoise_match_E = float(np.mean(denoise_match_E)) * self.T_E

        # t=0
        t_0 = torch.LongTensor([0]).to(device)
        loss_X, loss_E, _ = self.log_p_t(X_one_hot_3d,
                                      E_one_hot,
                                      Y_distribution,
                                      Y,
                                      X_one_hot_2d,
                                      batch_src,
                                      batch_dst,
                                      batch_E_one_hot,
                                      t_X=t_0,
                                      t_E=t_0)
        log_p_0_E = loss_E.item()
        log_p_0_X = loss_X.item()

        return denoise_match_E, denoise_match_X,\
            log_p_0_E, log_p_0_X

    @torch.no_grad()
    def sample(self, Y_dist, batch_size=32768, num_workers=4):
        """Sample a graph.

        Parameters
        ----------
        batch_size : int
            Batch size for edge prediction.
        num_workers : int
            Number of subprocesses for data loading in edge prediction.

        Returns
        -------
        X_t_one_hot : torch.Tensor of shape (F, |V|, 2)
            One-hot encoding of the generated node attributes.
        Y_0_one_hot : torch.Tensor of shape (|V|, C)
            One-hot encoding of the generated node labels.
        E_t : torch.LongTensor of shape (|V|, |V|)
            Adjacency matrix of the generated graph.
        """
        device = self.X_marginal.device
        self.device = device

        # Sample Y_0
        # (|V|, C)
        Y_prior = self.Y_marginal[None, :].expand(self.num_nodes, -1)
        Y_prior = torch.ones_like(Y_prior) / Y_prior.size(1)
        Y_0 = (Y_dist * 1).to("cuda:1")

        max_Y = Y_dist.argmax(-1).max()
        new_Y_dist = torch.ones_like(Y_dist).to("cuda:1")
        tmp = 0
        for y in range(max_Y+1):
            Y_tmp = Y_0[Y_dist.argmax(-1) == y].repeat(100, 1)
            idx = torch.randperm(Y_tmp.size()[0])[:Y_dist.size()[0] // (max_Y+1)]
            new_Y_dist[tmp : tmp + Y_dist.size()[0] // (max_Y+1)] = Y_tmp[idx] #+ torch.randn_like(Y_tmp[idx]) * 0.1
            tmp += Y_dist.size()[0] // (max_Y+1)
        idx = torch.randperm(Y_tmp.size()[0])[:Y_dist.size()[0] - tmp]
        new_Y_dist[tmp : Y_dist.size()[0]+1] = Y_tmp[idx] #+ torch.randn_like(Y_tmp[idx]) * 0.1
        Y_0 = new_Y_dist
    #    Y_0 = Y_prior.multinomial(1).reshape(-1)
    #    Y_0 = F.one_hot(Y_0).float()
    #    Y_0 = Y_00 * 1.3 + torch.randn_like(Y_00)
        
        # Sample X^T from prior distribution.
        # (F, |V|, 2)
        X_prior = self.X_marginal[:, None, :].expand(-1, self.num_nodes, -1)
        # (|V|, 2F)
        X_t_one_hot = self.sample_X(X_prior)

        # Iteratively sample p(X^s | X^t) for t = 1, ..., T_X, with s = t - 1.
        for s_X in tqdm(list(reversed(range(0, self.T_X)))):
            t_X = s_X + 1
            t_float_X = torch.tensor([t_X / self.T_X]).to(device)
            pred_X = self.graph_encoder.pred_X(t_float_X,
                                               X_t_one_hot,
                                               Y_0)
            pred_X = pred_X.softmax(dim=-1)
            # (F, |V|, 2)
            pred_X = torch.transpose(pred_X, 0, 1)      # x0_hat

            # (|V|, F, 2)
            X_t_one_hot = X_t_one_hot.reshape(self.num_nodes, self.num_attrs_X, -1)
            # (F, |V|, 2)
            X_t_one_hot = torch.transpose(X_t_one_hot, 0, 1)

            # Note that computing Q_bar_t from alpha_bar_t is the same
            # as computing Q_t from alpha_t.
            alpha_t_X = self.noise_schedule_X.alphas[t_X]
            alpha_bar_s_X = self.noise_schedule_X.alpha_bars[s_X]
            alpha_bar_t_X = self.noise_schedule_X.alpha_bars[t_X]

            # (F, |V|, 2)
            Q_t_X = self.transition.get_Q_bar_X(alpha_t_X)
            Q_bar_s_X = self.transition.get_Q_bar_X(alpha_bar_s_X)
            Q_bar_t_X = self.transition.get_Q_bar_X(alpha_bar_t_X)

            X_prob = self.posterior(X_t_one_hot, Q_t_X,
                                    Q_bar_s_X, Q_bar_t_X, pred_X)
    
            X_t_one_hot = self.sample_X(X_prob)

        # Sample E^T from prior distribution.
        # (|V|, |V|, 2)
        E_prior = self.E_marginal[None, None, :].expand(
            self.num_nodes, self.num_nodes, -1)
        # (|V|, |V|)
        E_t = self.sample_E(E_prior)
        dst, src = torch.triu_indices(self.num_nodes, self.num_nodes,
                                      offset=1, device=device)
        # (|E|)
        self.dst = dst
        # (|E|)
        self.src = src
        # (|E|, 2)
        edge_index = torch.stack([dst, src], dim=1).to("cpu")
        data_loader = DataLoader(edge_index, batch_size=batch_size,
                                 num_workers=num_workers)

        # Iteratively sample p(A^s | A^t) for t = 1, ..., T_E, with s = t - 1.
        for s_E in tqdm(list(reversed(range(0, self.T_E)))):
            t_E = s_E + 1
            alpha_t_E = self.noise_schedule_E.alphas[t_E]
            alpha_bar_s_E = self.noise_schedule_E.alpha_bars[s_E]
            alpha_bar_t_E = self.noise_schedule_E.alpha_bars[t_E]

            Q_t_E = self.transition.get_Q_bar_E(alpha_t_E)
            Q_bar_s_E = self.transition.get_Q_bar_E(alpha_bar_s_E)
            Q_bar_t_E = self.transition.get_Q_bar_E(alpha_bar_t_E)

            t_float_E = torch.tensor([t_E / self.T_E]).to(device)

            _, E_t = self.get_E_t(device,
                                  data_loader,
                                  self.graph_encoder.pred_E,
                                  t_float_E,
                                  X_t_one_hot,
                                  Y_0,
                                  E_t.to(device),
                                  Q_t_E,
                                  Q_bar_s_E,
                                  Q_bar_t_E,
                                  batch_size)


        # (|V|, F, 2)
        X_t_one_hot = X_t_one_hot.reshape(self.num_nodes, self.num_attrs_X, -1)
        # (F, |V|, 2)
        X_t_one_hot = torch.transpose(X_t_one_hot, 0, 1)

        Y_0 = Y_0.argmax(1)
    #    Y_0_one_hot = torch.where(Y_0 == -100, torch.tensor(0.0).cuda(1), Y_0)
        Y_0_one_hot = F.one_hot(Y_0, num_classes=self.num_classes_Y).float()

        return X_t_one_hot, Y_0_one_hot, E_t