import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import time
import os
import wandb
import math

from models.transformer import Transformer
from diffusion.noise_schedule import NoiseScheduleDiscrete, MarginalTransition

from diffusion import diffusion_utils
from metrics.train_loss import TrainLossDiscrete
from metrics.train_loss import ValidateLoss
from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
import utils
from torch_geometric.utils import to_dense_batch

class GraphDiffusionTransformer(pl.LightningModule):
    def __init__(
        self,
        cfg,
        dataset_infos,
        sampling_metrics,
        visualization_tools,
        tokenizer
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['sampling_metrics'])
        self.test_only = cfg.general.test_only
        
        self.dataset_info = dataset_infos

        nodes_dist = self.dataset_info.nodes_dist
        max_node_type = self.dataset_info.max_node_type
        max_bond_type = self.dataset_info.max_bond_type
        max_position_type = self.dataset_info.max_position_type

        self.cfg = cfg
        self.name = cfg.general.name
        self.T = cfg.model.diffusion_steps
        self.guide_scale = cfg.model.guide_scale
        self.no_context_train = cfg.train.no_context_train

        self.node_dist = nodes_dist
        self.max_node_type = max_node_type
        self.max_bond_type = max_bond_type
        self.max_position_type = max_position_type

        self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train, tokenizer=tokenizer)
        self.validate_loss = ValidateLoss(tokenizer=tokenizer)

        self.noise_weight = cfg.train.noise_weight
        self.drop_context = cfg.train.drop_context
        self.token_to_atom_count = self.dataset_info.token_to_atom_count

        self.val_nll = NLL()
        self.val_X_kl = SumExceptBatchKL()
        self.val_E_kl = SumExceptBatchKL()
        self.val_X_logp = SumExceptBatchMetric()
        self.val_E_logp = SumExceptBatchMetric()

        self.test_nll = NLL()
        self.test_X_kl = SumExceptBatchKL()
        self.test_E_kl = SumExceptBatchKL()
        self.test_X_logp = SumExceptBatchMetric()
        self.test_E_logp = SumExceptBatchMetric()

        self.visualization_tools = visualization_tools
        self.max_n_nodes = min(self.dataset_info.max_n_nodes, cfg.train.context_length)

        self.sampling_metrics = sampling_metrics
        self.sample_size = int(1 * self.cfg.train.batch_size)
        self.final_sample_size = self.cfg.general.get(
            "final_sample_size", self.sample_size
        )

        self.model = Transformer(
            max_n_nodes=self.max_n_nodes,
            hidden_size=cfg.model.hidden_size,
            depth=cfg.model.depth,
            num_heads=cfg.model.num_heads,
            mlp_ratio=cfg.model.mlp_ratio,
            drop_condition=cfg.model.drop_condition,
            X_dim=max_node_type, 
            E_dim=max_bond_type,
            y_dim=None,
            pos_dim=self.max_position_type,
            )
        self.model_dtype = self.dtype

        self.noise_schedule = NoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule,
                                                    timesteps=cfg.model.diffusion_steps)


        x_marginals = self.dataset_info.node_types.float()
        e_marginals = self.dataset_info.edge_types.float()
        pos_marginals = self.dataset_info.pos_types.float()

        x_marginals = x_marginals / x_marginals.sum()
        e_marginals = e_marginals / e_marginals.sum()
        pos_marginals = pos_marginals / pos_marginals.sum()
        self.pos_marginals = pos_marginals
        self.x_marginals = x_marginals
        self.e_marginals = e_marginals

        xe_conditions = self.dataset_info.co_occur_dist
        xe_conditions = xe_conditions.to_dense().float()
        # if constant
        # xe_conditions = torch.ones_like(xe_conditions)
        
        xe_conditions = (xe_conditions + xe_conditions.transpose(0, 1)) / 2
        xe_conditions = xe_conditions.sum(dim=1)
        ex_conditions = xe_conditions.t()
        epsilon = 1e-10
        xe_conditions = xe_conditions / (xe_conditions.sum(dim=-1, keepdim=True) + epsilon)
        ex_conditions = ex_conditions / (ex_conditions.sum(dim=-1, keepdim=True) + epsilon)

    
        self.transition_model = MarginalTransition(x_marginals=x_marginals, 
                                                e_marginals=e_marginals, 
                                                xe_conditions=xe_conditions,
                                                ex_conditions=ex_conditions,
                                                pos_marginals=pos_marginals,
                                                n_nodes=self.max_n_nodes)

        self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None, pos=pos_marginals)

        self.start_epoch_time = None
        self.train_iterations = None
        self.val_iterations = None
        self.log_every_steps = cfg.general.log_every_steps
        self.number_chain_steps = cfg.general.number_chain_steps

        self.best_val_nll = 1e8
        self.val_counter = 0
        self.batch_size = self.cfg.train.batch_size

    def forward(self, noisy_target, dense_context_dict=None, unconditioned=None):
        if unconditioned is None:
            rand = torch.rand(1).item()
            unconditioned = self.no_context_train or (self.current_epoch > -1 and rand < self.drop_context)

        if dense_context_dict is not None and not unconditioned:
            relation_to_target = dense_context_dict['relation_to_target']
            indicator_node_to_mol = dense_context_dict['indicator_node_to_mol']
            indicator_mol_to_ctx = dense_context_dict['indicator_mol_to_ctx']
            node_mask_allmol = dense_context_dict['node_mask_allmol']
            node_mask_target = dense_context_dict['node_mask_target']
            dense_data = dense_context_dict['dense_data']

            noisy_data = utils.replace_first_indices(dense_data, indicator_mol_to_ctx, noisy_target)
            indicator_node_to_ctx = indicator_mol_to_ctx[indicator_node_to_mol]
            relation_to_target = relation_to_target[indicator_node_to_mol].view(-1, 1)

            noisy_X, node_mask_context = to_dense_batch(x=noisy_data.X[node_mask_allmol], batch=indicator_node_to_ctx)
            bs, n, _, de = noisy_data.E.size()
            noisy_E = noisy_data.E.view(bs, n, n * self.max_bond_type)[node_mask_allmol]
            noisy_E, _ = to_dense_batch(x=noisy_E, batch=indicator_node_to_ctx)
            noisy_pos = noisy_data.pos.view(bs, n, n * self.max_position_type)[node_mask_allmol]
            noisy_pos, _ = to_dense_batch(x=noisy_pos, batch=indicator_node_to_ctx)
            
            relation_to_target, _ = to_dense_batch(x=relation_to_target, batch=indicator_node_to_ctx)
            noisy_X, noisy_E, noisy_pos = noisy_X.float(), noisy_E.float(), noisy_pos.float()
            node_mask_dict = {'target': node_mask_target, 'all': node_mask_context}
            batch_max_n_nodes = n
        else:
            noisy_X, noisy_E, noisy_pos = noisy_target['X_t'].to(self.model_dtype), noisy_target['E_t'].to(self.model_dtype), noisy_target['pos_t'].to(self.model_dtype)
            bs, n, _, de = noisy_E.size()
            noisy_E = noisy_E.view(bs, n, n * self.max_bond_type)
            noisy_pos = noisy_pos.view(bs, n, n * self.max_position_type)
            # raise ValueError
            node_mask_dict = {'target': noisy_target['node_mask'], 'all': noisy_target['node_mask']}
            relation_to_target = torch.zeros(noisy_X.size(0), noisy_X.size(1), 1).type_as(noisy_X)
            batch_max_n_nodes = n


        target_n_nodes = noisy_target['X_t'].size(1)
        t = noisy_target['t'] * self.T
        relation_to_target[~node_mask_dict['all']] = 1

        pred_X, pred_E, pred_pos = self.model(noisy_X, noisy_E, noisy_pos, relation_to_target, node_mask_dict, target_n_nodes, t=t)
        pred = utils.PlaceHolder(X=pred_X, E=pred_E, y=None, pos=pred_pos)
        
        return pred

    def configure_optimizers(self):
        params = self.parameters()
        optimizer = torch.optim.AdamW(params, lr=self.cfg.train.lr, amsgrad=True, weight_decay=self.cfg.train.weight_decay, eps=1e-6)
        print('config optimizer', 'lr', self.cfg.train.lr)
        return optimizer

    def on_fit_start(self) -> None:
        self.train_iterations = self.trainer.datamodule.training_iterations
        print("on fit train iteration:", self.train_iterations)
        print(
            "Size of the input features Xdim {}, Edim {}, pos_dim {} self.max_n_nodes {}".format(
                self.max_node_type, self.max_bond_type, self.max_position_type, self.max_n_nodes
            )
        )
        if self.local_rank == 0:
            utils.setup_wandb(self.cfg)

    def batch_data_process(self, mol_data, indicator_mol_to_ctx, indicator_node_to_mol, relation_to_target, apply_noise=True):
        mol_x = F.one_hot(mol_data.x, num_classes=self.max_node_type).to(self.device)
        mol_edge_attr = F.one_hot(mol_data.edge_attr, num_classes=self.max_bond_type).to(self.device)
        mol_edge_pos = F.one_hot(mol_data.edge_pos, num_classes=self.max_position_type).to(self.device)
        mol_edge_index = mol_data.edge_index.to(self.device)
        mol_y = mol_data.y.to(self.device)
        indicator_mol_to_ctx = indicator_mol_to_ctx.to(self.device)
        indicator_node_to_mol = indicator_node_to_mol.to(self.device)
        relation_to_target = relation_to_target.to(self.device)

        dense_data, node_mask_allmol = utils.to_dense(mol_x, mol_edge_index, mol_edge_attr, indicator_node_to_mol, y=mol_y, edge_pos=mol_edge_pos)
        dense_data = dense_data.mask(node_mask_allmol)
        target_data, node_mask_target = utils.select_first_indices(dense_data, node_mask_allmol, indicator_mol_to_ctx)

        dense_context_dict = {
            'indicator_node_to_mol': indicator_node_to_mol,
            'indicator_mol_to_ctx': indicator_mol_to_ctx,
            'relation_to_target': relation_to_target,
            'node_mask_allmol': node_mask_allmol,
            'dense_data': dense_data,
            'node_mask_target': node_mask_target,
            'target_data': target_data
        }

        if apply_noise:
            noisy_target = self.apply_noise(target_data.X, target_data.E, target_data.y, target_data.pos, node_mask_target)
            return noisy_target, dense_context_dict
        else:
            return None, dense_context_dict
        
    def training_step(self, data, i):
        mol_data = data['molecular_graphs']
        indicator_mol_to_ctx = data['context_indicators']
        indicator_node_to_mol = mol_data.batch
        relation_to_target = data['relations']

        noisy_target, dense_context_dict = self.batch_data_process(mol_data, indicator_mol_to_ctx, indicator_node_to_mol, relation_to_target)        
        pred = self.forward(noisy_target, dense_context_dict)

        alpha_t_bar = noisy_target['alpha_t_bar']
        if self.noise_weight == 'noiser_lower':
            noise_weight = alpha_t_bar
        elif self.noise_weight == 'noiser_higher':
            noise_weight = 1 - alpha_t_bar
        else:
            noise_weight = None
        
        self.limit_dist = self.limit_dist.to(self.device)
        self.token_to_atom_count = self.token_to_atom_count.to(self.device)

        target_data, node_mask_target = dense_context_dict['target_data'], dense_context_dict['node_mask_target']
        if False:
            indicator_node_to_mol = dense_context_dict['indicator_node_to_mol']
            indicator_mol_to_ctx = dense_context_dict['indicator_mol_to_ctx']
            node_mask_allmol = dense_context_dict['node_mask_allmol']
            dense_data = dense_context_dict['dense_data']
            
            indicator_node_to_ctx = indicator_mol_to_ctx[indicator_node_to_mol]
            true_X, _ = to_dense_batch(x=dense_data.X[node_mask_allmol], batch=indicator_node_to_ctx)
            true_X = true_X.view(-1, self.max_node_type)
            
            bs, n, _, de = dense_data.E.size()
            true_E = dense_data.E.view(bs, n, n * self.max_bond_type)[node_mask_allmol]
            true_E, _ = to_dense_batch(x=true_E, batch=indicator_node_to_ctx)
            true_E = true_E.view(-1, self.max_bond_type)
            true_pos = dense_data.pos.view(bs, n, n * self.max_position_type)[node_mask_allmol]
            true_pos, _ = to_dense_batch(x=true_pos, batch=indicator_node_to_ctx)
            true_pos = true_pos.view(-1, self.max_position_type)
            target_data = utils.PlaceHolder(X=true_X, E=true_E, pos=true_pos, y=dense_data.y)

            node_mask_target = node_mask_allmol

        loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, masked_pred_pos=pred.pos, masked_pred_y=pred.y,
                            true_X=target_data.X, true_E=target_data.E, true_pos=target_data.pos, true_y=target_data.y,
                            node_mask=node_mask_target, noise_weight=noise_weight, limit_dist=self.limit_dist,
                            token_to_atom_count=self.token_to_atom_count,
                            log=i % self.log_every_steps == 0)

        self.log(f'loss', loss, batch_size=target_data.X.size(0), sync_dist=True)
        return {'loss': loss}
   
    @torch.no_grad()
    def validation_step(self, data, i):
        mol_data = data['molecular_graphs']
        indicator_mol_to_ctx = data['context_indicators']
        indicator_node_to_mol = mol_data.batch
        relation_to_target = data['relations']

        noisy_target, dense_context_dict = self.batch_data_process(mol_data, indicator_mol_to_ctx, indicator_node_to_mol, relation_to_target)
        target_data, node_mask_target = dense_context_dict['target_data'], dense_context_dict['node_mask_target']
        pred = self.forward(noisy_target, dense_context_dict)

        nll = self.compute_val_loss(pred, noisy_target, dense_context_dict, target_data.X, target_data.E, target_data.pos, node_mask_target, test=False)
        self.log(f'valid_nll', nll, batch_size=indicator_mol_to_ctx[-1]+1, sync_dist=True)
        return {'loss': nll}

    @torch.no_grad()
    def test_step(self, data, i):
        mol_data = data['molecular_graphs']
        indicator_mol_to_ctx = data['context_indicators']
        indicator_node_to_mol = mol_data.batch
        relation_to_target = data['relations']

        noisy_target, dense_context_dict = self.batch_data_process(mol_data, indicator_mol_to_ctx, indicator_node_to_mol, relation_to_target)
        target_data, node_mask_target = dense_context_dict['target_data'], dense_context_dict['node_mask_target']
        pred = self.forward(noisy_target, dense_context_dict)

        nll = self.compute_val_loss(pred, noisy_target, dense_context_dict, target_data.X, target_data.E, target_data.pos, node_mask_target, test=True)
        self.log(f'test_nll', nll, batch_size=indicator_mol_to_ctx[-1]+1, sync_dist=True)
        return {'loss': nll}

    def on_train_epoch_start(self) -> None:
        # Get current learning rate
        current_lr = self.optimizers().param_groups[0]['lr']
        print(f"Current learning rate: {current_lr:.2e}")
        
        self.start_epoch_time = time.time()
        self.train_loss.reset()
    
    def on_train_epoch_end(self) -> None:
        log = True
        self.train_loss.log_epoch_metrics(
            self.current_epoch, self.start_epoch_time, log, finished=True
        )

    def on_validation_epoch_start(self) -> None:
        self.val_nll.reset()
        self.val_X_kl.reset()
        self.val_E_kl.reset()
        self.val_X_logp.reset()
        self.val_E_logp.reset()
        self.sampling_metrics.reset()

        self.validate_loss.reset()

    def on_validation_epoch_end(self) -> None:
        metrics = [
            self.val_nll.compute(),
            self.val_X_kl.compute() * self.T,
            self.val_E_kl.compute() * self.T,
            self.val_X_logp.compute(),
            self.val_E_logp.compute(),
        ]
        if wandb.run:
            wandb.log(
                {
                    "val/epoch_NLL": metrics[0],
                    "val/X_kl": metrics[1],
                    "val/E_kl": metrics[2],
                },
                commit=True,
            )

        print(
            f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
            f"Val Edge type KL: {metrics[2] :.2f}",
            "Val loss: %.2f \t Best :  %.2f\n" % (metrics[0], self.best_val_nll),
        )
        self.train_loss.log_epoch_metrics(
            self.current_epoch, self.start_epoch_time, False
        )

        # Log val nll with default Lightning logger, so it can be monitored by checkpoint callback
        self.log("valid_nll", metrics[0], sync_dist=True)

        if metrics[0] < self.best_val_nll:
            self.best_val_nll = metrics[0]

        self.val_counter += 1

        if (
            self.val_counter % self.cfg.general.sample_every_val == 0
            and self.val_counter > 1
        ):
            start = time.time()
            samples, targets,target_ids = [], [], []
            samples_left_to_generate = self.cfg.general.samples_to_generate
            valid_loader = self.trainer.datamodule.get_sampling_loader(self.sample_size, test=False)
            valid_iter = iter(valid_loader)

            while samples_left_to_generate > 0:
                try:
                    batch_data = next(valid_iter)
                except StopIteration:
                    valid_iter = iter(valid_loader)
                    batch_data = next(valid_iter)
                batch_size = batch_data['context_indicators'][-1] + 1
                to_generate = min(samples_left_to_generate, batch_size)

                current_sample, current_target = self.sample_batch(batch_data)
                samples = samples + current_sample[: to_generate]
                targets = targets + current_target[: to_generate]
                target_ids = target_ids + batch_data['target_id'][: to_generate]
                samples_left_to_generate -= to_generate

            print(f"Computing sampling metrics", " ...")
            all_generation, all_target = self.sampling_metrics(
                samples,
                targets,
                target_ids,
                self.trainer.datamodule.target_to_similar_smiles,
                self.name,
                self.current_epoch,
                val_counter=-1,
                test=False,
            )
            print(f"Done. Sampling took {time.time() - start:.2f} seconds\n")
            current_path = os.getcwd()
            result_path = os.path.join(
                current_path, f"graphs/{self.name}/epoch{self.current_epoch}_b0/"
            )

            self.visualization_tools.visualize_by_smiles(
                result_path, all_generation, all_target, self.cfg.general.samples_to_save
            )
            self.sampling_metrics.reset()

    def on_test_epoch_start(self) -> None:
        print("Starting test...")
        self.test_nll.reset()
        self.test_X_kl.reset()
        self.test_E_kl.reset()
        self.test_X_logp.reset()
        self.test_E_logp.reset()

        self.validate_loss.reset()
    
    @torch.no_grad()
    def on_test_epoch_end(self) -> None:
        """Measure likelihood on a test set and compute stability metrics."""
        metrics = [
            self.test_nll.compute(),
            self.test_X_kl.compute(),
            self.test_E_kl.compute(),
            self.test_X_logp.compute(),
            self.test_E_logp.compute(),
        ]
        if wandb.run:
            wandb.log(
                {
                    "test/epoch_NLL": metrics[0],
                    "test/X_kl": metrics[1],
                    "test/E_kl": metrics[2],
                },
                commit=True,
            )

        print(
            f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
            f"Test Edge type KL: {metrics[2] :.2f}",
        )

        samples, targets, target_ids = [], [], []
        total_sample = self.cfg.general.final_model_samples_to_generate if self.cfg.general.final_model_samples_to_generate else self.cfg.general.samples_to_generate
        samples_left_to_generate = total_sample

        test_loader = self.trainer.datamodule.get_sampling_loader(self.final_sample_size, test=True)
        test_iter = iter(test_loader)

        while samples_left_to_generate > 0:
            print(
                f"samples left to generate: {samples_left_to_generate}/"
                f"{total_sample}",
                end="",
                flush=True,
            )

            try:
                batch_data = next(test_iter)
            except StopIteration:
                test_iter = iter(test_loader)
                batch_data = next(test_iter)
            batch_size = batch_data['context_indicators'][-1] + 1
            to_generate = min(samples_left_to_generate, batch_size)
            
            current_sample, current_target = self.sample_batch(batch_data)
            samples = samples + current_sample[: to_generate]
            targets = targets + current_target[: to_generate]
            target_ids = target_ids + batch_data['target_id'][: to_generate]
            samples_left_to_generate -= to_generate
        
        print(f"final Computing sampling metrics...")
        self.sampling_metrics.reset()
        self.sampling_metrics(
            samples,
            targets,
            target_ids,
            self.trainer.datamodule.target_to_similar_smiles,
            self.name,
            self.current_epoch,
            self.val_counter,
            test=True,
        )
        self.sampling_metrics.reset()
        print(f"Done.")

    def kl_prior(self, X, E, pos, node_mask):
        """Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).

        This is essentially a lot of work for something that is in practice negligible in the loss. However, you
        compute it so that you see it when you've made a mistake in your noise schedule.
        """
        # Compute the last alpha value, alpha_T.
        Ts = self.T * torch.ones((X.size(0), 1), device=X.device)
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_int=Ts)  # (bs, 1)
        
        Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device, n_nodes=X.size(1))
        
        bs, n, d = X.shape
        X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1).to(self.model_dtype)
        prob_all = X_all @ Qtb.X
        probX = prob_all[:, :, :self.max_node_type]
        probE = prob_all[:, :, self.max_node_type:].reshape((bs, n, n, -1))
        prob_pos = pos.to(probE.dtype) @ Qtb.pos.unsqueeze(1)

        assert probX.shape == X.shape

        limit_X = self.limit_dist.X[None, None, :].expand(bs, n, -1).type_as(probX)
        limit_E = self.limit_dist.E[None, None, None, :].expand(bs, n, n, -1).type_as(probE)

        # Make sure that masked rows do not contribute to the loss
        limit_dist_X, limit_dist_E, probX, probE = diffusion_utils.mask_distributions(true_X=limit_X.clone(),
                                                                                      true_E=limit_E.clone(),
                                                                                      pred_X=probX,
                                                                                      pred_E=probE,
                                                                                      node_mask=node_mask)

        kl_distance_X = F.kl_div(input=probX.clamp_min(1e-9).log(), target=limit_dist_X, reduction='none')
        kl_distance_E = F.kl_div(input=probE.clamp_min(1e-9).log(), target=limit_dist_E, reduction='none')

        return diffusion_utils.sum_except_batch(kl_distance_X) + \
               diffusion_utils.sum_except_batch(kl_distance_E)

    def compute_Lt(self, X, E, pos, pred, noisy_data, node_mask, test):
        pred_probs_X = F.softmax(pred.X, dim=-1)
        pred_probs_E = F.softmax(pred.E, dim=-1)

        n_nodes = X.size(1)
        Qtb = self.transition_model.get_Qt_bar(noisy_data['alpha_t_bar'], self.device, n_nodes=n_nodes)
        Qsb = self.transition_model.get_Qt_bar(noisy_data['alpha_s_bar'], self.device, n_nodes=n_nodes)
        Qt = self.transition_model.get_Qt(noisy_data['beta_t'], self.device, n_nodes=n_nodes)

        # Compute distributions to compare with KL
        bs, n, d = X.shape
        X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1).to(self.model_dtype)
        Xt_all = torch.cat([noisy_data['X_t'], noisy_data['E_t'].reshape(bs, n, -1)], dim=-1).to(self.model_dtype)
        pred_probs_all = torch.cat([pred_probs_X, pred_probs_E.reshape(bs, n, -1)], dim=-1).to(self.model_dtype)

        prob_true = diffusion_utils.posterior_distributions(X=X_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.max_node_type)
        prob_true.E = prob_true.E.reshape((bs, n, n, -1))
        prob_pred = diffusion_utils.posterior_distributions(X=pred_probs_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.max_node_type)
        prob_pred.E = prob_pred.E.reshape((bs, n, n, -1))

        # Reshape and filter masked rows
        prob_true_X, prob_true_E, prob_pred.X, prob_pred.E = diffusion_utils.mask_distributions(true_X=prob_true.X,
                                                                                                true_E=prob_true.E,
                                                                                                pred_X=prob_pred.X,
                                                                                                pred_E=prob_pred.E,
                                                                                                node_mask=node_mask)
        kl_x = (self.test_X_kl if test else self.val_X_kl)(prob_true.X, torch.log(prob_pred.X))
        kl_e = (self.test_E_kl if test else self.val_E_kl)(prob_true.E, torch.log(prob_pred.E))

        return self.T * (kl_x + kl_e)

    def reconstruction_logp(self, t, X, E, pos, dense_context_dict, node_mask):
        # Compute noise values for t = 0.
        t_zeros = torch.zeros_like(t)
        beta_0 = self.noise_schedule(t_zeros)
        Q0 = self.transition_model.get_Qt(beta_0, self.device, n_nodes=X.size(1))
        
        bs, n, d = X.shape
        X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1).to(self.model_dtype)
        prob_all = X_all @ Q0.X
        probX0 = prob_all[:, :, :self.max_node_type]
        probE0 = prob_all[:, :, self.max_node_type:].reshape((bs, n, n, -1))
        prob_pos = pos.to(probE0.dtype) @ Q0.pos.unsqueeze(1)

        sampled0 = diffusion_utils.sample_discrete_features(probX=probX0, probE=probE0, prob_pos=prob_pos, node_mask=node_mask)

        X0 = F.one_hot(sampled0.X, num_classes=self.max_node_type)
        E0 = F.one_hot(sampled0.E, num_classes=self.max_bond_type)
        pos0 = F.one_hot(sampled0.pos, num_classes=self.max_position_type)

        assert (X.shape == X0.shape) and (E.shape == E0.shape)
        sampled_0 = utils.PlaceHolder(X=X0, E=E0, pos=pos0).mask(node_mask)

        # Predictions
        noisy_data = {'X_t': sampled_0.X, 'E_t': sampled_0.E, 'y_t': sampled_0.y, 'pos_t': sampled_0.pos, 'node_mask': node_mask,
                      't': torch.zeros(X0.shape[0], 1).type_as(X0)}
        pred0 = self.forward(noisy_data, dense_context_dict)

        # Normalize predictions
        probX0 = F.softmax(pred0.X, dim=-1)
        probE0 = F.softmax(pred0.E, dim=-1)
        proby0 = None
        prob_pos0 = F.softmax(pred0.pos, dim=-1)

        # Set masked rows to arbitrary values that don't contribute to loss
        probX0[~node_mask] = torch.ones(self.max_node_type).type_as(probX0)
        probE0[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))] = torch.ones(self.max_bond_type).type_as(probE0)
        prob_pos0[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))] = torch.ones(self.max_position_type).type_as(prob_pos0)

        diag_mask = torch.eye(probE0.size(1)).type_as(probE0).bool()
        diag_mask = diag_mask.unsqueeze(0).expand(probE0.size(0), -1, -1)
        probE0[diag_mask] = torch.ones(self.max_bond_type).type_as(probE0)
        prob_pos0[diag_mask] = torch.ones(self.max_position_type).type_as(prob_pos0)

        return utils.PlaceHolder(X=probX0, E=probE0, pos=prob_pos0)

    # def apply_noise(self, X, E, y, d node_mask):
    def apply_noise(self, X, E, y, pos, node_mask):
        """ Sample noise and apply it to the data. """
        # Sample a timestep t.
        # When evaluating, the loss for t=0 is computed separately
        lowest_t = 0 if self.training else 1
        t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1)
        s_int = t_int - 1

        t_float = t_int / self.T
        s_float = s_int / self.T

        # beta_t and alpha_s_bar are used for denoising/loss computation
        beta_t = self.noise_schedule(t_normalized=t_float)                         # (bs, 1)
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)      # (bs, 1)
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)      # (bs, 1)
        
        Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device, n_nodes=X.size(1))  # (bs, dx_in, dx_out), (bs, de_in, de_out)

        bs, n, d = X.shape
        X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1).to(self.model_dtype)
        prob_all = X_all @ Qtb.X
        probX = prob_all[:, :, :self.max_node_type]
        probE = prob_all[:, :, self.max_node_type:].reshape(bs, n, n, -1)
        prob_pos = pos.to(probE.dtype) @ Qtb.pos.unsqueeze(1)

        sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, prob_pos=prob_pos, node_mask=node_mask)

        X_t = F.one_hot(sampled_t.X, num_classes=self.max_node_type)
        E_t = F.one_hot(sampled_t.E, num_classes=self.max_bond_type)
        pos_t = F.one_hot(sampled_t.pos, num_classes=self.max_position_type)
        assert (X.shape == X_t.shape) and (E.shape == E_t.shape)

        y_t = y
        z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t, pos=pos_t).type_as(X_t).mask(node_mask)

        noisy_data = {'t_int': t_int, 't': t_float, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar,
                      'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'pos_t': z_t.pos, 'node_mask': node_mask}
        return noisy_data

    def compute_val_loss(self, pred, noisy_data, dense_context_dict, X, E, pos, node_mask, test=False):
        """Computes an estimator for the variational lower bound.
           pred: (batch_size, n, total_features)
           noisy_data: dict
           X, E, y : (bs, n, dx),  (bs, n, n, de), (bs, dy)
           node_mask : (bs, n)
           Output: nll (size 1)
       """
        t = noisy_data['t']

        # 1.
        N = node_mask.sum(1).long()
        log_pN = self.node_dist.log_prob(N)

        ### TODO: add pos loss
        # 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero.
        kl_prior = self.kl_prior(X, E, pos, node_mask)

        # 3. Diffusion loss
        loss_all_t = self.compute_Lt(X, E, pos, pred, noisy_data, node_mask, test)

        # 4. Reconstruction loss
        # Compute L0 term : -log p (X, E, y | z_0) = reconstruction loss
        prob0 = self.reconstruction_logp(t, X, E, pos, dense_context_dict, node_mask)

        eps = 1e-8
        loss_term_0 = self.val_X_logp(X * (prob0.X+eps).log()) + self.val_E_logp(E * (prob0.E+eps).log())

        # Combine terms
        nlls = - log_pN + kl_prior + loss_all_t - loss_term_0
        assert len(nlls.shape) == 1, f'{nlls.shape} has more than only batch dim.'

        # Update NLL metric object and return batch nll
        nll = (self.test_nll if test else self.val_nll)(nlls)        # Average over the batch
        # print('nll',nll,'nlls', nlls, 'log_pN', log_pN, 'kl_prior', kl_prior, 'loss_all_t', loss_all_t, 'loss_term_0', loss_term_0)

        return nll
    
    @torch.no_grad()
    def sample_batch(self, batch_data, init_node='reconstruction'):
        """
        :param batch_data: batch data
        :param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes
        :return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions)
        """
        mol_data = batch_data['molecular_graphs']
        indicator_mol_to_ctx = batch_data['context_indicators']
        indicator_node_to_mol = mol_data.batch
        relation_to_target = batch_data['relations']
        _, dense_context_dict = self.batch_data_process(mol_data, indicator_mol_to_ctx, indicator_node_to_mol, relation_to_target, apply_noise=False)
        batch_size = indicator_mol_to_ctx[-1] + 1
        
        if init_node == 'reconstruction' or init_node is None:
            node_mask_generated = dense_context_dict['node_mask_target']
            n_nodes_generated = node_mask_generated.sum(-1)
            node_mask_true = dense_context_dict['node_mask_target']
        elif init_node == 'random' or type(init_node) in [int, torch.Tensor]:
            if init_node == 'random':
                n_nodes_generated = self.node_dist.sample_n(batch_size, self.device)
            elif type(init_node) == int:
                n_nodes_generated = init_node * torch.ones(batch_size, device=self.device, dtype=torch.int)
            else:
                assert isinstance(init_node, torch.Tensor)
                n_nodes_generated = init_node
            arange = torch.arange(self.max_n_nodes, device=self.device).unsqueeze(0).expand(batch_size, -1)
            node_mask_generated = arange < n_nodes_generated.unsqueeze(1)
            dense_context_dict['node_mask_target'] = node_mask_generated
            node_mask_true = dense_context_dict['node_mask_target']

        n_nodes_true = node_mask_true.sum(-1)
        
        # z_T
        limit_dist = self.limit_dist.to(self.device)
        sampled_s = diffusion_utils.sample_discrete_feature_noise(limit_dist=limit_dist, node_mask=node_mask_generated)
        assert (sampled_s.E == torch.transpose(sampled_s.E, 1, 2)).all()

        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
        for s_int in reversed(range(0, self.T)):
            s_array = s_int * torch.ones((batch_size, 1)).to(self.device)
            t_array = s_array + 1
            s_norm = s_array / self.T
            t_norm = t_array / self.T

            # Sample z_s
            sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, sampled_s.X, sampled_s.E, sampled_s.pos, node_mask_generated, dense_context_dict)

        # Sample
        sampled_s = sampled_s.mask(node_mask_generated, collapse=True)        
        molecule_list = []
        for i in range(batch_size):
            n = n_nodes_generated[i]
            atom_types = sampled_s.X[i, :n].cpu()
            edge_types = sampled_s.E[i, :n, :n].cpu()
            positions = sampled_s.pos[i, :n, :n].cpu()
            molecule_list.append([atom_types, edge_types, positions])
        
        target_data = dense_context_dict['target_data']
        target_list = []
        target_data = target_data.mask(node_mask_true, collapse=True)        
        for i in range(batch_size):
            n = n_nodes_true[i]
            atom_types = target_data.X[i, :n].cpu()
            edge_types = target_data.E[i, :n, :n].cpu()
            positions = target_data.pos[i, :n, :n].cpu()
            target_list.append([atom_types, edge_types, positions])

        return molecule_list, target_list

    def sample_p_zs_given_zt(self, s, t, X_t, E_t, pos_t, node_mask, dense_context_dict):
        """Samples from zs ~ p(zs | zt). Only used during sampling.
           if last_step, return the graph prediction as well"""
        bs, n, dxs = X_t.shape
        beta_t = self.noise_schedule(t_normalized=t)  # (bs, 1)
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)

        # Neural net predictions
        noisy_data = {'X_t': X_t, 'E_t': E_t, 'pos_t': pos_t, 't': t, 'node_mask': node_mask}
        def get_prob(noisy_data, unconditioned=False):
            pred = self.forward(noisy_data, dense_context_dict, unconditioned=unconditioned)

            # Normalize predictions
            pred_X = F.softmax(pred.X, dim=-1)  # bs, n, d0
            pred_E = F.softmax(pred.E, dim=-1)  # bs, n, n, d0
            pred_pos = F.softmax(pred.pos, dim=-1)  # bs, n, n, dpos

            # Retrieve transitions matrix
            n_nodes = pred_X.size(1)
            Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device, n_nodes=n_nodes)
            Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device, n_nodes=n_nodes)
            Qt = self.transition_model.get_Qt(beta_t, self.device, n_nodes=n_nodes)

            Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
            pred_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)

            unnormalized_probX_all = diffusion_utils.reverse_diffusion(
                predX_0=pred_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
            )
            unnormalized_prob_X = unnormalized_probX_all[:, :, : self.max_node_type]
            unnormalized_prob_E = unnormalized_probX_all[:, :, self.max_node_type :].reshape(bs, n * n, -1)
            unnormalized_prob_pos =  diffusion_utils.reverse_diffusion(
                predX_0=pred_pos.flatten(start_dim=1, end_dim=-2), X_t=pos_t.flatten(start_dim=1, end_dim=-2), Qt=Qt.pos, Qsb=Qsb.pos, Qtb=Qtb.pos
            )
            unnormalized_prob_pos = unnormalized_prob_pos.reshape(bs, n, n, pred_pos.shape[-1])

            unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
            unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
            unnormalized_prob_pos[torch.sum(unnormalized_prob_pos, dim=-1) == 0] = 1e-5

            prob_X = unnormalized_prob_X / torch.sum(
                unnormalized_prob_X, dim=-1, keepdim=True
            )  # bs, n, d_t-1
            prob_E = unnormalized_prob_E / torch.sum(
                unnormalized_prob_E, dim=-1, keepdim=True
            )  # bs, n, d_t-1
            prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
            prob_pos = unnormalized_prob_pos / torch.sum(
                unnormalized_prob_pos, dim=-1, keepdim=True
            )
            prob_pos = prob_pos.reshape(bs, n, n, pred_pos.shape[-1])
            return prob_X, prob_E, prob_pos
        
        # prob_X, prob_E, pos = get_prob(noisy_data)
        prob_X, prob_E, prob_pos = get_prob(noisy_data, unconditioned=False)

        # ### Guidance
        if self.guide_scale != 1:
            uncon_prob_X, uncon_prob_E, uncon_prob_pos = get_prob(noisy_data, unconditioned=True)
            prob_X = uncon_prob_X *  (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale  
            prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale  
            prob_pos = uncon_prob_pos * (prob_pos / uncon_prob_pos.clamp_min(1e-10)) ** self.guide_scale  
            prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10)
            prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10)
            prob_pos = prob_pos / prob_pos.sum(dim=-1, keepdim=True).clamp_min(1e-10)

        assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
        assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
        assert ((prob_pos.sum(dim=-1) - 1).abs() < 1e-4).all()

        sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, prob_pos, node_mask=node_mask)

        X_s = F.one_hot(sampled_s.X, num_classes=self.max_node_type).to(self.model_dtype)
        E_s = F.one_hot(sampled_s.E, num_classes=self.max_bond_type).to(self.model_dtype)
        pos_s = F.one_hot(sampled_s.pos, num_classes=self.max_position_type).to(self.model_dtype)

        assert (E_s == torch.transpose(E_s, 1, 2)).all()
        assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)

        out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=None, pos=pos_s)
        out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=None, pos=pos_s)

        return out_one_hot.mask(node_mask)