import os
import pathlib
import logging
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast

from src.neuralnet.transformer_model_with_y import GraphTransformerWithY, GraphTransformerWithYAtomMapPosEmb, PositionalEmbedding
from src.neuralnet.transformer_model_stacked import GraphTransformerWithYStacked
from src.diffusion.noise_schedule import *
from src.utils import graph, mol
from src.neuralnet.ema_pytorch import EMA
from src.utils.diffusion import helpers

# A logger for this file
log = logging.getLogger(__name__)
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*') # Disable rdkit warnings

parent_path = pathlib.Path(os.path.realpath(__file__)).parents[2]

class DiscreteDenoisingDiffusion(nn.Module):
    def __init__(self, cfg, dataset_infos, node_type_counts_unnormalized=None, edge_type_counts_unnormalized=None, use_data_parallel=None):
        super().__init__()
        input_dims = dataset_infos.input_dims
        output_dims = dataset_infos.output_dims
        nodes_dist = dataset_infos.nodes_dist
        if cfg.neuralnet.extra_features:
            input_dims = {'X': input_dims['X'] + 8, 'E': input_dims['E'], 'y': input_dims['y'] + 12}

        self.cfg = cfg
        self.T = cfg.diffusion.diffusion_steps

        self.Xdim_output = output_dims['X']
        self.Edim_output = output_dims['E']
        if self.cfg.neuralnet.increase_y_dim_for_multigpu: # Fix necessary for multi-gpu training due to a corner-case
            output_dims['y'] += 1
        self.ydim_output = output_dims['y'] 
        self.node_dist = nodes_dist

        self.pos_emb_module = PositionalEmbedding(cfg.neuralnet.hidden_dims['dx'])

        self.dataset_info = dataset_infos
        self.log_to_wandb = cfg.train.log_to_wandb

        self.eps = 1e-6
        self.log_every_steps= cfg.general.log_every_steps
        
        node_idx_to_mask, edge_idx_to_mask = graph.get_index_from_states(atom_decoder=self.dataset_info.atom_decoder,
                                                                         bond_decoder=self.dataset_info.bond_decoder,
                                                                         node_states_to_mask=cfg.diffusion.node_states_to_mask,
                                                                         edge_states_to_mask=cfg.diffusion.edge_states_to_mask,
                                                                         device=device)

        abs_state_position_e = 0
        abs_state_position_x = self.dataset_info.atom_decoder.index('Au') 
        
        if self.cfg.neuralnet.architecture=='with_y': # permutation equivariant model
            self.model = GraphTransformerWithY(n_layers=cfg.neuralnet.n_layers, input_dims=input_dims,
                                               hidden_mlp_dims=cfg.neuralnet.hidden_mlp_dims,
                                               hidden_dims=cfg.neuralnet.hidden_dims,
                                               output_dims=output_dims, act_fn_in=nn.ReLU(), act_fn_out=nn.ReLU(), dropout=cfg.neuralnet.dropout)
        elif self.cfg.neuralnet.architecture=='with_y_atommap_number_pos_enc': # aligned model with large graph where we join the product and reactant graphs into a large graph
            self.model = GraphTransformerWithYAtomMapPosEmb(n_layers=cfg.neuralnet.n_layers, input_dims=input_dims,
                                               hidden_mlp_dims=cfg.neuralnet.hidden_mlp_dims,
                                               hidden_dims=cfg.neuralnet.hidden_dims,
                                               output_dims=output_dims, act_fn_in=nn.ReLU(), act_fn_out=nn.ReLU(),
                                               pos_emb_permutations=cfg.neuralnet.pos_emb_permutations, dropout=cfg.neuralnet.dropout,
                                               p_to_r_skip_connection=cfg.neuralnet.p_to_r_skip_connection,
                                               p_to_r_init=cfg.neuralnet.p_to_r_init,
                                               input_alignment=cfg.neuralnet.input_alignment)
        elif self.cfg.neuralnet.architecture=='with_y_stacked': # aligned model with the 'input' alignment where the product conditioning is given by overlaying it on the reactant graph
            self.model = GraphTransformerWithYStacked(n_layers=cfg.neuralnet.n_layers, input_dims=input_dims,
                                               hidden_mlp_dims=cfg.neuralnet.hidden_mlp_dims,
                                               hidden_dims=cfg.neuralnet.hidden_dims,
                                               output_dims=output_dims, act_fn_in=nn.ReLU(), act_fn_out=nn.ReLU(),
                                               pos_emb_permutations=cfg.neuralnet.pos_emb_permutations, dropout=cfg.neuralnet.dropout,
                                               p_to_r_skip_connection=cfg.neuralnet.p_to_r_skip_connection,
                                               p_to_r_init=cfg.neuralnet.p_to_r_init)
        if use_data_parallel:
            log.info(f"Using {torch.cuda.device_count()} GPUs for training")
            self.model = torch.nn.DataParallel(self.model)

        self.ema = None
        if cfg.neuralnet.use_ema:
            self.ema = EMA(self.model, beta=cfg.neuralnet.ema_decay, power=1)

        if cfg.diffusion.transition=='absorbing_masknoedge':
            self.transition_model, self.transition_model_eval = (AbsorbingStateTransitionMaskNoEdge(x_classes=self.Xdim_output, e_classes=self.Edim_output,
                                                                       y_classes=self.ydim_output, timesteps=T_,
                                                                       diffuse_edges=cfg.diffusion.diffuse_edges,
                                                                       abs_state_position_e=abs_state_position_e, abs_state_position_x=abs_state_position_x,
                                                                       node_idx_to_mask=node_idx_to_mask, edge_idx_to_mask=edge_idx_to_mask)
                                                                       for T_ in [cfg.diffusion.diffusion_steps, cfg.diffusion.diffusion_steps_eval])
            self.limit_dist = self.transition_model.get_limit_dist()
        else: 
            assert f'Transition model undefined. Got {cfg.diffusion.transition}\n.'

    def training_step(self, data, i, device):
        dense_data = graph.to_dense(data=data).to_device(device)
        t_int = torch.randint(1, self.T+1, size=(len(data),1), device=device)
        z_t, _ = self.apply_noise(dense_data, t_int = t_int, transition_model=self.transition_model)
        
        z_t = graph.apply_mask(orig=dense_data, z_t=z_t, atom_decoder=self.dataset_info.atom_decoder,
                               bond_decoder=self.dataset_info.bond_decoder, mask_nodes=self.cfg.diffusion.mask_nodes, 
                               mask_edges=self.cfg.diffusion.mask_edges, return_masks=False)
        
        z_t_ = z_t.get_new_object()
        if not self.cfg.diffusion.diffuse_edges: z_t.E = dense_data.E.clone()
        if not self.cfg.diffusion.diffuse_nodes: z_t.X = dense_data.X.clone()
        
        if torch.cuda.is_available() and self.cfg.train.use_mixed_precision:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                pred = self.forward(z_t=z_t)
        else:
            pred = self.forward(z_t=z_t)

        pred, mask_X, mask_E = graph.apply_mask(orig=dense_data, z_t=pred,
                                                atom_decoder=self.dataset_info.atom_decoder,
                                                bond_decoder=self.dataset_info.bond_decoder,
                                                mask_nodes=self.cfg.diffusion.mask_nodes,
                                                mask_edges=self.cfg.diffusion.mask_edges,
                                                as_logits=True, return_masks=True)
        if not self.cfg.diffusion.diffuse_edges: pred.E = dense_data.E.clone()
        if not self.cfg.diffusion.diffuse_nodes: pred.X = dense_data.X.clone()

        mask_nodes = mask_X.max(-1)[0].flatten(0,-1) # identify nodes to ignore in masking
        mask_edges = mask_E.max(-1)[0].flatten(0,-1)
            
        if self.cfg.train.loss=='ce':
            loss_X, loss_E, loss = helpers.ce(pred=pred, discrete_dense_true=dense_data,
                                                  diffuse_edges=self.cfg.diffusion.diffuse_edges,
                                                  diffuse_nodes=self.cfg.diffusion.diffuse_nodes, 
                                                  lambda_E=self.cfg.diffusion.lambda_train[0], 
                                                  log=(i % self.log_every_steps == 0) and self.log_to_wandb,
                                                  mask_nodes=mask_nodes, mask_edges=mask_edges)
        elif self.cfg.train.loss=='vb':
            loss = self.elbo_batch_quick(dense_data, pred=pred, z_t=z_t_, lambda_E=self.cfg.diffusion.lambda_train[0])
            loss_X, loss_E = torch.zeros((1,)), torch.zeros((1,))
        else:
            raise ValueError(f'Loss function {self.cfg.train.loss} not recognized.')
            
        return loss_X, loss_E, loss 

    def apply_noise(self, dense_data, t_int, transition_model):
        """ 
            Sample noise and apply it to the data. 
            
            input:
                discrete_data: batch graph object with nodes and edges in discrete form.
                t_int: time step for noise.
            return: 
                (PlaceHolder) z_t.
        """
        X, E, y = dense_data.X, dense_data.E, dense_data.y
        device = dense_data.X.device

        assert X.dim()==3, 'Expected X in batch format.'+\
               f' Got X.dim={X.dim()}, If using one example, add batch dimension with: X.unsqueeze(dim=0).'
        
        Qtb = transition_model.get_Qt_bar(t_int.cpu(), device=device) # (bs, dx_in, dx_out), (bs, de_in, de_out)

        # Qtb.X and Qtb.E should have batch dimension
        assert Qtb.X.dim()==3 and Qtb.E.dim()==3, 'Expected Qtb.X and Qtb.E to have ndim=3 ((bs, dx/de, dx/de)) respectively. '+\
                                                  f'Got Qtb.X.dim={Qtb.X.dim()} and Qtb.E.dim={Qtb.E.dim()}.'
        # both Qtb.X and Qtb.E should be row normalized
        assert (abs(Qtb.X.sum(dim=-1)-1.) < 1e-4).all()
        assert (abs(Qtb.E.sum(dim=-1)-1.) < 1e-4).all()

        # compute transition probabilities
        probE = E @ Qtb.E.unsqueeze(1) # (bs, n, n, de_out)
        probX = X @ Qtb.X  # (bs, n, dx_out)
        
        prob_t = dense_data.get_new_object(X=probX, E=probE, y=t_int.float()).mask(dense_data.node_mask)
        
        z_t = helpers.sample_discrete_features(prob=prob_t)

        assert (X.shape==z_t.X.shape) and (E.shape==z_t.E.shape), 'Noisy and original data do not have the same shape.'

        return z_t, prob_t
    
    def get_pos_encodings(self, z_t):
        if self.cfg.neuralnet.architecture == 'with_y_atommap_number_pos_enc':
            if self.cfg.neuralnet.pos_encoding_type == 'laplacian_pos_enc':
                pos_encodings = self.pos_emb_module.matched_positional_encodings_laplacian(z_t.E.argmax(-1), z_t.atom_map_numbers, z_t.mol_assignments, self.cfg.neuralnet.num_lap_eig_vectors)
            elif self.cfg.neuralnet.pos_encoding_type == 'laplacian_pos_enc_gpu':
                pos_encodings = self.pos_emb_module.matched_positional_encodings_laplacian_gpu(z_t.E.argmax(-1), z_t.atom_map_numbers, z_t.mol_assignments, self.cfg.neuralnet.num_lap_eig_vectors)
            else:
                raise ValueError(f'pos_encoding_type {self.cfg.neuralnet.pos_encoding_type} not recognized')
        elif self.cfg.neuralnet.architecture == 'with_y_stacked':
            model = self.model.module if isinstance(self.model, nn.DataParallel) else self.model
            
            suno_number = self.cfg.dataset.atom_types.index("SuNo")
            reaction_side_separation_index = (z_t.X.argmax(-1) == suno_number).nonzero(as_tuple=True)[1]
            
            if self.cfg.neuralnet.pos_encoding_type == 'laplacian_pos_enc_gpu':
                pos_encodings = model.pos_emb_module.matched_positional_encodings_laplacian(z_t.E.argmax(-1), z_t.atom_map_numbers, z_t.mol_assignments, self.cfg.neuralnet.num_lap_eig_vectors)
                pos_encodings, _ = model.cut_reaction_reactant_part_X_only(pos_encodings, reaction_side_separation_index) # ... this could be done in forward as well, more efficient with multiple GPUs. Actually both parts hmm
            elif self.cfg.neuralnet.pos_encoding_type == 'laplacian_pos_enc':
                pos_encodings = model.pos_emb_module.matched_positional_encodings_laplacian_scipy(z_t.E.argmax(-1), z_t.atom_map_numbers, z_t.mol_assignments, self.cfg.neuralnet.num_lap_eig_vectors)
                pos_encodings, _ = model.cut_reaction_reactant_part_X_only(pos_encodings, reaction_side_separation_index)
            else:
                pos_encodings = torch.zeros(z_t.X.shape[0], z_t.X.shape[1], model.input_dim_X, device=z_t.X.device)
                pos_encodings, _ = model.cut_reaction_reactant_part_X_only(pos_encodings, reaction_side_separation_index)
        else:
            pos_encodings = None
        return pos_encodings
    
    def get_pos_encodings_if_relevant(self, z_t):
        if self.cfg.neuralnet.architecture == 'with_y_atommap_number_pos_enc' or 'with_y_stacked': # Precalculate the pos encs, since they are the same for each step in the loop
            pos_encodings = self.get_pos_encodings(z_t)
        else:
            pos_encodings = None
        return pos_encodings

    def forward(self, z_t, pos_encodings=None):
        device = z_t.X.device

        # randomly permute the atom mappings to make sure we don't use them wrongly
        perm = torch.arange(z_t.atom_map_numbers.max().item()+1, device=device)[1:]
        perm = perm[torch.randperm(len(perm))]
        perm = torch.cat([torch.zeros(1, dtype=torch.long, device=device), perm])
        z_t.atom_map_numbers = perm[z_t.atom_map_numbers]
        
        if self.cfg.neuralnet.extra_features:
            with autocast(enabled=False):
                z_t = self.compute_extra_data(z_t)
        if self.cfg.neuralnet.architecture=='with_y_atommap_number_pos_enc':
            assert z_t.mol_assignments is not None, 'molecule_assigments is None in forward()'
            if pos_encodings == None: 
                with autocast(enabled=False):
                    # if pos encs weren't precalculated in the sampling loop
                    pos_encodings = self.get_pos_encodings(z_t)

            if self.cfg.neuralnet.use_ema and not self.training:
                res = self.ema(z_t.X, z_t.E, z_t.y, z_t.node_mask, z_t.atom_map_numbers, pos_encodings, z_t.mol_assignments)
            else:
                res = self.model(z_t.X, z_t.E, z_t.y, z_t.node_mask, z_t.atom_map_numbers, pos_encodings, z_t.mol_assignments)
        elif self.cfg.neuralnet.architecture=='with_y_stacked':
            if self.cfg.neuralnet.use_ema and not self.training:
                res = self.ema(z_t.X, z_t.E, z_t.y, z_t.node_mask, z_t.atom_map_numbers, pos_encodings, z_t.mol_assignments, use_pos_encoding_if_applicable, self.cfg.neuralnet.pos_encoding_type, self.cfg.neuralnet.num_lap_eig_vectors, self.cfg.dataset.atom_types)
            else:
                res = self.model(z_t.X, z_t.E, z_t.y, z_t.node_mask, z_t.atom_map_numbers, pos_encodings, z_t.mol_assignments, use_pos_encoding_if_applicable, self.cfg.neuralnet.pos_encoding_type, self.cfg.neuralnet.num_lap_eig_vectors, self.cfg.dataset.atom_types)
        else:
            if self.cfg.neuralnet.use_ema and not self.training:
                res = self.ema(z_t.X, z_t.E, z_t.y, z_t.node_mask)
            else:
                res = self.model(z_t.X, z_t.E, z_t.y, z_t.node_mask)
        if isinstance(res, tuple):
            X, E, y, node_mask = res
            res = z_t.get_new_object(X=X, E=E, y=y, node_mask=node_mask)
        return res
    
    def compute_Lt_all(self, dense_true):
        '''
            Compute L_s terms: E_{q(x_t|x)} KL[q(x_s|x_t,x_0)||p(x_s|x_t)], with s = t-1
            But compute all of the terms, is this how we want the function to behave?
            To test this, would be nice to have a function for defining the transition matrices 
            for different time steps
        '''
                
        device = dense_true.X.device
        true_X, true_E = dense_true.X, dense_true.E
        
        Lts = []
        
        assert self.T % self.cfg.diffusion.diffusion_steps_eval == 0, 'diffusion_steps_eval should be divisible by diffusion_steps'
        all_steps = list(range(self.cfg.diffusion.diffusion_steps_eval+1)) #np.linspace(0, self.T, self.cfg.diffusion.diffusion_steps_eval+1).astype('int')
        eval_step_size = self.T // self.cfg.diffusion.diffusion_steps_eval
        steps_to_eval_here = all_steps[2:]

        pos_encodings = self.get_pos_encodings_if_relevant(dense_true)
        
        for idx, t in enumerate(steps_to_eval_here):
            t_int = torch.ones((true_X.shape[0], 1)).to(device)*t
            z_t, _ = self.apply_noise(dense_true, t_int=t_int, transition_model=self.transition_model_eval)
            z_t.y *= eval_step_size # Adjust the neural net input to the correct range

            z_t = graph.apply_mask(orig=dense_true, z_t=z_t,
                               atom_decoder=self.dataset_info.atom_decoder,
                               bond_decoder=self.dataset_info.bond_decoder, 
                               mask_nodes=self.cfg.diffusion.mask_nodes, 
                               mask_edges=self.cfg.diffusion.mask_edges,
                               return_masks=False)
            
            pred = self.forward(z_t=z_t, pos_encodings=pos_encodings)
            
            # compute q(x_{t-1}|x_t, x_0) for X and E
            Lt = self.compute_Lt(dense_true=dense_true, z_t=z_t, t=t_int, x_0_tilde_logit=pred, 
                                 transition_model=self.transition_model_eval)
            Lt.to_device('cpu')
            Lts.append(Lt)

        return Lts

    def compute_Lt(self, dense_true, z_t, t, x_0_tilde_logit, transition_model, log=False):
        assert t.shape[1]==1, 't should be a tensor of shape (bs, 1)'
        bs, n, v = dense_true.X.shape
        e = dense_true.E.shape[-1]
        device = z_t.X.device
        s = t - 1

        Qt = transition_model.get_Qt(t, device=device)
        Qtb = transition_model.get_Qt_bar(t, device=device) 
        Qsb = transition_model.get_Qt_bar(s, device=device)
        
        # compute q(x_{t-1}|x_t) = q(x_{t-1}|x_t, x_0)
        q_s_given_t_0_X = helpers.compute_posterior_distribution(M=dense_true.X, M_t=z_t.X, Qt_M=Qt.X, Qsb_M=Qsb.X, Qtb_M=Qtb.X)
        q_s_given_t_0_X = q_s_given_t_0_X.reshape(bs, n, v)
        q_s_given_t_0_E = helpers.compute_posterior_distribution(M=dense_true.E, M_t=z_t.E, Qt_M=Qt.E, Qsb_M=Qsb.E, Qtb_M=Qtb.E)
        q_s_given_t_0_E = q_s_given_t_0_E.reshape(bs, n, n, e)
        q_s_given_t_0 = z_t.get_new_object(X=q_s_given_t_0_X, E=q_s_given_t_0_E)

        if log==False:
            x_0_tilde = z_t.get_new_object(X=F.softmax(x_0_tilde_logit.X, dim=-1), E=F.softmax(x_0_tilde_logit.E, dim=-1))
            p_s_given_t_X = helpers.compute_posterior_distribution(M=x_0_tilde.X, M_t=z_t.X, Qt_M=Qt.X, Qsb_M=Qsb.X, Qtb_M=Qtb.X)        
            p_s_given_t_X = p_s_given_t_X.reshape(bs, n, v)
            p_s_given_t_E = helpers.compute_posterior_distribution(M=x_0_tilde.E, M_t=z_t.E, Qt_M=Qt.E, Qsb_M=Qsb.E, Qtb_M=Qtb.E)
            p_s_given_t_E = p_s_given_t_E.reshape(bs, n, n, e)
            p_s_given_t = z_t.get_new_object(X=p_s_given_t_X, E=p_s_given_t_E)
            p_s_given_t = graph.apply_mask(orig=dense_true, z_t=p_s_given_t, 
                                            atom_decoder=self.dataset_info.atom_decoder,
                                            bond_decoder=self.dataset_info.bond_decoder, 
                                            mask_nodes=self.cfg.diffusion.mask_nodes, 
                                            mask_edges=self.cfg.diffusion.mask_edges,
                                            node_states_to_mask=self.cfg.diffusion.node_states_to_mask, 
                                            edge_states_to_mask=self.cfg.diffusion.edge_states_to_mask)
        else:
            log_p_s_given_t_X = helpers.compute_posterior_distribution(M=x_0_tilde_logit.X, M_t=z_t.X, Qt_M=Qt.X, Qsb_M=Qsb.X, Qtb_M=Qtb.X, log=True)
            log_p_s_given_t_X = log_p_s_given_t_X.reshape(bs, n, v)
            log_p_s_given_t_E = helpers.compute_posterior_distribution(M=x_0_tilde_logit.E, M_t=z_t.E, Qt_M=Qt.E, Qsb_M=Qsb.E, Qtb_M=Qtb.E, log=True)
            log_p_s_given_t_E = log_p_s_given_t_E.reshape(bs, n, n, e)
            log_p_s_given_t = z_t.get_new_object(X=log_p_s_given_t_X, E=log_p_s_given_t_E)
            log_p_s_given_t = graph.apply_mask(orig=dense_true, z_t=log_p_s_given_t, 
                                                atom_decoder=self.dataset_info.atom_decoder,
                                                bond_decoder=self.dataset_info.bond_decoder, 
                                                mask_nodes=self.cfg.diffusion.mask_nodes, 
                                                mask_edges=self.cfg.diffusion.mask_edges,
                                                node_states_to_mask=self.cfg.diffusion.node_states_to_mask, 
                                                edge_states_to_mask=self.cfg.diffusion.edge_states_to_mask, 
                                                as_logits=True)

        q_s_given_t_0 = graph.apply_mask(orig=dense_true, z_t=q_s_given_t_0, 
                                         atom_decoder=self.dataset_info.atom_decoder,
                                         bond_decoder=self.dataset_info.bond_decoder, 
                                         mask_nodes=self.cfg.diffusion.mask_nodes, 
                                         mask_edges=self.cfg.diffusion.mask_edges,
                                         node_states_to_mask=self.cfg.diffusion.node_states_to_mask, 
                                         edge_states_to_mask=self.cfg.diffusion.edge_states_to_mask)
        
        # compute KL(true||pred) = KL(target||input)
        if log == False:
            kl_x = F.kl_div(input=(p_s_given_t.X+self.eps).log(), target=q_s_given_t_0.X, reduction='none').sum(-1)
            kl_e = F.kl_div(input=(p_s_given_t.E+self.eps).log(), target=q_s_given_t_0.E, reduction='none').sum(-1)
        else:
            kl_x = F.kl_div(input=torch.log_softmax(log_p_s_given_t.X, -1), target=q_s_given_t_0.X, reduction='none').sum(-1)
            kl_e = F.kl_div(input=torch.log_softmax(log_p_s_given_t.E, -1), target=q_s_given_t_0.E, reduction='none').sum(-1)

        Lt = z_t.get_new_object(X=kl_x, E=kl_e)
        
        return Lt

    def compute_L1(self, dense_true, pos_encodings=None):
        device = dense_true.X.device
        t_int = torch.ones((dense_true.X.shape[0],1), device=device)

        z_1, _ = self.apply_noise(dense_true, t_int=t_int, transition_model=self.transition_model_eval)

        z_1 = graph.apply_mask(orig=dense_true, z_t=z_1,
                            atom_decoder=self.dataset_info.atom_decoder,
                            bond_decoder=self.dataset_info.bond_decoder, 
                            mask_nodes=self.cfg.diffusion.mask_nodes, 
                            mask_edges=self.cfg.diffusion.mask_edges,
                            return_masks=False)

        assert self.T % self.cfg.diffusion.diffusion_steps_eval == 0, 'diffusion_steps_eval should be divisible by diffusion_steps'
        eval_step_size = self.T // self.cfg.diffusion.diffusion_steps_eval
        z_1.y *= eval_step_size # Adjust the neural net input to the correct range

        pred0 = self.forward(z_t=z_1, pos_encodings=pos_encodings)       
        pred0 = graph.apply_mask(orig=dense_true, z_t=pred0, 
                                 atom_decoder=self.dataset_info.atom_decoder,
                                 bond_decoder=self.dataset_info.bond_decoder, 
                                 mask_nodes=self.cfg.diffusion.mask_nodes, 
                                 mask_edges=self.cfg.diffusion.mask_edges,
                                 node_states_to_mask=self.cfg.diffusion.node_states_to_mask, 
                                 edge_states_to_mask=self.cfg.diffusion.edge_states_to_mask,
                                 as_logits=True)
        pred0.X = F.log_softmax(pred0.X,-1)
        pred0.E = F.log_softmax(pred0.E,-1)
        
        loss_term_0 = helpers.reconstruction_logp(orig=dense_true, pred_t0=pred0)

        return loss_term_0

    def kl_prior(self, dense_true):
        device = dense_true.X.device
            
        X, E = dense_true.X, dense_true.E
        bs, n, v, e = X.shape[0], X.shape[1], X.shape[-1], E.shape[-1]

        # compute p(x_T)
        Ts = self.T*torch.ones((bs,1), device=device)
        Qtb = self.transition_model.get_Qt_bar(Ts, device)

        probX = X @ Qtb.X  # (bs, n, dx_out)
        probE = E @ Qtb.E.unsqueeze(1)  # (bs, n, n, de_out)
        assert probX.shape == X.shape
        
        prob = dense_true.get_new_object(X=probX, E=probE)

        # compute q(x_T)
        limitX = self.limit_dist.X[None, None, :].expand(bs, n, -1).type_as(X)
        limitE = self.limit_dist.E[None, None, None, :].expand(bs, n, n, -1).type_as(E)
        limit = dense_true.get_new_object(X=limitX, E=limitE)
        
        limit = graph.apply_mask(orig=dense_true, z_t=limit,
                                 atom_decoder=self.dataset_info.atom_decoder,
                                 bond_decoder=self.dataset_info.bond_decoder, 
                                 mask_nodes=self.cfg.diffusion.mask_nodes, 
                                 mask_edges=self.cfg.diffusion.mask_edges,
                                 node_states_to_mask=self.cfg.diffusion.node_states_to_mask, 
                                 edge_states_to_mask=self.cfg.diffusion.edge_states_to_mask)
        
        prob = graph.apply_mask(orig=dense_true, z_t=prob,
                                 atom_decoder=self.dataset_info.atom_decoder,
                                 bond_decoder=self.dataset_info.bond_decoder, 
                                 mask_nodes=self.cfg.diffusion.mask_nodes, 
                                 mask_edges=self.cfg.diffusion.mask_edges,
                                 node_states_to_mask=self.cfg.diffusion.node_states_to_mask, 
                                 edge_states_to_mask=self.cfg.diffusion.edge_states_to_mask)
        
        kl_prior_ = helpers.kl_prior(prior=prob, limit=limit, eps=self.eps)
        
        return kl_prior_
    
    @torch.no_grad()
    def get_elbo_of_data(self, dataloader, n_samples, device):
        """
        Computes the negative Evidence Lower Bound (ELBO) of the model on a given dataset.

        Args:
            dataloader (torch.utils.data.DataLoader): A PyTorch DataLoader object that provides batches of data.
            n_samples (int): The number of samples to use for the ELBO estimation.

        Returns:
            float: The estimated ELBO of the model on the given dataset.
        """
        
        batch_size = graph.get_batch_size_of_dataloader(dataloader)
        num_batches = max(n_samples//batch_size+int(n_samples%batch_size>0),1)
        assert num_batches<=len(dataloader), 'ELBO: testing more batches than is available in the dataset.'
        log.info(f"Num of batches needed for ELBO estimation: {num_batches}")
        total_elbo = 0
        dataiter = iter(dataloader)
        for _ in range(num_batches):
            data = next(dataiter)
            data = data.to(device)
            dense_true = graph.to_dense(data=data).to_device(device)
            elbo, _, _ = self.elbo(dense_true)
            total_elbo += elbo
            
        total_elbo /= num_batches

        # returning negative elbo as an upper bound on NLL
        return total_elbo
    
    def elbo_batch_quick(self, dense_true, pred, z_t, lambda_E=1.0, avg_over_batch=True):
        """
        Computes an estimator for the variational lower bound, but sampled such that we only
        get a single timestep t for each batch element. This makes it possible to train the model
        as well.

        input:
           discrete_true: a batch of data in discrete format (batch_size, n, total_features)
            z_t: sampled data at some timestep t (containts that in z_t.y)
            lambda_E: weight for the E term in the loss
            pred: the prediction of the model for the given z_t
            avg_over_batch: whether to average over the batch or not
           
        output:
            (float) the ELBO value of the given data batch.
        """
        t = z_t.y
        device = dense_true.X.device

        # Prior term
        # pred = self.forward(z_t=z_t)
        # If the transition matrix goes to identity as t->0, then this works for all steps, including t=1
        term_t = self.compute_Lt(dense_true=dense_true, z_t=z_t, t=t, x_0_tilde_logit=pred,
                               transition_model=self.transition_model, log=True) # Placeholder object, with, e.g., X of shape (batch_size, n, total_features)
        #loss *= self.cfg.diffusion.diffusion_steps # scale the estimator to the full ELBO
        kl_prior = self.kl_prior(dense_true) # Should be zero

        # TODO: This is not really quite right... the z_t is sampled from noise, the reconstruction loss is now just another CE loss
        pred.X, pred.E = F.log_softmax(pred.X, dim=-1), F.log_softmax(pred.E, dim=-1)
        term_1 = helpers.reconstruction_logp(orig=dense_true, pred_t0=pred)
        term_1.X, term_1.E = -term_1.X, -term_1.E

        # Manually weight the E term for training
        term_t.E, term_1.E, kl_prior.E = term_t.E * lambda_E, term_1.E * lambda_E, kl_prior.E * lambda_E

        # Combine the loss 1 term and the t terms
        terms_t_1 = term_t.get_new_object(X=torch.zeros_like(term_1.X), E=torch.zeros_like(term_1.E))
        #terms_t_1 = graph.PlaceHolder(X = torch.zeros_like(term_1.X), E = torch.zeros_like(term_1.E), y = term_t.y, node_mask = term_t.node_mask)
        expanded_t = t.repeat([1, dense_true.X.shape[1]])
        terms_t_1.X[expanded_t==1], terms_t_1.E[expanded_t==1] = term_1.X[expanded_t==1], term_1.E[expanded_t==1]
        terms_t_1.X[expanded_t!=1], terms_t_1.E[expanded_t!=1] = term_t.X[expanded_t!=1], term_t.E[expanded_t!=1]

        # Ignore padding & conditioning nodes in the averaging of the losses
        _, mask_X, mask_E = graph.apply_mask(orig=dense_true, z_t=z_t,
                                             atom_decoder=self.dataset_info.atom_decoder,
                                             bond_decoder=self.dataset_info.bond_decoder,
                                             mask_nodes=self.cfg.diffusion.mask_nodes,
                                             mask_edges=self.cfg.diffusion.mask_edges, 
                                             return_masks=True)
        
        terms_t_1 = helpers.mean_without_masked(graph_obj=terms_t_1, mask_X=mask_X, mask_E=mask_E,
                                                diffuse_edges=self.cfg.diffusion.diffuse_edges, 
                                                diffuse_nodes=self.cfg.diffusion.diffuse_nodes,
                                                avg_over_batch=avg_over_batch)
        kl_prior = helpers.mean_without_masked(graph_obj=kl_prior, mask_X=mask_X, mask_E=mask_E,
                                               diffuse_edges=self.cfg.diffusion.diffuse_edges, 
                                               diffuse_nodes=self.cfg.diffusion.diffuse_nodes, 
                                               avg_over_batch=avg_over_batch)

        elbo = terms_t_1 * self.cfg.diffusion.diffusion_steps + kl_prior

        return elbo

    def elbo(self, dense_true, avg_over_batch=True):
        """
        Computes an estimator for the variational lower bound.

        input:
           discrete_true: a batch of data in discrete format (batch_size, n, total_features)

        output:
            (float) the ELBO value of the given data batch.
       """

        # The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero.
        device = dense_true.X.device
        kl_prior = self.kl_prior(dense_true)
        kl_prior.to_device('cpu') # move everything to CPU to avoid memory issues when computing all the Lt terms

        # All the KL(q(z_{t-1}|z_t, z_0) || p(z_{t-1}|z_t)) terms in the diffusion model
        loss_all_t = self.compute_Lt_all(dense_true)
        for loss_t in loss_all_t:
            loss_t.E = loss_t.E * self.cfg.diffusion.lambda_test

        _, mask_X, mask_E = graph.apply_mask(orig=dense_true, z_t=dense_true,
                                             atom_decoder=self.dataset_info.atom_decoder,
                                             bond_decoder=self.dataset_info.bond_decoder,
                                             mask_nodes=self.cfg.diffusion.mask_nodes, 
                                             mask_edges=self.cfg.diffusion.mask_edges, 
                                             return_masks=True)
        mask_X = mask_X.to('cpu')
        mask_E = mask_E.to('cpu')
        # 4. Reconstruction loss
        # Compute L0 term : -log p (X, E, y | z_0) = reconstruction loss
        pos_encodings = self.get_pos_encodings_if_relevant(dense_true)
        loss_term_0 = self.compute_L1(dense_true=dense_true, pos_encodings=pos_encodings)
        loss_term_0.to_device('cpu')
        loss_term_0.E *= self.cfg.diffusion.lambda_test

        loss_0_per_dim = helpers.mean_without_masked(graph_obj=loss_term_0, 
                                                    mask_X=mask_X, mask_E=mask_E,
                                                    diffuse_edges=self.cfg.diffusion.diffuse_edges,
                                                    diffuse_nodes=self.cfg.diffusion.diffuse_nodes,
                                                    avg_over_batch=avg_over_batch)
        kl_prior_per_dim = helpers.mean_without_masked(graph_obj=kl_prior, mask_X=mask_X, mask_E=mask_E, 
                                                       diffuse_edges=self.cfg.diffusion.diffuse_edges, 
                                                       diffuse_nodes=self.cfg.diffusion.diffuse_nodes,
                                                       avg_over_batch=avg_over_batch)
        loss_t_per_dim = sum([helpers.mean_without_masked(graph_obj=loss_t, mask_X=mask_X, mask_E=mask_E,
                                                          diffuse_edges=self.cfg.diffusion.diffuse_edges, 
                                                          diffuse_nodes=self.cfg.diffusion.diffuse_nodes,
                                                          avg_over_batch=avg_over_batch) for loss_t in loss_all_t])

        if len(loss_all_t)==0: loss_t_per_dim = torch.empty_like(kl_prior_per_dim, dtype=torch.float)
        vb =  kl_prior_per_dim + loss_t_per_dim - loss_0_per_dim

        return vb, loss_t_per_dim, loss_0_per_dim

    @torch.no_grad()
    def sample_one_batch(self, device=None, n_samples=None, data=None, get_chains=False, get_true_rxns=False, inpaint_node_idx=None, inpaint_edge_idx=None):
        assert data!=None or n_samples!=None, 'You need to give either data or n_samples.'
        assert data!=None or self.cfg.diffusion.mask_nodes==None, 'You need to give data if the model is using a mask.'
        assert data!=None or get_true_rxns, 'You need to give data if you want to return true_rxns.'
   
        if data!=None:
            dense_data = data
            node_mask = dense_data.node_mask.to(device)
        else:
            n_nodes = self.node_dist.sample_n(n_samples, device)
            n_max = torch.max(n_nodes).item()
            arange = torch.arange(n_max, device=device).unsqueeze(0).expand(n_samples, -1)
            node_mask = arange < n_nodes.unsqueeze(1)
            dense_data = None

        pos_encodings = self.get_pos_encodings_if_relevant(dense_data) # precalculate the pos encodings, since they are the same at each step

        z_t = helpers.sample_from_noise(limit_dist=self.limit_dist, node_mask=node_mask, T=self.T)
        if data is not None: z_t = dense_data.get_new_object(X=z_t.X, E=z_t.E, y=z_t.y)
        z_t = graph.apply_mask(orig=z_t if data==None else dense_data, z_t=z_t,
                               atom_decoder=self.dataset_info.atom_decoder,
                               bond_decoder=self.dataset_info.bond_decoder, 
                               mask_nodes=self.cfg.diffusion.mask_nodes, 
                               mask_edges=self.cfg.diffusion.mask_edges)
    
        mask_X, mask_E = graph.fix_nodes_and_edges_by_idx(data=dense_data, node_idx=inpaint_node_idx,
                                                          edge_idx=inpaint_edge_idx)
        z_t.X[mask_X], z_t.E[mask_E] = dense_data.X[mask_X], dense_data.E[mask_E]
        z_t.X[~mask_X], z_t.E[~mask_E] = z_t.X[~mask_X], z_t.E[~mask_E]
        
        if not self.cfg.diffusion.diffuse_edges and data!=None: z_t.E = dense_data.E.clone()
        if not self.cfg.diffusion.diffuse_nodes and data!=None: z_t.X = dense_data.X.clone()

        if get_chains: sample_chains, prob_s_chains, pred_0_chains = [], [], []
        print(f'self.T {self.T}\n')
        print(f'self.cfg.diffusion.diffusion_steps_eval {self.cfg.diffusion.diffusion_steps_eval}\n')
        assert self.T % self.cfg.diffusion.diffusion_steps_eval == 0, 'diffusion_steps_eval should be divisible by diffusion_steps'
        all_steps = list(range(self.cfg.diffusion.diffusion_steps_eval+1))
        eval_step_size = self.T // self.cfg.diffusion.diffusion_steps_eval
        t_steps = all_steps[1:]
        s_steps = all_steps[:-1]

        for i in reversed(range(len(t_steps))):
            t_int = t_steps[i]
            s_int = s_steps[i]

            s_array = s_int * torch.ones((z_t.X.shape[0], 1)).long().to(device)
            t_array = t_int * torch.ones((z_t.X.shape[0], 1)).long().to(device)

            z_t.y = t_array.clone().float() * eval_step_size 

            # compute p(x | z_t)
            pred = self.forward(z_t=z_t, pos_encodings=pos_encodings)
            
            # compute p(z_s | z_t) (denoiser)
            prob_s = helpers.get_p_zs_given_zt(transition_model=self.transition_model_eval, t_array=t_array, pred=pred, z_t=z_t, return_prob=True)

            if not self.cfg.diffusion.diffuse_edges and data!=None: prob_s.E = dense_data.E.clone()
            if not self.cfg.diffusion.diffuse_nodes and data!=None: prob_s.X = dense_data.X.clone()

            # save chains if relevant
            if get_chains and (s_int%self.cfg.train.log_every_t==0 or s_int==self.T-1): 
                # turn pred from logits to probabilities for plotting 
                pred.X = F.softmax(pred.X, dim=-1)
                pred.E = F.softmax(pred.E, dim=-1)
                pred.X[...,-1] = 0.
                pred.X /= pred.X.sum(-1).unsqueeze(-1)
            
                pred = graph.apply_mask(orig=z_t if data==None else dense_data, z_t=pred,
                                        atom_decoder=self.dataset_info.atom_decoder,
                                        bond_decoder=self.dataset_info.bond_decoder, 
                                        mask_nodes=self.cfg.diffusion.mask_nodes, 
                                        mask_edges=self.cfg.diffusion.mask_edges)
                mask_X, mask_E = graph.fix_nodes_and_edges_by_idx(data=dense_data, node_idx=inpaint_node_idx, edge_idx=inpaint_edge_idx)
                pred.X[mask_X], pred.E[mask_E] = dense_data.X[mask_X], dense_data.E[mask_E]
                pred.X[~mask_X], pred.E[~mask_E] = pred.X[~mask_X], pred.E[~mask_E]
                if not self.cfg.diffusion.diffuse_edges and data!=None: pred.E = dense_data.E.clone()
                if not self.cfg.diffusion.diffuse_nodes and data!=None: pred.X = dense_data.X.clone()
                
                # save p(z_s | z_t)
                prob_s = graph.apply_mask(orig=z_t if data==None else dense_data, z_t=prob_s,
                                            atom_decoder=self.dataset_info.atom_decoder,
                                            bond_decoder=self.dataset_info.bond_decoder, 
                                            mask_nodes=self.cfg.diffusion.mask_nodes, 
                                            mask_edges=self.cfg.diffusion.mask_edges)
                mask_X, mask_E = graph.fix_nodes_and_edges_by_idx(data=dense_data, node_idx=inpaint_node_idx, edge_idx=inpaint_edge_idx)
                prob_s.X[mask_X], prob_s.E[mask_E] = dense_data.X[mask_X], dense_data.E[mask_E]
                prob_s.X[~mask_X], prob_s.E[~mask_E] = prob_s.X[~mask_X], prob_s.E[~mask_E]
                if not self.cfg.diffusion.diffuse_edges and data!=None: prob_s.E = dense_data.E.clone()
                if not self.cfg.diffusion.diffuse_nodes and data!=None: prob_s.X = dense_data.X.clone()

                sample_chains.append((s_int+1, z_t.mask(z_t.node_mask, collapse=True)))
                prob_s_chains.append((s_int, prob_s))
                pred_0_chains.append((s_int, pred))
                
            # sample from p(z_s | z_t)
            z_t = helpers.sample_discrete_features(prob=prob_s)

            # sanity check
            assert (z_t.E==torch.transpose(z_t.E, 1, 2)).all(), 'E is not symmetric.'
            
            z_t = graph.apply_mask(orig=z_t if data==None else dense_data, z_t=z_t,
                                    atom_decoder=self.dataset_info.atom_decoder,
                                    bond_decoder=self.dataset_info.bond_decoder, 
                                    mask_nodes=self.cfg.diffusion.mask_nodes, 
                                    mask_edges=self.cfg.diffusion.mask_edges)
            mask_X, mask_E = graph.fix_nodes_and_edges_by_idx(data=dense_data, node_idx=inpaint_node_idx,
                                                                edge_idx=inpaint_edge_idx)
            z_t.X[mask_X], z_t.E[mask_E] = dense_data.X[mask_X], dense_data.E[mask_E]
            z_t.X[~mask_X], z_t.E[~mask_E] = z_t.X[~mask_X], z_t.E[~mask_E]

            if not self.cfg.diffusion.diffuse_edges and data!=None: z_t.E = dense_data.E.clone()
            if not self.cfg.diffusion.diffuse_nodes and data!=None: z_t.X = dense_data.X.clone()
        
        if get_chains:
            sample = copy.deepcopy(z_t)
            sample_chains.append((0, sample.mask(sample.node_mask, collapse=True)))
    
        if get_true_rxns:
            return (z_t.mask(sample.node_mask, collapse=True), sample_chains, prob_s_chains, pred_0_chains, dense_data)

        if get_chains:
            return (z_t.mask(sample.node_mask, collapse=True), sample_chains, prob_s_chains, pred_0_chains)
            
        return z_t

    def compute_extra_data(self, noisy_data):
        """ At every training step (after adding noise) and step in sampling, compute extra information and append to
            the network input. """
        
        from src.neuralnet.extra_features import ExtraFeatures
        from src.neuralnet.extra_features_molecular import ExtraMolecularFeatures
        self.extra_features_calculator = ExtraFeatures('all', self.dataset_info)
        self.extra_molecular_features_calculator = ExtraMolecularFeatures(self.dataset_info)

        device = noisy_data.X.device
        # Doing the calculations on CPU to avoid numerical issues with AMD GPUs
        X_, E_, y_ = self.extra_features_calculator(noisy_data.E.to('cpu'), noisy_data.node_mask.to('cpu'))
        extra_features = noisy_data.get_new_object(X=X_.to(device), E=E_.to(device), y=y_.to(device))
        X_, E_, y_ = self.extra_molecular_features_calculator(noisy_data.X.to('cpu'), noisy_data.E.to('cpu'))
        extra_molecular_features = noisy_data.get_new_object(X=X_.to(device), E=E_.to(device), y=y_.to(device))

        extra_X = torch.cat((noisy_data.X, extra_features.X, extra_molecular_features.X), dim=-1)
        extra_E = torch.cat((noisy_data.E, extra_features.E, extra_molecular_features.E), dim=-1)
        extra_y = torch.cat((noisy_data.y, extra_features.y, extra_molecular_features.y), dim=-1)
        
        extra_z = noisy_data.get_new_object(X=extra_X, E=extra_E, y=extra_y)
        
        return extra_z