import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np
import time
import os

from models.transformer import Denoiser
from models.transformer_unconditioned import UnconditionedDenoiser
from models.conditions import CategoricalEmbedder, ClusterContinuousEmbedder

from metrics.train_loss import TrainLossDiscrete
from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL, MaskedLossLt

from diffusion import diffusion_utils
from diffusion.transition import MaskedTransition

import utils


class MELD(pl.LightningModule):
    def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools, num_train_steps=None, test_scaffold_smiles=None):
        super().__init__()
        self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])

        input_dims = dataset_infos.input_dims
        output_dims = dataset_infos.output_dims
        nodes_dist = dataset_infos.nodes_dist
        active_index = dataset_infos.active_index
        self.test_scaffold_smiles = test_scaffold_smiles

        self.cfg = cfg
        self.name = cfg.general.name
        self.T = cfg.model.diffusion_steps
        self.guide_scale = cfg.model.guide_scale

        self.Xdim = input_dims['X'] + 1
        self.Edim = input_dims['E'] + 1
        self.ydim = input_dims['y']
        self.Xdim_output = output_dims['X'] + 1
        self.Edim_output = output_dims['E'] + 1

        self.ydim_output = output_dims['y']
        self.node_dist = nodes_dist
        self.active_index = active_index
        self.dataset_info = dataset_infos
        
        self.atom_weight_list = dataset_infos.atom_weight_list
        self.atom_weight_list[-1] = self.atom_weight_list[:-1].mean()

        self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train)

        self.val_nll = NLL()
        self.val_nll = NLL()

        self.val_X_kl = MaskedLossLt()
        self.val_E_kl = MaskedLossLt()
        self.test_X_kl = MaskedLossLt()
        self.test_E_kl = MaskedLossLt()
            
        self.val_X_logp = SumExceptBatchMetric()
        self.val_E_logp = SumExceptBatchMetric()
        self.val_y_collection = []
        self.val_smiles_collection = []

        self.test_nll = NLL()
        self.test_X_logp = SumExceptBatchMetric()
        self.test_E_logp = SumExceptBatchMetric()
        self.test_y_collection = []
        self.test_smiles_collection = []

        self.train_metrics = train_metrics
        self.sampling_metrics = sampling_metrics

        self.visualization_tools = visualization_tools
        self.max_n_nodes = dataset_infos.max_n_nodes

        if cfg.model.unconditioned:
            self.model = UnconditionedDenoiser(max_n_nodes=self.max_n_nodes,
                            hidden_size=cfg.model.hidden_size,
                            depth=cfg.model.depth,
                            num_heads=cfg.model.num_heads,
                            mlp_ratio=cfg.model.mlp_ratio,
                            drop_condition=cfg.model.drop_condition,
                            Xdim=self.Xdim, 
                            Edim=self.Edim,
                            ydim=self.ydim,
                            task_type=dataset_infos.task_type,
                            transition=cfg.model.transition,
                            guidance_target=cfg.dataset.guidance_target)
        else:
            self.model = Denoiser(max_n_nodes=self.max_n_nodes,
                            hidden_size=cfg.model.hidden_size,
                            depth=cfg.model.depth,
                            num_heads=cfg.model.num_heads,
                            mlp_ratio=cfg.model.mlp_ratio,
                            drop_condition=cfg.model.drop_condition,
                            Xdim=self.Xdim, 
                            Edim=self.Edim,
                            ydim=self.ydim,
                            task_type=dataset_infos.task_type,
                            transition=cfg.model.transition,
                            guidance_target=cfg.dataset.guidance_target,
                            use_provided_drop_ids=cfg.model.use_provided_drop_ids)

        if self.cfg.model.diffusion_noise_schedule == 'latent-learnable-schedule':
            from diffusion.scheduler.latent_scheduler import LatentSchedule
            self.noise_schedule = LatentSchedule(
                X_dim=self.Xdim, 
                E_dim=self.Edim, 
                ydim=self.ydim_output,      
                context_dim=cfg.model.transition_hidden_dim,
                timesteps=cfg.model.diffusion_steps,
                max_n_nodes=self.max_n_nodes, 
            )
        elif self.cfg.model.diffusion_noise_schedule == 'latent-property-learnable-schedule':
            from diffusion.scheduler.latent_scheduler import LatentPropertySchedule
            self.noise_schedule = LatentPropertySchedule(
                X_dim=self.Xdim, 
                E_dim=self.Edim, 
                ydim=self.ydim_output,      
                context_dim=cfg.model.transition_hidden_dim,
                timesteps=cfg.model.diffusion_steps,
                max_n_nodes=self.max_n_nodes, 
                guidance_target=cfg.dataset.guidance_target,
                drop_condition=cfg.model.drop_condition,
                use_provided_drop_ids=cfg.model.use_provided_drop_ids,
            )
            
        self.transition_model = MaskedTransition(self.Xdim, self.Edim, self.ydim, cfg=cfg)
        
        self.start_epoch_time = None
        self.train_iterations = None
        self.val_iterations = None
        self.log_every_steps = cfg.general.log_every_steps
        self.number_chain_steps = cfg.general.number_chain_steps

        self.best_val_nll = 1e8
        self.val_counter = 0
        self.batch_size = self.cfg.train.batch_size
   
    def forward(self, noisy_data, unconditioned=False, force_drop_ids=None):
        x, e, y = noisy_data['X_t'].float(), noisy_data['E_t'].float(), noisy_data['y_t'].float().clone()
        node_mask, t =  noisy_data['node_mask'], noisy_data['t']
        if t.dim() == 1:
            t = t.unsqueeze(-1)
        pred = self.model(x, e, node_mask=node_mask, y=y, t=t, unconditioned=unconditioned, 
                        force_drop_ids=force_drop_ids)
        return pred
        
    def training_step(self, data, i):
        if self.cfg.model.use_provided_drop_ids:
            force_drop_ids = (torch.rand(data.y.size(0), device=data.y.device) < self.cfg.model.drop_condition)
        else:
            force_drop_ids = None

        data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
        # add new class for mask token
        data_x = torch.cat([data_x, torch.zeros(data_x.size(0), 1).to(self.device)], dim=-1)
        if self.cfg.dataset.task_name in ["ZINC", "QM9"]:
            data_edge_attr = F.one_hot(data.edge_attr, num_classes=4+1).float()
        else:
            data_edge_attr = F.one_hot(data.edge_attr, num_classes=5+1).float()

        dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
        dense_data = dense_data.mask(node_mask)
        
        X, E = dense_data.X, dense_data.E
        tau = self.cfg.train.temperature
        noisy_data = self.apply_noise(X, E, data.y, node_mask, force_drop_ids=force_drop_ids, tau=tau)

        pred = self.forward(noisy_data, force_drop_ids=force_drop_ids)

        noisy_t_int = noisy_data['t_int']
        t_float = noisy_t_int.float() / self.T
        bs, n = X.size(0), X.size(1)
        dgamma_times_alpha = self.noise_schedule.dgamma_times_alpha(bs, n, self.device, t_normalized=t_float, 
                                                                    prop=data.y, force_drop_ids=force_drop_ids)  # (bs, 1)
        cls_mask_node = (noisy_data['X_t'].argmax(dim=-1) == self.Xdim_output - 1)  # (bs, n)
        cls_mask_edge = (noisy_data['E_t'].argmax(dim=-1) == self.Edim_output - 1)  # (bs, n, n)

        loss = self.train_loss.forward_w_mask(
            pred_X=pred.X, 
            pred_E=pred.E, 
            pred_y=pred.y,
            true_X=X, 
            true_E=E,
            true_y=data.y,
            node_mask=node_mask,
            cls_mask_node=cls_mask_node,
            cls_mask_edge=cls_mask_edge,
            dgamma_times_alpha=dgamma_times_alpha,
            log=i % self.log_every_steps == 0,
            observation=False,
            noise_schedule=self.cfg.model.diffusion_noise_schedule
        )
            
        self.train_metrics(pred_X=pred.X, pred_E=pred.E, true_X=X, true_E=E,
                        log=i % self.log_every_steps == 0)
        self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
        return {'loss': loss}

    def configure_optimizers(self):
        params = self.parameters()
        optimizer = torch.optim.AdamW(params, lr=self.cfg.train.lr, amsgrad=True,
                                 weight_decay=self.cfg.train.weight_decay)
        return optimizer
    
    def on_fit_start(self) -> None:
        self.train_iterations = self.trainer.datamodule.training_iterations
        print('on fit train iteration:', self.train_iterations)
        print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim))

    def on_train_epoch_start(self) -> None:
        if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
            print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs))
        self.start_epoch_time = time.time()
        self.train_loss.reset()
        self.train_metrics.reset()

    def on_train_epoch_end(self) -> None:
        if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
            log = True
        else:
            log = False
        self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log)
        self.train_metrics.log_epoch_metrics(self.current_epoch, log)

    def on_validation_epoch_start(self) -> None:
        self.val_nll.reset()
        self.val_X_kl.reset()
        self.val_E_kl.reset()
        self.val_X_logp.reset()
        self.val_E_logp.reset()
        self.sampling_metrics.reset()
        self.val_y_collection = []
        self.val_smiles_collection = []

    @torch.no_grad()
    def validation_step(self, data, i):
        data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
        data_x = torch.cat([data_x, torch.zeros(data_x.size(0), 1).to(self.device)], dim=-1)
        if self.cfg.dataset.task_name in ["ZINC", "QM9"]:
            data_edge_attr = F.one_hot(data.edge_attr, num_classes=4+1).float()
        else:
            data_edge_attr = F.one_hot(data.edge_attr, num_classes=5+1).float()
        
        dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
        dense_data = dense_data.mask(node_mask)
        tau = self.cfg.train.temperature
        noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask, tau=tau)

        pred = self.forward(noisy_data)
        pred.X[:, :, -1] = float('-inf')
        pred.E[:, :, :, -1] = float('-inf')
        nll, loss_prior, loss_diffusion, loss_recons = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, 
                                                                            test=False, force_drop_ids=None)
            
        self.val_y_collection.append(data.y)
        self.val_smiles_collection.extend(data.smiles)
        
        self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True)
        self.log(f'loss_prior', loss_prior, batch_size=data.x.size(0), sync_dist=True)
        self.log(f'loss_diffusion', loss_diffusion, batch_size=data.x.size(0), sync_dist=True)
        self.log(f'loss_recons', loss_recons, batch_size=data.x.size(0), sync_dist=True)
        return {'loss': nll}

    def on_validation_epoch_end(self) -> None:
        torch.cuda.empty_cache()
        
        metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T,
                   self.val_X_logp.compute(), self.val_E_logp.compute()]
        if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
            print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
                f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll))
        self.log("val/NLL",  metrics[0], sync_dist=True)

        if metrics[0] < self.best_val_nll:
            self.best_val_nll = metrics[0]

        self.val_counter += 1
        
        if self.val_counter % self.cfg.general.sample_every_val == 0 and self.val_counter > 1 and self.val_counter >= 500:
            start = time.time()
            samples_left_to_generate = self.cfg.general.samples_to_generate
            samples_left_to_save = self.cfg.general.samples_to_save
            chains_left_to_save = self.cfg.general.chains_to_save

            samples, all_ys = [], []
            all_smiles = []

            val_y_collection = torch.cat(self.val_y_collection, dim=0)

            num_examples = val_y_collection.size(0)
            if samples_left_to_generate > num_examples:
                ratio = samples_left_to_generate // num_examples
                val_y_collection = val_y_collection.repeat(ratio+1, 1)
                self.val_smiles_collection = self.val_smiles_collection * (ratio + 1)
                num_examples = val_y_collection.size(0)
                
            start_index = 0
            while samples_left_to_generate > 0:                
                bs = 1 * 256 # self.cfg.train.batch_size
                to_generate = min(samples_left_to_generate, bs)
                to_save = min(samples_left_to_save, bs)
                chains_save = min(chains_left_to_save, bs)
                
                batch_y = val_y_collection[start_index:start_index + to_generate]   
                batch_smiles = [self.val_smiles_collection[idx] for idx in range(start_index, start_index + to_generate)]

                cur_sample = self.sample_batch(batch_id=start_index, batch_size=to_generate, y=batch_y,
                                                save_final=to_save,
                                                keep_chain=chains_save,
                                                number_chain_steps=self.number_chain_steps,
                                                is_test=False, save_batch=False, batch_smiles=batch_smiles)
                samples.extend(cur_sample)
                all_ys.append(batch_y)
                all_smiles.extend(batch_smiles)
                start_index += to_generate

                samples_left_to_save -= to_save
                samples_left_to_generate -= to_generate
                chains_left_to_save -= chains_save

            print(f"Computing sampling metrics", ' ...')
            
            valid_smiles, total_errors = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False, 
            ref_smiles=self.test_scaffold_smiles if self.test_scaffold_smiles is not None else all_smiles)
            print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
            for task in total_errors:
                self.log(f'val/{task}', total_errors[task], on_epoch=True, sync_dist=True)
            current_path = os.getcwd()
            result_path = os.path.join(current_path,
                                       f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
            self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
            self.sampling_metrics.reset()

    def on_test_epoch_start(self) -> None:
        print("Starting test...")
        self.test_nll.reset()
        self.test_X_kl.reset()
        self.test_E_kl.reset()
        self.test_X_logp.reset()
        self.test_E_logp.reset()
        self.test_y_collection = []
        self.test_smiles_collection = []
    
    @torch.no_grad()
    def test_step(self, data, i):
        data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
        data_x = torch.cat([data_x, torch.zeros(data_x.size(0), 1).to(self.device)], dim=-1)
        if self.cfg.dataset.task_name in ["ZINC", "QM9"]:
            data_edge_attr = F.one_hot(data.edge_attr, num_classes=4+1).float()
        else:
            data_edge_attr = F.one_hot(data.edge_attr, num_classes=5+1).float()

        dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
        dense_data = dense_data.mask(node_mask)
        tau = self.cfg.train.temperature
        noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask, tau=tau)
        
        pred = self.forward(noisy_data)
        pred.X[:, :, -1] = float('-inf')
        pred.E[:, :, :, -1] = float('-inf')
        nll, _, _, _ = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, 
                                             test=True, force_drop_ids=None)
        self.test_y_collection.append(data.y)
        self.test_smiles_collection.extend(data.smiles)
        return {'loss': nll}

    def on_test_epoch_end(self) -> None:
        """ Measure likelihood on a test set and compute stability metrics. """
        metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(),
                   self.test_X_logp.compute(), self.test_E_logp.compute()]

        print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
              f"Test Edge type KL: {metrics[2] :.2f}")

        ## final epcoh
        # self.cfg.general.final_model_samples_to_generate = 500
        samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
        samples_left_to_save = self.cfg.general.final_model_samples_to_save
        chains_left_to_save = self.cfg.general.final_model_chains_to_save

        samples, all_ys, batch_id = [], [], 0
        all_smiles = []
        
        test_y_collection = torch.cat(self.test_y_collection, dim=0)
        
        num_examples = test_y_collection.size(0)

        if self.cfg.general.final_model_samples_to_generate > num_examples:
            ratio = self.cfg.general.final_model_samples_to_generate // num_examples
            test_y_collection = test_y_collection.repeat(ratio+1, 1)
            self.test_smiles_collection = self.test_smiles_collection * (ratio + 1)
            num_examples = test_y_collection.size(0)
        
        ttt = 0
        while samples_left_to_generate > 0:
            print(f'samples left to generate: {samples_left_to_generate}/'
                f'{self.cfg.general.final_model_samples_to_generate}', end='', flush=True)

            bs = 1 * 500
            to_generate = min(samples_left_to_generate, bs)
            to_save = min(samples_left_to_save, bs)
            chains_save = min(chains_left_to_save, bs)
            
            batch_y = test_y_collection[batch_id : batch_id + to_generate]
            batch_smiles = [self.test_smiles_collection[idx] for idx in range(batch_id, batch_id + to_generate)]
            
            cur_sample = self.sample_batch(batch_id=batch_id, batch_size=to_generate, y=batch_y, save_final=to_save,
                                            keep_chain=chains_save, number_chain_steps=self.number_chain_steps, 
                                            is_test=True, save_batch=True, batch_smiles=batch_smiles)
            samples = samples + cur_sample
            
            all_ys.append(batch_y)
            all_smiles.extend(batch_smiles)
            batch_id += to_generate

            samples_left_to_save -= to_save
            samples_left_to_generate -= to_generate
            chains_left_to_save -= chains_save

            ttt += 1
            
        print(f"final Computing sampling metrics...")
        self.sampling_metrics.reset()
        self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, self.val_counter, test=True, 
        ref_smiles=self.test_scaffold_smiles if self.test_scaffold_smiles is not None else all_smiles)
        self.sampling_metrics.reset()
        print(f"Done.")
            

    def kl_prior(self, X, E, y, node_mask):
        """Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).

        This is essentially a lot of work for something that is in practice negligible in the loss. However, you
        compute it so that you see it when you've made a mistake in your noise schedule.
        """
        # Compute the last alpha value, alpha_T.
        ones = torch.ones((X.size(0), 1), device=X.device)
        Ts = self.T * ones
        
        bs, n = X.size(0), X.size(1)
        alpha_t = self.noise_schedule.alpha(bs, n, self.device, t_int=Ts, prop=y)  # (bs, 1)
        Qtb = self.transition_model.get_Qt_bar(alpha_t, self.device, y)
        
        bs, n, d = X.shape
        if Qtb.X.dim() == 4:
            probX = torch.einsum('bnk,bnkj->bnj', X, Qtb.X)  # (111, 50, 11)
            probE = torch.einsum('bnmk,bnmkj->bnmj', E, Qtb.E)  # (111, 50, 50, 5)
        else:
            probX = X @ Qtb.X
            probE = E @ Qtb.E.unsqueeze(1)
        assert probX.shape == X.shape

        limit_x, limit_e = self.transition_model.get_marginals(y)
        limit_dist = utils.PlaceHolder(X=limit_x, E=limit_e, y=None)
        
        limit_X = limit_dist.X.unsqueeze(1).expand(bs, n, -1).type_as(probX)
        limit_E = limit_dist.E.unsqueeze(1).unsqueeze(2).expand(bs, n, n, -1).type_as(probE)

        # Make sure that masked rows do not contribute to the loss
        limit_dist_X, limit_dist_E, probX, probE = diffusion_utils.mask_distributions(true_X=limit_X.clone(),
                                                                                      true_E=limit_E.clone(),
                                                                                      pred_X=probX,
                                                                                      pred_E=probE,
                                                                                      node_mask=node_mask)

        kl_distance_X = F.kl_div(input=probX.log(), target=limit_dist_X, reduction='none')
        kl_distance_E = F.kl_div(input=probE.log(), target=limit_dist_E, reduction='none')
        return diffusion_utils.sum_except_batch(kl_distance_X) + \
               diffusion_utils.sum_except_batch(kl_distance_E)

    def compute_Lt(self, X, E, y, pred, noisy_data, node_mask, test, force_drop_ids=None):
        # Diffusion loss
        bs, n, d = X.shape
        cls_mask_node = (noisy_data['X_t'][:, :, -1] == 1).reshape(-1)
        cls_mask_edge = (noisy_data['E_t'][:, :, :, -1] == 1).reshape(-1)
        final_node_mask = node_mask.reshape(-1) & cls_mask_node
        edge_mask = (node_mask.unsqueeze(1) * node_mask.unsqueeze(2))
        diag = torch.eye(n, dtype=torch.bool, device=node_mask.device)
        edge_mask[diag.unsqueeze(0).expand(bs, -1, -1)] = False
        edge_mask = edge_mask.reshape(-1)
        final_edge_mask = edge_mask & cls_mask_edge

        dgamma_times_alpha = self.noise_schedule.dgamma_times_alpha(bs, n, device=X.device, 
                                                                    t_normalized=noisy_data['t'], 
                                                                    prop=y, force_drop_ids=force_drop_ids)  # (bs, 1)
        if isinstance(dgamma_times_alpha, tuple):
            weights_node, weights_edge = dgamma_times_alpha
            weights_node = weights_node.reshape(-1)[final_node_mask]
            weights_edge = weights_edge.reshape(-1)[final_edge_mask]
        else:
            weights = dgamma_times_alpha
            weights_node = weights.expand(-1, n).reshape(-1)[final_node_mask]
            weights_edge = weights.unsqueeze(-1).expand(-1, n, n).reshape(-1)[final_edge_mask]

        true_X = X.reshape(-1, X.size(-1))[final_node_mask]
        true_E = E.reshape(-1, E.size(-1))[final_edge_mask]
        
        pred_X = pred.X.reshape(-1, pred.X.size(-1))[final_node_mask, :]
        pred_E = pred.E.reshape(-1, pred.E.size(-1))[final_edge_mask, :]

        loss_X = F.cross_entropy(pred_X, torch.argmax(true_X, dim=-1), reduction='none') * weights_node.reshape(-1)
        loss_E = F.cross_entropy(pred_E, torch.argmax(true_E, dim=-1), reduction='none') * weights_edge.reshape(-1)
        
        if test:
            kl_x = self.test_X_kl(loss_X)
            kl_e = self.test_E_kl(loss_E)
        else:
            kl_x = self.val_X_kl(loss_X)
            kl_e = self.val_E_kl(loss_E)
        return kl_x + kl_e

    def reconstruction_logp(self, t, X, E, y, node_mask, debug=False):
        # Compute noise values for t = 0.
        bs, n = X.size(0), X.size(1)
        t_zeros = torch.zeros_like(t)
        alpha_0 = self.noise_schedule.alpha(bs, n, self.device, t_normalized=t_zeros, prop=y)
        
        if isinstance(alpha_0, tuple):
            alpha_0_x, alpha_0_e = alpha_0
            recon_x = (1 - alpha_0_x) * np.log(self.Xdim_output-1) # (bs, n)
            recon_x = recon_x.sum()
            recon_e = (1 - alpha_0_e) * np.log(self.Edim_output-1) # (bs, n, n)
            recon_e = recon_e.sum()
        else:
            alpha_0 = alpha_0
            n = X.shape[1]
            recon_x = n * (1 - alpha_0) * np.log(self.Xdim_output-1)
            recon_x = recon_x.sum()
            recon_e = n * n * (1 - alpha_0) * np.log(self.Edim_output-1) # n*(n-1)?
            recon_e = recon_e.sum()
        
        return recon_x + recon_e
    
    def apply_noise(self, X, E, y, node_mask, fully_noisy=False, fixed_time=None, force_drop_ids=None, 
                    specific_T=None, tau=1):
        """ Sample noise and apply it to the data. """

        if specific_T is not None:
            t_int = torch.ones(X.size(0), 1).to(X.device) * specific_T
        lowest_t = 1
        if fixed_time is not None:
            t_int = torch.ones(X.size(0), 1).to(X.device) * fixed_time
        elif not fully_noisy:
            t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1)
        else:
            t_int = (torch.ones(X.size(0), 1) * (self.T)).float().to(X.device)
        s_int = t_int - 1

        t_float = t_int / self.T
        s_float = s_int / self.T
        
        bs, n, d = X.shape
        alpha_s = self.noise_schedule.alpha(bs, n, self.device, t_normalized=s_float, prop=y, 
                                            force_drop_ids=force_drop_ids)
        alpha_t = self.noise_schedule.alpha(bs, n, self.device, t_normalized=t_float, prop=y, 
                                            force_drop_ids=force_drop_ids)
        Qtb = self.transition_model.get_Qt_bar(alpha_t, self.device, y)

        if Qtb.X.dim() == 4:
            probX = torch.einsum('bnk,bnkj->bnj', X, Qtb.X)  # (111, 50, 11)
            probE = torch.einsum('bnmk,bnmkj->bnmj', E, Qtb.E)  # (111, 50, 50, 5)
        else:
            probX = X @ Qtb.X # (bs, n, dx)
            probE = E @ Qtb.E.unsqueeze(1) # (bs, n, n, de)

        sampled_t = diffusion_utils.st_sample_discrete_feature(probX=probX, probE=probE, node_mask=node_mask, hard=False, 
                                                            temperature=tau)
        X_t = sampled_t.X
        E_t = sampled_t.E
        assert (X.shape == X_t.shape) and (E.shape == E_t.shape)

        y_t = y
        z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
        noisy_data = {'t_int': t_int, 't': t_float, 'alpha_s': alpha_s, 's_int': s_int, 's': s_float,
                      'alpha_t': alpha_t, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask, 
                      'prob_X': probX, 'prob_E': probE}
        return noisy_data

    def compute_val_loss(self, pred, noisy_data, X, E, y, node_mask, test=False, force_drop_ids=None):
        """Computes an estimator for the variational lower bound.
           pred: (batch_size, n, total_features)
           noisy_data: dict
           X, E, y : (bs, n, dx),  (bs, n, n, de), (bs, dy)
           node_mask : (bs, n)
           Output: nll (size 1)
       """
        t = noisy_data['t']

        # 1.
        N = node_mask.sum(1).long()
        log_pN = self.node_dist.log_prob(N)

        # 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes) should be close to zero.
        loss_prior = self.kl_prior(X, E, y, node_mask)

        # 3. Diffusion loss
        loss_diff = self.compute_Lt(X, E, y, pred, noisy_data, node_mask, test, force_drop_ids)

        # 4. Reconstruction loss
        loss_recon = self.reconstruction_logp(t, X, E, y, node_mask)

        # Combine terms
        nlls = loss_diff + loss_prior + loss_recon - log_pN
        assert len(nlls.shape) == 1, f'{nlls.shape} has more than only batch dim.'

        # Update NLL metric object and return batch nll
        nll = (self.test_nll if test else self.val_nll)(nlls)        # Average over the batch
        nlls = loss_diff + loss_prior + loss_recon - log_pN
        return nll, loss_prior.mean().item(), loss_diff.mean().item(), loss_recon.mean().item()
    
    @torch.no_grad()
    def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, 
                     num_nodes=None, is_test=False, save_batch=False, batch_smiles=None):
        """
        :param batch_id: int
        :param batch_size: int
        :param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes
        :param save_final: int: number of predictions to save to file
        :param keep_chain: int: number of chains to save to file (disabled)
        :param keep_chain_steps: number of timesteps to save for each chain (disabled)
        :return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions)
        """
        # sample number of nodes
        if num_nodes is None:
            n_nodes = self.node_dist.sample_n(batch_size, self.device)
        elif type(num_nodes) == int:
            n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int)
        else:
            assert isinstance(num_nodes, torch.Tensor)
            n_nodes = num_nodes
        n_max = self.max_n_nodes
        arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1)
        node_mask = arange < n_nodes.unsqueeze(1)

        limit_x, limit_e = self.transition_model.get_marginals()
        limit_dist = utils.PlaceHolder(X=limit_x, E=limit_e, y=None)

        z_T = diffusion_utils.sample_discrete_features_from_limits(limit_dist=limit_dist, node_mask=node_mask, hard=True, 
                                                                    temperature=self.cfg.train.temperature)
        
        ### validity check
        if limit_dist.X.isnan().any() or limit_dist.E.isnan().any():
            raise ValueError("limit_dist.X.isnan().any() or limit_dist.E.isnan().any()")
        if z_T.X.isnan().any() or z_T.E.isnan().any():
            raise ValueError("z_T.X.isnan().any() or z_T.E.isnan().any()")
        
        X, E = z_T.X, z_T.E

        del (limit_dist, limit_x, limit_e, z_T)
        torch.cuda.empty_cache()

        step = self.T
        for s_int in reversed(range(0, self.T)):
            s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
            t_array = s_array + 1
            s_norm = s_array / self.T
            t_norm = t_array / self.T

            # Sample z_s
            mask_bool = ("masked" in self.cfg.model.transition) and (s_int == 0)
            sampled_s = self.sample_p_zs_given_zt(step, s_norm, t_norm, X, E, y, 
                                                    node_mask, n_nodes, 
                                                    mask_bool=mask_bool)

            X, E = sampled_s.X, sampled_s.E
            
            step -= 1
        
        # check if final sampled X, E do not contain masks.
        max_cls_node = X.max(dim=-1).values
        max_cls_edge = E.max(dim=-1).values
        if (max_cls_node == self.Xdim_output - 1).any() or (max_cls_edge == self.Edim_output - 1).any():
            raise ValueError("Final sampled X, E should not contain masks.")
    
        sampled_s = sampled_s.mask(node_mask, collapse=True)
        sampled_s.y = y
        X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
        
        molecule_list = []
        for i in range(batch_size):
            n = n_nodes[i]
            atom_types = X[i, :n].cpu()
            edge_types = E[i, :n, :n].cpu()
            molecule_list.append([atom_types, edge_types])
        return molecule_list

    def sample_p_zs_given_zt(self, step, s, t, X_t, E_t, y_t, node_mask, n_nodes, mask_bool=False):
        """Samples from zs ~ p(zs | zt). Only used during sampling.
           if last_step, return the graph prediction as well"""
        bs, n, dx = X_t.shape

        # Neural net predictions
        noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 
                      'node_mask': node_mask}
        edge_mask = node_mask.unsqueeze(1) & node_mask.unsqueeze(2)

        pred = self.forward(noisy_data, unconditioned=False)
        prob_X = F.softmax(pred.X, dim=-1)
        prob_E = F.softmax(pred.E, dim=-1)
        # unconditioned prediction
        if self.guide_scale is not None and not self.cfg.model.unconditioned:
            uncon_pred = self.forward(noisy_data, unconditioned=True)
            uncon_prob_X = F.softmax(uncon_pred.X, dim=-1)
            uncon_prob_E = F.softmax(uncon_pred.E, dim=-1)

            prob_X = uncon_prob_X *  (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale  
            prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale  
            prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10)
            prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10)
            del (uncon_pred, uncon_prob_X, uncon_prob_E)
            torch.cuda.empty_cache()

        mask_location_X = (noisy_data['X_t'].argmax(dim=-1) == self.Xdim_output - 1)
        mask_location_E = (noisy_data['E_t'].argmax(dim=-1) == self.Edim_output - 1)

        alpha_s = self.noise_schedule.alpha(bs, n, self.device, t_normalized=s, prop=y_t)
        alpha_t = self.noise_schedule.alpha(bs, n, self.device, t_normalized=t, prop=y_t)
        
        alpha_s_node, alpha_s_edge = alpha_s # (bs, n), (bs, n, n)
        alpha_t_node, alpha_t_edge = alpha_t
            
        unmask_ratio_X = torch.ceil(alpha_t_node[node_mask].mean(dim=-1) * n_nodes).long()
        triu_mask = torch.triu(torch.ones(n, n, device=self.device), diagonal=1).bool()
        triu_mask = triu_mask.unsqueeze(0).expand(bs, -1, -1)
        edge_mask = node_mask.unsqueeze(1) & node_mask.unsqueeze(2) # (bs, n, n)
        edge_mask = edge_mask & triu_mask # & mask_location_E

        masked_alpha = alpha_t_edge.masked_fill(~edge_mask, 0.0)  # (bs, n, n)
        edge_mask_float = edge_mask.float()
        sum_alpha = masked_alpha.view(bs, -1).sum(dim=-1)  # (bs,)
        num_valid_edges = edge_mask_float.view(bs, -1).sum(dim=-1).clamp(min=1.0)  # prevent div by 0
        mean_alpha = sum_alpha / num_valid_edges  # (bs,)
        unmask_ratio_E = torch.ceil(mean_alpha * n_nodes * (n_nodes - 1) / 2).long()  # (bs,)

        node_conf = prob_X.max(dim=-1).values
        # node_conf = node_conf.masked_fill(~(node_mask & mask_location_X), float('-inf'))
        node_conf = node_conf.masked_fill(~node_mask, float('-inf'))
        edge_conf = prob_E.max(dim=-1).values # (bs, n, n)
        edge_conf = edge_conf.masked_fill(~edge_mask, float('-inf'))

        to_unmask_X = torch.zeros(bs, n, device=self.device, dtype=torch.bool)
        to_unmask_E = torch.zeros(bs, n, n, device=self.device, dtype=torch.bool)
        for b in range(bs):
            # valid_masked_nodes = mask_location_X[b] & node_mask[b]
            kx = min(unmask_ratio_X[b].item(), node_mask[b].sum().item())
            topk_node_idx = torch.topk(node_conf[b], kx, dim=-1).indices
            to_unmask_X[b, topk_node_idx] = 1

            triu_indices = triu_mask[b].nonzero(as_tuple=False)  # (e, 2)
            flat_conf = edge_conf[b][triu_mask[b]]  # 1D of upper triangle
            ke = min(unmask_ratio_E[b].item(), flat_conf.numel())
            topk_edge_idx = torch.topk(flat_conf, ke, dim=-1).indices  # indices in flat vector

            topk_pairs = triu_indices[topk_edge_idx]  # shape: (k, 2)
            to_unmask_E[b, topk_pairs[:, 0], topk_pairs[:, 1]] = True
            to_unmask_E[b, topk_pairs[:, 1], topk_pairs[:, 0]] = True
            

        if not mask_bool:
            places_to_predict_x = mask_location_X & to_unmask_X
            places_to_predict_e = mask_location_E & to_unmask_E
        else:
            places_to_predict_x = mask_location_X
            places_to_predict_e = mask_location_E

        places_to_predict_x = places_to_predict_x.unsqueeze(-1).expand(bs, n, self.Xdim_output)
        places_to_predict_e = places_to_predict_e.unsqueeze(-1).expand(bs, n, n, self.Edim_output)

        sampled_s = diffusion_utils.sample_discrete_features(
            probX=prob_X, 
            probE=prob_E, 
            node_mask=node_mask,
            remove_mask_class=mask_bool
        )
        X_s = torch.where(places_to_predict_x, F.one_hot(sampled_s.X, num_classes=self.Xdim_output), X_t.float())
        E_s = torch.where(places_to_predict_e, F.one_hot(sampled_s.E, num_classes=self.Edim_output), E_t.float())

        out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)

        torch.cuda.empty_cache()
        return out_one_hot.mask(node_mask).type_as(y_t)
            
    