import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from rdkit.Chem import Draw
from rdkit import Chem
from rdkit.Chem import rdChemReactions as Reactions

import seaborn as sns
import numpy as np
import imageio
import os
import wandb
import logging

log = logging.getLogger(__name__)

from src.utils.graph import PlaceHolder
from src.utils import graph
from src.utils import mol

import torch
from torch.nn import functional as F
from torch_geometric.utils import to_dense_batch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EPS = 1e-6

def accumulate_rxn_scores(acc_scores, new_scores, total_iterations):
    '''
        Updates the acc_scores with new metric averages taking into account the new_scores.
        
        input:
            acc_scores: accumulated scores state
            new_scores: new_scores to add to the accumulation
            total_iterations: total number of batches considered. 
        output:
            acc_scores: accumulated scores state with the new_scores added.
    '''
    for metric in new_scores.keys():
        if type(new_scores[metric])==list: # accumulates the plots
            if acc_scores[metric]==0:
                acc_scores[metric] = new_scores[metric]
            else:
                acc_scores[metric].extend(new_scores[metric])
        else:
            acc_scores[metric] += new_scores[metric].mean()/total_iterations
        
    return acc_scores
        
def mean_without_masked(graph_obj, mask_X, mask_E, diffuse_nodes=True, diffuse_edges=True, avg_over_batch=False):
    '''
        Takes graph object (of type PlaceHolder) and returns the mean of the X and E values not considering masked elements
        
        input:
            graph: PlaceHolder object, with X=nodes and E=adjacency matrix
            mask_nodes: a boolean mask with True corresponding to the nodes we want to keep and False those we want to discard
        output:
            res: scalar (float), corresponding to the mean not taking into account the masked elements
    '''
    mask_X = mask_X.max(-1)[0]
    mask_E = mask_E.max(-1)[0].flatten(1,-1)
    
    graph_obj.E = graph_obj.E.flatten(1,-1)#graph_obj.E.sum(-1).flatten(1,-1)
    
    # take only elements we want to consider based on the mask
    graph_obj.X = graph_obj.X*mask_X
    graph_obj.E = graph_obj.E*mask_E
        
    if avg_over_batch:
        mean_X = graph_obj.X.sum()/mask_X.sum() 
        mean_E = graph_obj.E.sum()/mask_E.sum()
    else:
        mean_X = graph_obj.X.sum(-1)/mask_X.sum(-1) 
        mean_E = graph_obj.E.sum(-1)/mask_E.sum(-1)
    
    if diffuse_edges and diffuse_nodes:
        res = mean_X+mean_E
    elif diffuse_edges:
        res = mean_E
    elif diffuse_nodes:
        res = mean_X

    return res

def save_as_smiles(data, atom_types, bond_types, output_filename='default'):
    smiles = mol.rxn_from_graph_supernode(data=data, atom_types=atom_types, bond_types=bond_types)
    
    file_path = os.path.join(os.getcwd(), f'{output_filename}_output.gen')
    open(file_path, 'a').writelines(smiles+'\n')
    
def kl_prior(prior, limit, eps=EPS):
    """
        Computes the (point-wise) KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).
        
        input:
            prior: PlaceHolder graph object
            limit: PlaceHolder graph object
            eps: small offset value
        output:
            kl_prior_: PlaceHolder graph object
    """

    kl_x = F.kl_div(input=(prior.X+eps).log(), target=limit.X, reduction='none').sum(-1)
    kl_e = F.kl_div(input=(prior.E+eps).log(), target=limit.E, reduction='none').sum(-1)
    
    kl_prior_ = graph.PlaceHolder(X=kl_x, E=kl_e, node_mask=prior.node_mask, y=torch.zeros(1, dtype=torch.float))
    
    return kl_prior_

def reconstruction_logp(orig, pred_t0):
    # Normalize predictions

    # get prob pred for a given true rxn
    # E_{q(x_0)} E_{q(x_1|x_0)} log p(x_0|x_1)
    # x_0 ~ (q(x_0) = dataset) => x_1 ~ q(x_1|x_0) (by noising) => logits/probs p(x_0|x_1) (by denoising) 
    # => p(x_0|x_1)*x_0 to choose the probability of a specific category (x_0 is one-hot encoded)
    loss_term_0_x = (orig.X*pred_t0.X).sum(-1)
    loss_term_0_e = (orig.E*pred_t0.E).sum(-1)
    
    loss_term_0 = graph.PlaceHolder(X=loss_term_0_x, E=loss_term_0_e, node_mask=orig.node_mask, y=torch.ones(1, dtype=torch.float))

    return loss_term_0

def ce(pred, discrete_dense_true, diffuse_edges, lambda_E, log=False, mask_nodes=None, mask_edges=None, diffuse_nodes=True):
        
    true_X, true_E = discrete_dense_true.X, discrete_dense_true.E
    pred_X, pred_E = pred.X, pred.E
        
    true_X = true_X.reshape(-1,true_X.size(-1))  # (bs * n, dx)
    true_E = true_E.reshape(-1,true_E.size(-1))  # (bs * n * n, de)
    pred_X = pred_X.reshape(-1,pred_X.size(-1))  # (bs * n, dx)
    pred_E = pred_E.reshape(-1,pred_E.size(-1))  # (bs * n * n, de)

    # Remove other masked nodes or edges
    if mask_nodes is not None: 
        flat_true_X = true_X[mask_nodes,:]
        flat_pred_X = pred_X[mask_nodes,:]
    if mask_edges is not None: 
        flat_true_E = true_E[mask_edges,:]
        flat_pred_E = pred_E[mask_edges,:]

    flat_true_X_discrete = flat_true_X.argmax(dim=-1) # (bs*n,)
    flat_true_E_discrete = flat_true_E.argmax(dim=-1) # (bs*n*n,)

    loss_X = F.cross_entropy(flat_pred_X, flat_true_X_discrete, reduction='none')
    loss_E = F.cross_entropy(flat_pred_E, flat_true_E_discrete, reduction='none')

    batch_ce = diffuse_nodes * loss_X.mean() + diffuse_edges * lambda_E * loss_E.mean() 

    return loss_X, loss_E, batch_ce
    
def get_p_zs_given_zt(transition_model, t_array, pred, z_t, return_prob=False,
                      temperature_scaling=1.0):
    """Samples from zs ~ p(zs | zt). Only used during sampling.
        if last_step, return the graph prediction as well"""

    # sample z_s
    bs, n, dxs = z_t.X.shape
    device = z_t.X.device

    # Retrieve transition matrices
    Qtb = transition_model.get_Qt_bar(t_array, device)
    Qsb = transition_model.get_Qt_bar(t_array - 1, device)
    Qt = transition_model.get_Qt(t_array, device)
            
    # Normalize predictions
    pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
    pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
    p_s_and_t_given_0_X = compute_batched_over0_posterior_distribution(X_t=z_t.X, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X) # shape (bs, n, [x_{0}], [x_{t-1}])
    p_s_and_t_given_0_E = compute_batched_over0_posterior_distribution(X_t=z_t.E, Qt=Qt.E, Qsb=Qsb.E, Qtb=Qtb.E)

    # Dim of these two tensors: bs, N, d0, d_t-1
    weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X         # bs, n, d0, d_t-1 # p(x_{t-1}|x_t) = q(x_{t-1},x_t| x_0)p(x_0|x_t)
    unnormalized_prob_X = weighted_X.sum(dim=2)                     # bs, n, d_t-1
    unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 # in case pred is 0?
    prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  # bs, n, d_t-1

    pred_E_ = pred_E.reshape((bs, -1, pred_E.shape[-1]))
    weighted_E = pred_E_.unsqueeze(-1) * p_s_and_t_given_0_E        # bs, N, d0, d_t-1
    unnormalized_prob_E = weighted_E.sum(dim=-2)
    unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
    prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)
    prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
    
    prob_s = z_t.get_new_object(X=prob_X, E=prob_E).mask(z_t.node_mask)
    #prob_s = graph.PlaceHolder(X=prob_X.clone(), E=prob_E.clone(), y=z_t.y, node_mask=z_t.node_mask, atom_map_numbers=z_t.atom_map_numbers).mask(z_t.node_mask)
    
    if return_prob: return prob_s
    
    sampled_s = sample_discrete_features(prob_X, prob_E, node_mask=z_t.node_mask)

    X_s = F.one_hot(sampled_s.X, num_classes=z_t.X.shape[-1]).float()
    E_s = F.one_hot(sampled_s.E, num_classes=z_t.E.shape[-1]).float()

    assert (E_s == torch.transpose(E_s, 1, 2)).all()
    assert (z_t.X.shape == X_s.shape) and (z_t.E.shape == E_s.shape)

    out_one_hot = z_t.get_new_object(X=X_s, E=E_s, y=torch.zeros(z_t.y.shape[0], 0).to(device))
    out_discrete = z_t.get_new_object(X=X_s, E=E_s, y=torch.zeros(z_t.y.shape[0], 0).to(device))

    return out_one_hot.mask(z_t.node_mask).type_as(z_t.y), out_discrete.mask(z_t.node_mask, collapse=True).type_as(z_t.y)

def mol_diagnostic_plots(sample, atom_types, bond_types, name='mol.png', show=False, return_mol=False):  
    '''
        Plotting diagnostics of node and edge distributions.

        X: node matrix with distribution over node types. (n, dx)
        E: edge matrix with distribution over edge types. (n, n, de)
        name: the name of the file where to save the figure.
        show: Boolean to decide whether to show the matplot plot or not.

        return:
            plt fig object.
    ''' 
    # remove padding nodes
    X_no_padding = sample.X[sample.node_mask].cpu()
    E_no_padding = sample.E[sample.node_mask,...][:,sample.node_mask,...].cpu()
    # X_no_padding = X
    # E_no_padding = E
    # E_no_padding = E_no_padding.argmax(dim=-1) # because plotting the adjacency matrix in 2d

    fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(20, 7), 
                             gridspec_kw={'width_ratios': [1.5, 1.5, 1]}) # x, y
    # unused axises
    axes[0,0].axis('off')
    axes[2,0].axis('off')
    axes[0,2].axis('off')
    axes[2,2].axis('off')

    ## plot distribution over atom types for each atom
    sns_kargs = {"vmin": 0, "vmax": 1.0, "cmap": sns.cm.rocket_r}

    ### if want to put back y labels:
    ### yticklabels=[f"{i}" for i in range(X_no_padding.shape[0])],
    sns.heatmap(X_no_padding, xticklabels=atom_types, 
                ax=axes[1,0], cbar_kws={'pad': 0.01}, **sns_kargs)
    axes[1,0].set_title('atom types')
    axes[1,0].tick_params(axis='x', rotation=90)
    axes[1,0].yaxis.set_tick_params(labelleft=False)
    axes[1,0].set_yticks([])
    
    ## plot adjacency matrix with bond types
    vmap = {i:b for i, b in enumerate(bond_types)}
    n = len(vmap)
    # colors from https://matplotlib.org/stable/gallery/color/named_colors.html
    myColors = (mcolors.to_rgb("bisque"),  mcolors.to_rgb("teal"), mcolors.to_rgb("forestgreen"),
                mcolors.to_rgb("midnightblue"))
    cmap = LinearSegmentedColormap.from_list('Custom', myColors, len(myColors))

    # single bond
    ax = sns.heatmap(E_no_padding[..., 1], 
                     xticklabels=[f"{i}" for i in range(X_no_padding.shape[0])], 
                     yticklabels=[f"{i}" for i in range(X_no_padding.shape[0])],
                     cbar_kws={'pad': 0.01}, ax=axes[0, 1], **sns_kargs)

    # colorbar = ax.collections[0].colorbar
    # r = colorbar.vmax - colorbar.vmin
    # colorbar.set_ticks([colorbar.vmin + 0.5 * r / (n) + r * i / (n) for i in range(n)])
    # colorbar.set_ticklabels(list(vmap.values()))
    axes[0,1].set_title('single bonds')
    #axes[0, 1].tick_params(axis='x', rotation=90)
    axes[0,1].yaxis.set_tick_params(labelleft=False)
    axes[0,1].set_yticks([])
    axes[0,1].xaxis.set_tick_params(labelbottom=False)
    axes[0,1].set_xticks([])

    # double bond
    ax = sns.heatmap(E_no_padding[...,2], 
                     xticklabels=[f"{i}" for i in range(X_no_padding.shape[0])], 
                     yticklabels=[f"{i}" for i in range(X_no_padding.shape[0])],
                     cbar_kws={'pad': 0.01}, ax=axes[1, 1], **sns_kargs)
    axes[1,1].set_title('double bonds')
    axes[1,1].yaxis.set_tick_params(labelleft=False)
    axes[1,1].set_yticks([])
    axes[1,1].xaxis.set_tick_params(labelbottom=False)
    axes[1,1].set_xticks([])

    # triple bond
    ax = sns.heatmap(E_no_padding[...,3], 
                     xticklabels=[f"{i}" for i in range(X_no_padding.shape[0])], 
                     yticklabels=[f"{i}" for i in range(X_no_padding.shape[0])],
                     cbar_kws={'pad': 0.01}, ax=axes[2, 1], **sns_kargs)
    axes[2,1].set_title('triple bonds')
    axes[2,1].yaxis.set_tick_params(labelleft=False)
    axes[2,1].set_yticks([])
    axes[2,1].xaxis.set_tick_params(labelbottom=False)
    axes[2,1].set_xticks([])

    ## plot graphs as molecules
    sample.X = sample.X.unsqueeze(0)
    sample.E = sample.E.unsqueeze(0)
    sample.node_mask = sample.node_mask.unsqueeze(0)
    mol_g = sample_discrete_features(prob=sample)
    mol_g = mol_g.mask(mol_g.node_mask, collapse=True)
    mol_rdkit = mol.mol_from_graph(node_list=mol_g.X[0,...], adjacency_matrix=mol_g.E[0,...], 
                             atom_types=atom_types, bond_types=bond_types)
    img = Draw.MolToImage(mol_rdkit, size=(300, 300))
    axes[1,2].imshow(img)
    axes[1,2].set_title('sample molecule')
    axes[1,2].axis('off')
    axes[1,2].yaxis.set_tick_params(labelleft=False)
    axes[1,2].set_yticks([])
    axes[1,2].xaxis.set_tick_params(labelbottom=False)
    axes[1,2].set_xticks([])

    fig.suptitle(name.split('.png')[0])
    #plt.tight_layout(rect=[0, 0, 1, 0.97]) # hack to add space between suptitle and subplots. Yuck.
    plt.tight_layout()
    if show:
        plt.show()
   
    plt.savefig(name)
    plt.close()
    if return_mol:
        return fig, mol_rdkit
    else:
        return fig

def rxn_diagnostic_chains(chains, atom_types, bond_types, chain_name='default'):
    '''
        Visualize chains of a process as an mp4 video.

        chains: list of PlaceHolder objects representing a batch of graphs at each time step.
        len(chains)==nb_time_steps.

        Returns:
            (str) list of paths of mp4 videos of chains.
    '''
    nb_of_chains = chains[0][1].X.shape[0] # number of graph chains to plot
    imgio_kargs = {'fps': 1, 'quality': 10, 'macro_block_size': None, 'codec': 'h264',
                   'ffmpeg_params': ['-vf', 'crop=trunc(iw/2)*2:trunc(ih/2)*2']}
    
    # init a writer per chain
    writers = {}  
    sampled_mols = {}
    for t, samples_t in chains:
        for c in range(nb_of_chains):
            suno_idx = atom_types.index('SuNo') # offset because index 0 is for no node   
    
            suno_indices = (samples_t.X[c,...].argmax(-1)==suno_idx).nonzero(as_tuple=True)[0].cpu() 
            mols_atoms = torch.tensor_split(samples_t.X[c,...], suno_indices, dim=0)[1:-1] # ignore first set (SuNo) and last set (product)
            mols_edges = torch.tensor_split(samples_t.E[c,...], suno_indices, dim=0)[1:-1]
            node_masks = torch.tensor_split(samples_t.node_mask[c,...], suno_indices, dim=-1)[1:-1]
            
            for m, mol_atoms in enumerate(mols_atoms): # for each mol in sample
                chain_pic_name = f'{chain_name}_sample_t{t}_chain{c}_mol{m}.png'

                if c not in writers.keys():
                    writer = imageio.get_writer(f'{chain_name}_chain{c}_mol{m}.mp4', **imgio_kargs)
                    writers[c] = {m: writer}
                else:
                    if m not in writers[c].keys():
                        writer = imageio.get_writer(f'{chain_name}_chain{c}_mol{m}.mp4', **imgio_kargs)
                        writers[c][m] = writer
                    else:
                        writer = writers[c][m]

                mol_edges_to_all = mols_edges[m] 
                mol_edges_t = torch.tensor_split(mol_edges_to_all, suno_indices, dim=1)[1:] # ignore first because empty SuNo set
                mol_edges = mol_edges_t[m]
                mol_edges = mol_edges[1:,:][:,1:] # (n-1, n-1)
                mol_atoms = mol_atoms[1:] # (n-1)
                node_mask = node_masks[m][1:]
                
                one_sample = PlaceHolder(X=mol_atoms, E=mol_edges, node_mask=node_mask, y=torch.tensor([t], device=device).unsqueeze(-1))
                
                fig, mol = mol_diagnostic_plots(sample=one_sample, atom_types=atom_types, bond_types=bond_types, 
                                                name=chain_pic_name, show=False, return_mol=True)
                
                if c not in sampled_mols.keys():
                    sampled_mols[c] = {m: [mol]}
                else:
                    if m not in sampled_mols[c].keys():
                        sampled_mols[c][m] = [mol]
                    else:
                        sampled_mols[c][m].append(mol) 
                    
                img = imageio.v2.imread(os.path.join(os.getcwd(), chain_pic_name))
                writers[c][m].append_data(img)
            # repeat the last frame 10 times for a nicer video
            if t==0:
                for _ in range(10):
                    writers[c][m].append_data(img)

    # close previous writers
    for c in writers.keys():
        for m in writers[c].keys():
            writers[c][m].close()
    
    return [(os.path.join(os.getcwd(), f'{chain_name}_chain{c}_mol{m}.mp4'), os.path.join(os.getcwd(), f'chain{c}_mol{m}.png'), Chem.MolToSmiles(sampled_mols[c][m][-1])) for c in writers.keys() for m in range(len(writers[c]))]

def rxn_vs_sample_plot(true_rxns, sampled_rxns, atom_types, bond_types, chain_name='default', plot_dummy_nodes=False, rxn_offset_nb=0):
    '''
       Visualize the true rxn vs a rxn being sampled to compare the reactants more easily.
       
       rxn_offset_nb: where to start the count for naming the rxn plot (file).
    '''

    assert true_rxns.X.shape[0]==sampled_rxns[0][1].X.shape[0], 'You need to give as many true_rxns as there are chains.'+\
            f' Currently there are {true_rxns.X.shape[0]} true rxns and {sampled_rxns[0][1].X.shape[0]} chains.'
            
    # initialize the params of the video writer
    nb_of_chains = true_rxns.X.shape[0] # number of graph chains to plot
    imgio_kargs = {'fps': 1, 'quality': 10, 'macro_block_size': None, 'codec': 'h264', 'ffmpeg_params': ['-vf', 'crop=trunc(iw/2)*2:trunc(ih/2)*2']}
    
    # create a frame for each time step t
    writers = []
    for t, samples_t in sampled_rxns:
        for c in range(nb_of_chains):
            chain_pic_name = f'{chain_name}_t{t}_rxn{c+rxn_offset_nb}.png'
            # get image of the true rxn, to be added to each plot at time t 
            true_rxn = graph.PlaceHolder(X=true_rxns.X[c,...].unsqueeze(0), E=true_rxns.E[c,...].unsqueeze(0), node_mask=true_rxns.node_mask[c,...].unsqueeze(0), y=true_rxns.y)
            true_img = mol.rxn_plot(rxn=true_rxn, atom_types=atom_types, bond_types=bond_types, plot_dummy_nodes=plot_dummy_nodes)

            # get image of the sample rxn at time t
            one_sample_t = graph.PlaceHolder(X=samples_t.X[c,...].unsqueeze(0), E=samples_t.E[c,...].unsqueeze(0), y=samples_t.y, node_mask=samples_t.node_mask[c,...].unsqueeze(0))
            sampled_img = mol.rxn_plot(rxn=one_sample_t, atom_types=atom_types, bond_types=bond_types, plot_dummy_nodes=plot_dummy_nodes)
            
            # plot sampled and true rxn in the same fig
            fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(20, 7)) # x, y
            axes[0].axis('off')
            axes[1].axis('off')
            
            axes[0].set_title('sampled')
            axes[1].set_title('true')
            
            axes[0].imshow(sampled_img)
            axes[1].imshow(true_img)
            fig.suptitle(chain_pic_name.split('.png')[0])
            plt.savefig(chain_pic_name)
            plt.close()
            
            if c >= len(writers):
                writer = imageio.get_writer(f'{chain_name}_rxn{c+rxn_offset_nb}.mp4', **imgio_kargs)
                writers.append(writer)

            img = imageio.v2.imread(os.path.join(os.getcwd(), chain_pic_name))
            writers[c].append_data(img)
            
            # repeat the last frame 10 times for a nicer video
            if t==0:
                for _ in range(10):
                    writers[c].append_data(img)
                
    # close previous writers
    for c in range(len(writers)):
        writers[c].close()
                
    return [os.path.join(os.getcwd(), f'{chain_name}_rxn{c+rxn_offset_nb}.mp4') for  c in range(nb_of_chains)]
    
def assert_correctly_masked(variable, node_mask):
    # print(f'(variable * (1 - node_mask.long())).abs().max().item() {(variable * (1 - node_mask.long())).abs().max().item()}\n')
    # exit()
    assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \
        'Variables not masked properly.'

def inflate_batch_array(array, target_shape):
    """
    Inflates the batch array (array) with only a single axis (i.e. shape = (batch_size,), or possibly more empty
    axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape.
    """
    target_shape = (array.size(0),) + (1,) * (len(target_shape) - 1)
    return array.view(target_shape)

def sigma(gamma, target_shape):
    """Computes sigma given gamma."""
    return inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)), target_shape)

def sample_discrete_features(prob):
    ''' 
        Sample features from multinomial distribution with given probabilities (probX, probE)
        input: 
            probX: node features. (bs, n, dx_out)
            probE: edge features. (bs, n, n, de_out)
            node_mask: mask used by PyG for batching. (bs, n)
            (optional) y_t: the y feature of noisy object (often time step).
            
    '''
    
    # Noise X
    # The masked rows should define probability distributions as well
    probX, probE, y = prob.X.clone(), prob.E.clone(), prob.y.clone()
    
    probX[~prob.node_mask] = 1 / probX.shape[-1] # masked is ignored
    probX = probX.reshape(probX.size(0) * probX.size(1), -1)       # (bs * n, dx_out)
    
    #log.info(probX.sum(dim=-1))
    
    assert (abs(probX.sum(dim=-1) - 1) < 1e-4).all()

    # Sample X
    X_t = probX.multinomial(1)                                  # (bs * n, 1)
    X_t = X_t.reshape(prob.node_mask.size(0), prob.node_mask.size(1))     # (bs, n)
    X_t = X_t * prob.node_mask

    # Noise E
    # The masked rows should define probability distributions as well
    inverse_edge_mask = ~(prob.node_mask.unsqueeze(1) * prob.node_mask.unsqueeze(2))
    diag_mask = torch.eye(probE.size(1), probE.size(2)).unsqueeze(0).expand(probE.size(0), -1, -1)
    probE[inverse_edge_mask] = 1 / probE.shape[-1]
    probE[diag_mask.bool()] = 1 / probE.shape[-1] # allows sampling self edges when what we want is a valid dist
    probE = probE.reshape(probE.size(0) * probE.size(1) * probE.size(2), -1) # (bs * n * n, de_out)
    
    # Sample E
    E_t = probE.multinomial(1).reshape(prob.node_mask.size(0), prob.node_mask.size(1), prob.node_mask.size(1))   # (bs, n, n)
    E_t = torch.triu(E_t, diagonal=1)
    E_t = (E_t + torch.transpose(E_t, 1, 2))
    E_t = E_t * prob.node_mask.unsqueeze(dim=1) * prob.node_mask.unsqueeze(dim=2)
    
    X_t = F.one_hot(X_t, num_classes=probX.shape[-1]).float()
    E_t = F.one_hot(E_t, num_classes=probE.shape[-1]).float()
    
    z_t = prob.get_new_object(X=X_t, E=E_t, y=y)
    return z_t

def compute_posterior_distribution(M, M_t, Qt_M, Qsb_M, Qtb_M, log=False):
    ''' M: X or E
        Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T
    '''
    # Flatten feature tensors
    M = M.flatten(start_dim=1, end_dim=-2).to(torch.float32) # (bs, N, d) with N = n or n * n
    M_t = M_t.flatten(start_dim=1, end_dim=-2).to(torch.float32) # same
    Qt_M_T = torch.transpose(Qt_M, -2, -1) # (bs, d, d)

    if log == False:
        numerator = (M_t @ Qt_M_T) * (M @ Qsb_M) # (bs, N, d)
        denom = numerator.sum(-1, keepdim=True)
        prob = numerator / denom #.unsqueeze(-1) # (bs, N, d)
    else:
        eps = 1e-30
        M = torch.log_softmax(M, dim=-1)
        log_numerator = torch.log(M_t @ Qt_M_T + eps) + torch.log(torch.exp(M) @ Qsb_M + eps) #- torch.logsumexp(M, dim=-1, keepdim=True) # (bs, N, d)
        log_denom = torch.logsumexp(log_numerator,dim=-1,keepdim=True)
        prob = log_numerator - log_denom

    return prob

def compute_batched_over0_posterior_distribution(X_t, Qt, Qsb, Qtb):
    """ 
        Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0

        M: X or E
        X_t: bs, n, dt          or bs, n, n, dt
        Qt: bs, d_t-1, dt
        Qsb: bs, d0, d_t-1
        Qtb: bs, d0, dt
    """
    # Flatten feature tensors
    X_t = X_t.flatten(start_dim=1, end_dim=-2).to(torch.float32) # bs, N, dt with N=n or N=n*n
    Qt_T = Qt.transpose(-1, -2)                                  # bs, dt, d_t-1
    X_t_transposed = X_t.transpose(-1, -2)                       # bs, dt, N

    numerator = (X_t @ Qt_T).unsqueeze(dim=2) * Qsb.unsqueeze(1)             # bs, N, 1, d_t-1. Just use different rows of Qsb to represent the x_0 dimension. The last dimension should be x_{t-1} dimension
    denominator = (Qtb @ X_t_transposed).transpose(-1, -2).unsqueeze(-1)     # bs, d0, N, 1
    denominator[denominator==0] = 1e-6 

    out = numerator / denominator

    # Dimensions here: bs, N, d0, d_t-1
    return out

def mask_distributions(true_X, true_E, pred_X, pred_E, node_mask):
    # Set masked rows to arbitrary distributions, so it doesn't contribute to loss
    row_X = torch.zeros(true_X.size(-1), dtype=torch.float, device=true_X.device)
    row_X[0] = 1.
    row_E = torch.zeros(true_E.size(-1), dtype=torch.float, device=true_E.device)
    row_E[0] = 1.

    diag_mask = ~torch.eye(node_mask.size(1), device=node_mask.device, dtype=torch.bool).unsqueeze(0)
    true_X[~node_mask] = row_X
    true_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E
    pred_X[~node_mask] = row_X
    pred_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E

    return true_X, true_E, pred_X, pred_E

def posterior_distributions(X, E, y, X_t, E_t, y_t, Qt, Qsb, Qtb):
    
    prob_X = compute_posterior_distribution(M=X, M_t=X_t, Qt_M=Qt.X, Qsb_M=Qsb.X, Qtb_M=Qtb.X)   # (bs, n, dx)
    prob_E = compute_posterior_distribution(M=E, M_t=E_t, Qt_M=Qt.E, Qsb_M=Qsb.E, Qtb_M=Qtb.E)   # (bs, n * n, de)

    return PlaceHolder(X=prob_X, E=prob_E, y=y_t)

def sample_from_noise(limit_dist, node_mask, T):
    """ 
        Sample from the limit distribution of the diffusion process.
        
        input:
            limit_dist: stationary distribution of the diffusion process.
            node_mask: masking used by PyG for batching.
        output:
            z_T: sampled node and edge features.
    """
    bs, n_max = node_mask.shape
    x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1)
    e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1)
    z_X = x_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max).long()
    z_E = e_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max, n_max).long()
    z_X = F.one_hot(z_X, num_classes=x_limit.shape[-1]).float().to(device)
    z_E = F.one_hot(z_E, num_classes=e_limit.shape[-1]).float().to(device)
    z_y = T*torch.ones((bs,1)).to(device)

    # Get upper triangular part of edges, without main diagonal
    upper_triangular_mask = torch.zeros_like(z_E).to(device)
    indices = torch.triu_indices(row=z_E.size(1), col=z_E.size(2), offset=1)
    upper_triangular_mask[:,indices[0],indices[1],:] = 1
    # make sure adjacency matrix is symmetric over the diagonal
    z_E = z_E * upper_triangular_mask
    z_E = (z_E + torch.transpose(z_E, 1, 2))
    assert (z_E == torch.transpose(z_E, 1, 2)).all()

    return PlaceHolder(X=z_X.to(device), E=z_E.to(device), y=z_y.to(device), node_mask=node_mask.to(device)).mask(node_mask.to(device))