import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from omegaconf import open_dict, OmegaConf
import os

from models.transformer_model import GraphTransformer
from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalUniformTransition
from diffusion import diffusion_utils
from utils.losses import TrainLoss
import utils.utils as utils
from target_predict import MoleculeTargetPrediction



def add_new_key(cfg):
    OmegaConf.set_struct(cfg.model, True)
    if not hasattr(cfg.model, "level_ar"):
        with open_dict(cfg.model):
            cfg.model.level_ar = False

    if not hasattr(cfg.train, 'dropout'):
        with open_dict(cfg.train):
            cfg.train.dropout = 0.1

    OmegaConf.set_struct(cfg.general, True)
    if not hasattr(cfg.general, 'sample_num'):
        with open_dict(cfg.general):
            cfg.general.sample_num = 1


class JTDiffusion(pl.LightningModule):
    def __init__(self, cfg, dataset_infos, extra_features):
        super().__init__()

        input_dims = dataset_infos.input_dims
        output_dims = dataset_infos.output_dims
        nodes_dist = dataset_infos.nodes_dist

        add_new_key(cfg)
        self.cfg = cfg
        self.name = cfg.general.name
        self.model_dtype = torch.float32
        self.T = cfg.model.diffusion_steps
        self.cond = cfg.general.target
        if self.cond is not None:
            self.cond_model = [MoleculeTargetPrediction.load_from_checkpoint(cfg.train.cond_model_path).eval()]
            for param in self.cond_model[0].parameters():
                param.requires_grad = False

        self.Xdim = input_dims['X']
        self.Edim = input_dims['E']
        self.ydim = input_dims['y']
        self.Xdim_output = output_dims['X']
        self.Edim_output = output_dims['E']
        self.ydim_output = output_dims['y']
        self.node_dist = nodes_dist

        self.dataset_info = dataset_infos

        self.train_loss = TrainLoss(self.cfg.model.lambda_train, self.cfg.model.class_weight)
        self.val_loss = TrainLoss(self.cfg.model.lambda_train, self.cfg.model.class_weight)
        self.test_loss = TrainLoss(self.cfg.model.lambda_train, self.cfg.model.class_weight)

        self.extra_features = extra_features

        self.model = GraphTransformer(n_layers=cfg.model.n_layers,
                                        input_dims=input_dims,
                                        hidden_mlp_dims=cfg.model.hidden_mlp_dims,
                                        hidden_dims=cfg.model.hidden_dims,
                                        output_dims=output_dims,
                                        act_fn_in=nn.ReLU(),
                                        act_fn_out=nn.ReLU(),
                                        dropout = cfg.train.dropout,
                                        node_vocab_size = dataset_infos.node_types.shape[0])

        self.noise_schedule = PredefinedNoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule, timesteps=cfg.model.diffusion_steps)
        node_types = self.dataset_info.node_types.float()
        x_marginals = node_types / torch.sum(node_types)

        edge_types = self.dataset_info.edge_types.float()
        e_marginals = edge_types / torch.sum(edge_types)
        print(f"Marginal distribution of the classes: {x_marginals} for nodes, {e_marginals} for edges")
        self.transition_model = MarginalUniformTransition(x_marginals=x_marginals, e_marginals=e_marginals,
                                                            y_classes=self.ydim_output)
        self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals,
                                            y=torch.ones(self.ydim_output) / self.ydim_output)

        self.log_every_steps = cfg.general.log_every_steps
        self.best_loss = {"epoch": -1, "loss": np.inf}
        self.level_ar = cfg.model.level_ar

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.cfg.train.lr, amsgrad=True,
                                 weight_decay=self.cfg.train.weight_decay)
        return optimizer

    def get_pred_cond(self, data, level, node_mask):
        X_one_hot = F.gumbel_softmax(data.X, hard = True, dim = -1)
        E_one_hot = F.gumbel_softmax(data.E, hard = True, dim = -1)
        self.cond_model[0].to(X_one_hot.device)
        # now use gt batch's level; TODO: use pred batch's true level
        data = {'X_t': X_one_hot, 'E_t': E_one_hot, 'y_t': E_one_hot.new_zeros(E_one_hot.shape[0], 0), 'node_mask': node_mask, "node_level": level.unsqueeze(-1)}
        with torch.no_grad():
            extra_data = self.cond_model[0].compute_extra_data(data)

        assert not self.cond_model[0].training and not self.cond_model[0].model.training
        pred_y = self.cond_model[0].forward(data, extra_data, node_mask)
        return pred_y

    def forward(self, noisy_data, extra_data, node_mask):
        X = torch.cat((noisy_data['X_t'], extra_data.X), dim=2).float()
        E = torch.cat((noisy_data['E_t'], extra_data.E), dim=3).float()
        y = torch.hstack((noisy_data['y_t'], extra_data.y)).float()
        return self.model(X, E, y, node_mask)    
        
    def shared_step(self, data, split = "train"):
        level = data.level
        dense_data, node_mask, level = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch, level)
        dense_data = dense_data.mask(node_mask)
        X, E = dense_data.X, dense_data.E
        level = level / (level.max(dim = -1, keepdim = True)[0] + 1e-7)

        if self.level_ar:
            noisy_data = self.apply_noise_level_wise(X, E, data.y, node_mask, level)
        else:
            noisy_data = self.apply_noise(X, E, data.y, node_mask)
        extra_data = self.compute_extra_data(noisy_data)
        
        pred = self.forward(noisy_data, extra_data, node_mask)
        if self.cond is not None:
            pred.y = self.get_pred_cond(pred, level, node_mask)
        
        if split == "train": loss_func = self.train_loss
        elif split == "val": loss_func = self.val_loss
        else: loss_func = self.test_loss
        loss, _ = loss_func(masked_pred_X = pred.X, masked_pred_E = pred.E, pred_y = pred.y, true_X = X, true_E = E, true_y = data.y)

        return {'loss': loss}, pred, dense_data
    
    def training_step(self, data, i):
        return self.shared_step(data, "train")[0]
    
    def validation_step(self, data, i):
        return self.shared_step(data, "val")[0]

    def on_train_epoch_start(self):
        self.train_loss.reset()

    def on_train_epoch_end(self):
        to_log = self.train_loss.log_epoch_metrics()        
        self.log_dict(to_log, on_epoch = True)
    
    def on_validation_epoch_start(self):
        self.val_loss.reset()

    def on_validation_epoch_end(self):
        to_log = self.val_loss.log_epoch_metrics()        
        log_dict = {"val/X_CE": to_log['train_epoch/X_CE'],
                    "val/E_CE": to_log['train_epoch/E_CE'],
                    "val/y_MSE": to_log['train_epoch/y_MSE'],
                    "val/loss": to_log['train_epoch/loss']}
        self.log_dict(log_dict, on_epoch = True)

    def on_test_epoch_start(self):
        self.test_loss.reset()

    def sample_level(self, bs):
        def pad(l):
            l = l + [0] * (max_n_nodes - len(l))
            return l
        
        n_nodes, level = [], []
        max_n_nodes = 0
        for i in range(bs):
            num_level = torch.multinomial(self.dataset_info.level_num, 1).item()
            l = []
            for level_i in range(0, num_level):
                i_size = torch.multinomial(self.dataset_info.level_size[level_i], 1).item()
                l += [level_i / num_level] * i_size
            l += [1.]   # last level (root) only has one node
            level.append(l)
            n_nodes.append(len(l))
            if len(l) > max_n_nodes:
                max_n_nodes = len(l)
        
        n_nodes = torch.tensor(n_nodes, device = self.device)
        level = torch.stack([torch.tensor(pad(l), device = self.device) for l in level], dim = 0)
        node_mask = torch.arange(max_n_nodes, device = self.device).unsqueeze(0) < n_nodes.unsqueeze(1)

        return node_mask, n_nodes, level
        
    def test_step(self, data, batch_idx):
        loss_dict, pred, dense_data = self.shared_step(data, "test")
        mask_X = (dense_data.X != 0.).any(dim=-1)
        node_accu = (pred.X.argmax(dim = -1) == dense_data.X.argmax(dim = -1)).float()[mask_X].mean()
        mask_E = torch.triu(torch.ones(dense_data.E.shape[:-1], device = self.device), diagonal = 1).bool() & (dense_data.E != 0.).any(dim=-1)
        edge_accu = (pred.E.argmax(dim = -1) == dense_data.E.argmax(dim = -1)).float()[mask_E].mean()
        print(f"Node accuracy: {node_accu:.3f}, Edge accuracy: {edge_accu:.3f}\n")

        bs = dense_data.X.shape[0]
        dense_data.y = data.y
        node_mask, n_nodes, level = self.sample_level(bs)
        X, E, im_x_list, im_adj_list = self.sample(node_mask, cond = data.y, level = level)
        pred = utils.PlaceHolder(X = X, E = E, y = torch.zeros(X.shape[0], 0))
        
        if self.cond is not None and self.cond not in ["plogp", "qed"]:
            dense_data.y = dense_data.y * self.trainer.datamodule.target_std + self.trainer.datamodule.target_mu
        dense_data.X, dense_data.E = dense_data.X.argmax(dim = -1), dense_data.E.argmax(dim = -1)
        pred.X, pred.E = pred.X.argmax(dim = -1), pred.E.argmax(dim = -1)
        torch.save((dense_data, data.n_nodes, pred, n_nodes), f"{os.path.dirname(self.log_path)}/saved_samples_{batch_idx}.pth")
        
        return loss_dict
    

    def graph_t_to_level_t(self, t, level, N = 16):
        '''map global timestep [0, self.T] to node wise timestep [0, self.T]
        t: int graph timestep t
        level should be normalized into [0, 1]
        '''
        # xs = 2 * N * (1 - level)        # bottom up
        xs = 2 * N * level              # top down

        node_t = self.T / (self.T - xs) * (t - xs)
        node_t = torch.clip(torch.round(node_t), 0, self.T).int().float()
        node_t = node_t.unsqueeze(-1)       # [bs, n, 1]
        t_norm = node_t / self.T            # [bs, n, 1]
        return node_t, t_norm
        
    @torch.no_grad()
    def sample(self, node_mask, cond, level):
        batch_size = node_mask.shape[0]

        # Sample noise  -- z has size (n_samples, n_nodes, n_features)
        z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist = self.limit_dist, node_mask = node_mask)
        X, E, y = z_T.X, z_T.E, cond

        im_x_list = []
        im_adj_list = []
        for s_int in tqdm.tqdm(reversed(range(0, self.T)), leave = False, total = self.T, desc = "Sample step"):
            s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
            t_array = s_array + 1
            s_norm = s_array / self.T
            t_norm = t_array / self.T

            # Sample z_s
            if not self.level_ar:
                sampled_s, _ = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask, level)
            else:
                sampled_s, _ = self.sample_p_zs_given_zt_node_wise(s_norm, t_norm, X, E, y, node_mask, level)
            X, E = sampled_s.X, sampled_s.E

            if s_int % 10 == 0:
                im_x_list.append(X)
                im_adj_list.append(E)

        # Sample
        sampled_s = sampled_s.mask(node_mask, collapse = False)
        X, E = sampled_s.X, sampled_s.E
        return X, E, im_x_list, im_adj_list

    def apply_noise(self, X, E, y, node_mask):
        """ Sample noise and apply it to the data. """

        # Sample a timestep t.
        # When evaluating, the loss for t=0 is computed separately
        lowest_t = 0 if self.training else 1
        t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1)
        s_int = t_int - 1

        t_float = t_int / self.T
        s_float = s_int / self.T

        # beta_t and alpha_s_bar are used for denoising/loss computation
        beta_t = self.noise_schedule(t_normalized=t_float)                         # (bs, 1)
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)      # (bs, 1)
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)      # (bs, 1)

        Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device=self.device)  # (bs, dx_in, dx_out), (bs, de_in, de_out)
        assert (abs(Qtb.X.sum(dim=2) - 1.) < 1e-4).all(), Qtb.X.sum(dim=2) - 1
        assert (abs(Qtb.E.sum(dim=2) - 1.) < 1e-4).all()

        # Compute transition probabilities
        probX = X @ Qtb.X  # (bs, n, dx_out)
        probE = E @ Qtb.E.unsqueeze(1)  # (bs, n, n, de_out)

        sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask)

        X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
        E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
        assert (X.shape == X_t.shape) and (E.shape == E_t.shape)

        z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y).type_as(X_t).mask(node_mask)

        noisy_data = {'t_int': t_int, 't': t_float, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar,
                      'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask}
        return noisy_data
    
    def apply_noise_level_wise(self, X, E, y, node_mask, level_float):
        """ Sample noise and apply it to the data. """

        # Sample a graph level timestep t.
        # When evaluating, the loss for t=0 is computed separately
        lowest_t = 0 if self.training else 1
        t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1), graph level timestep
        s_int = t_int - 1

        t_float = t_int / self.T
        s_float = s_int / self.T
        
        node_t, node_t_float = self.graph_t_to_level_t(t_int, level_float)            # [bs, n, 1], [bs, n, 1]
        node_s, node_s_float = self.graph_t_to_level_t(s_int, level_float)            # [bs, n, 1], [bs, n, 1]

        # beta_t and alpha_s_bar are used for denoising/loss computation
        beta_t = self.noise_schedule(t_normalized = node_t_float)                    # [bs, n, 1]
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized = node_s_float) # [bs, n, 1]
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized = node_t_float) # [bs, n, 1]

        probX = torch.zeros_like(X).float()
        probE = torch.zeros_like(E).float() 
        # no batch operation due to limited GPU memory
        for node_idx in range(X.shape[1]):
            node_t_i = node_t[:, node_idx]              # [bs, 1]
            node_t_float_i = node_t_float[:, node_idx]  # [bs, 1]

            alpha_t_bar_i = self.noise_schedule.get_alpha_bar(t_normalized = node_t_float_i) # [bs, 1]
            Qtb = self.transition_model.get_Qt_bar(alpha_t_bar_i, device = self.device)   # [bs, dx_in, dx_out], [bs, de_in, de_out]
            assert (abs(Qtb.X.sum(dim=2) - 1.) < 1e-4).all(), Qtb.X.sum(dim=2) - 1
            assert (abs(Qtb.E.sum(dim=2) - 1.) < 1e-4).all()

            # Compute transition probabilities
            probX_i = X @ Qtb.X                         # (bs, n, dx_out)
            probE_i = E @ Qtb.E.unsqueeze(1)            # (bs, n, n, de_out)
            triu_mask = torch.triu(torch.ones(probE_i.shape[:3], device = X.device, dtype = torch.bool), diagonal = 0)
            probE_i[triu_mask.unsqueeze(-1).repeat(1, 1, 1, probE_i.shape[-1])] = 0.
            probE_i = probE_i + probE_i.transpose(1, 2)

            probX[:, node_idx] = probX_i[:, node_idx]
            probE[:, node_idx] = probE_i[:, node_idx]
            probE[:, :, node_idx] = probE_i[:, :, node_idx]

        sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask)

        X_t = F.one_hot(sampled_t.X, num_classes = self.Xdim_output)
        E_t = F.one_hot(sampled_t.E, num_classes = self.Edim_output)
        assert (X.shape == X_t.shape) and (E.shape == E_t.shape)

        z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y).type_as(X_t).mask(node_mask)

        noisy_data = {'t_int': t_int, 't': t_float, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar,
                      'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask, "node_t_int": node_t, "node_t_float": node_t_float, "node_level": level_float.unsqueeze(-1)}
        return noisy_data
    
    def apply_noise_level_wise_batch(self, X, E, y, node_mask, level_float):
        """ Sample noise and apply it to the data. """

        # Sample a graph level timestep t.
        # When evaluating, the loss for t=0 is computed separately
        lowest_t = 0 if self.training else 1
        t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1), graph level timestep
        s_int = t_int - 1

        t_float = t_int / self.T
        s_float = s_int / self.T
        
        node_t, node_t_float = self.graph_t_to_level_t(t_int, level_float)            # [bs, n, 1], [bs, n, 1]
        node_s, node_s_float = self.graph_t_to_level_t(s_int, level_float)            # [bs, n, 1], [bs, n, 1]

        # beta_t and alpha_s_bar are used for denoising/loss computation
        beta_t = self.noise_schedule(t_normalized = node_t_float)                    # [bs, n, 1]
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized = node_s_float) # [bs, n, 1]
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized = node_t_float) # [bs, n, 1]

        Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device = self.device)   # [bs, n, dx_in, dx_out), (bs, n, de_in, de_out]
        assert (abs(Qtb.X.sum(dim=3) - 1.) < 1e-4).all(), Qtb.X.sum(dim=3) - 1
        assert (abs(Qtb.E.sum(dim=3) - 1.) < 1e-4).all()

        # Compute transition probabilities
        probX = (X.unsqueeze(2) @ Qtb.X).squeeze(2)  # (bs, n, dx_out)
        probE = E @ Qtb.E                            # (bs, n, n, de_out)
        triu_mask = torch.triu(torch.ones(probE.shape[:3], device = X.device, dtype = torch.bool), diagonal = 0)
        probE[triu_mask.unsqueeze(-1).repeat(1, 1, 1, probE.shape[-1])] = 0.
        probE = probE + probE.transpose(1, 2)

        sampled_t = diffusion_utils.sample_discrete_features(probX = probX, probE = probE, node_mask = node_mask)

        X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
        E_t = F.one_hot(sampled_t.E, num_classes = self.Edim_output)
        assert (X.shape == X_t.shape) and (E.shape == E_t.shape)

        z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y).type_as(X_t).mask(node_mask)

        noisy_data = {'t_int': t_int, 't': t_float, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar,
                      'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask, "node_t_int": node_t, "node_t_float": node_t_float, "node_level": level_float.unsqueeze(-1)}
        return noisy_data


    def get_Qt_skip(self, s, t, device):
        ''' Returns multi-step transition matrices for X and E, from step s to step t.
        s: float, [1, 1]
        t: float, [1, 1]
        '''
        s_int = torch.round(s * self.T)
        t_int = torch.round(t * self.T)
        beta_t = self.noise_schedule(t_normalized = t)
        Qt = self.transition_model.get_Qt(beta_t, device)   # [1, 3, 3], t-1, t
        Qt_x = Qt.X
        Qt_e = Qt.E
        for t_c in range(int(t_int.item()) - 2, int(s_int.item()) - 1, -1):
            beta_t = self.noise_schedule(t_int = torch.tensor([[t_c + 1]], device = device))
            Qt_c = self.transition_model.get_Qt(beta_t, device)
            Qt_x = Qt_c.X @ Qt_x
            Qt_e = Qt_c.E @ Qt_e
        return utils.PlaceHolder(X = Qt_x, E = Qt_e, y = Qt.y)

    def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask, level):
        """Samples from zs ~ p(zs | zt). Only used during sampling.
           if last_step, return the graph prediction as well"""
        global_t = t

        bs, n, dxs = X_t.shape
        beta_t = self.noise_schedule(t_normalized=t)  # [bs, 1]; if level_ar: [bs, n, 1]
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)

        # Retrieve transitions matrix
        Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)    # [bs, 3, 3]
        Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device)
        Qt = self.transition_model.get_Qt(beta_t, self.device)

        # Neural net predictions
        noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': global_t, 'node_mask': node_mask}
        extra_data = self.compute_extra_data(noisy_data)
        pred = self.forward(noisy_data, extra_data, node_mask)
        pred = utils.PlaceHolder(X = pred.X, E = pred.E, y = pred.y).mask(node_mask)

        # Normalize predictions
        pred_X = F.softmax(pred.X, dim=-1)               # bs, n, d0
        pred_E = F.softmax(pred.E, dim=-1)               # bs, n, n, d0

        p_s_and_t_given_0_X = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=X_t,
                                                                                           Qt=Qt.X,
                                                                                           Qsb=Qsb.X,
                                                                                           Qtb=Qtb.X)

        p_s_and_t_given_0_E = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=E_t,
                                                                                           Qt=Qt.E,
                                                                                           Qsb=Qsb.E,
                                                                                           Qtb=Qtb.E)
        # Dim of these two tensors: bs, N, d0, d_t-1
        # left:         bs, N, 3, 1, p(x_0 | x_t)
        # right:        bs, N, 3, 3, p(x_t-1 | x_t, x_0) or p(x_t-n | x_t, x_0)
        # weighted_X:   bs, N, 3, 3, p(x_0, x_t-1 | x_t) or p(x_0, x_t-n | x_t)
        # sum(dim = 2): bs, N,    3, p(x_t-1 | x_t), p(x_t-n | x_t)
        weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X         # bs, n, d0, d_t-1
        unnormalized_prob_X = weighted_X.sum(dim=2)                     # bs, n, d_t-1
        unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
        prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  # bs, n, d_t-1

        pred_E = pred_E.reshape((bs, -1, pred_E.shape[-1]))
        weighted_E = pred_E.unsqueeze(-1) * p_s_and_t_given_0_E        # bs, N, d0, d_t-1
        unnormalized_prob_E = weighted_E.sum(dim=-2)
        unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
        prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)
        prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])

        assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
        assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()

        sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask)

        X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
        E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()

        assert (E_s == torch.transpose(E_s, 1, 2)).all()
        assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)

        out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0))
        out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0))

        return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)

    def sample_p_zs_given_zt_node_wise(self, s, t, X_t, E_t, y_t, node_mask, level):
        """Samples from zs ~ p(zs | zt). Only used during sampling.
           if last_step, return the graph prediction as well"""
        global_t = t
        _, t = self.graph_t_to_level_t(torch.round(t * self.T), level)      # [bs, n, 1]
        _, s = self.graph_t_to_level_t(torch.round(s * self.T), level)      # [bs, n, 1]

        # Neural net predictions
        noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': global_t, 'node_mask': node_mask, "node_t_float": t, "node_level": level.unsqueeze(-1)}
        extra_data = self.compute_extra_data(noisy_data)
        pred = self.forward(noisy_data, extra_data, node_mask)
        pred = utils.PlaceHolder(X = pred.X, E = pred.E, y = pred.y).mask(node_mask)

        # Normalize predictions
        pred_X = F.softmax(pred.X, dim=-1)               # bs, n, d0
        pred_E = F.softmax(pred.E, dim=-1)               # bs, n, n, d0

        bs, n, dxs = X_t.shape

        prob_X = torch.zeros_like(X_t).float()
        prob_E = torch.zeros_like(E_t).float()
        no_change_mask = torch.zeros_like(node_mask).bool()
        for node_idx in range(X_t.shape[1]):
            t_i = t[:, node_idx]              # [bs, 1]
            s_i = s[:, node_idx]              # [bs, 1]
            edge_mask_i = torch.zeros(bs, n, n, device = self.device).bool()
            edge_mask_i[:, node_idx] = 1
            edge_mask_i[:, :, node_idx] = 1
            alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_i)
            alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_i)

            Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)    # [bs, 3, 3]
            Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device)    # [bs, 3, 3]

            no_change_mask[:, node_idx] = (s_i == t_i).squeeze(-1) | (t_i == 0).squeeze(-1)
            from_to_pair = torch.cat((s_i, t_i), dim = 1)   # [bs, 2]
            all_pairs = torch.unique(from_to_pair, dim = 0)
            Qt_x = torch.zeros_like(Qtb.X)
            Qt_e = torch.zeros_like(Qtb.E)
            for pair in all_pairs:
                qt = self.get_Qt_skip(pair[0].unsqueeze(0), pair[1].unsqueeze(0), self.device)
                Qt_x[(from_to_pair == pair).all(dim = -1)] = qt.X
                Qt_e[(from_to_pair == pair).all(dim = -1)] = qt.E
            Qt = utils.PlaceHolder(X = Qt_x, E = Qt_e, y = Qtb.y)        # [bs, 3, 3]

            p_s_and_t_given_0_X = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=X_t[:, node_idx : node_idx + 1],
                                                                                            Qt=Qt.X,
                                                                                            Qsb=Qsb.X,
                                                                                            Qtb=Qtb.X)

            p_s_and_t_given_0_E = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=E_t[edge_mask_i].reshape(bs, -1, E_t.shape[-1]),
                                                                                            Qt=Qt.E,
                                                                                            Qsb=Qsb.E,
                                                                                            Qtb=Qtb.E)
            # Dim of these two tensors: bs, N, d0, d_t-1
            # left:         bs, 1, 3, 1, p(x_0 | x_t)
            # right:        bs, 1, 3, 3, p(x_t-1 | x_t, x_0) or p(x_t-n | x_t, x_0)
            # weighted_X:   bs, 1, 3, 3, p(x_0, x_t-1 | x_t) or p(x_0, x_t-n | x_t)
            # sum(dim = 2): bs, 1,    3, p(x_t-1 | x_t), p(x_t-n | x_t)
            weighted_X = pred_X[:, node_idx : node_idx + 1].unsqueeze(-1) * p_s_and_t_given_0_X         # bs, 1, d0, d_t-1
            unnormalized_prob_X = weighted_X.sum(dim=2)                     # bs, n, d_t-1
            unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
            prob_X[:, node_idx:node_idx + 1] = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  # bs, 1, d_t-1

            pred_E_i = pred_E[edge_mask_i].reshape((bs, -1, pred_E.shape[-1]))
            weighted_E = pred_E_i.unsqueeze(-1) * p_s_and_t_given_0_E        # bs, N, d0, d_t-1
            unnormalized_prob_E = weighted_E.sum(dim=-2)
            unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
            prob_E[edge_mask_i] = (unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)).view(-1, pred_E.shape[-1])

        assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
        assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()

        sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask)

        X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
        E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()

        no_change_mask_e = (no_change_mask.unsqueeze(1) | no_change_mask.unsqueeze(2)) * node_mask.unsqueeze(1) * node_mask.unsqueeze(2)
        X_s[no_change_mask] = X_t[no_change_mask].to(torch.float32)
        E_s[no_change_mask_e] = E_t[no_change_mask_e].to(torch.float32)

        assert (E_s == torch.transpose(E_s, 1, 2)).all()
        assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)

        out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0))
        out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0))

        return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)

    def compute_extra_data(self, noisy_data):
        """ At every training step (after adding noise) and step in sampling, compute extra information and append to
            the network input. """

        extra_features = self.extra_features(noisy_data)
        t = noisy_data['t']
        extra_X = extra_features.X  
        extra_E = extra_features.E
        extra_y = torch.cat((extra_features.y, t), dim=1)

        if "node_t_float" in noisy_data:
            extra_X = torch.cat([extra_X, noisy_data["node_t_float"]], dim = -1)
        if "node_level" in noisy_data:
            extra_X = torch.cat([extra_X, noisy_data["node_level"]], dim = -1)

        return utils.PlaceHolder(X = extra_X, E = extra_E, y = extra_y)
