import time
import os
import math
import numpy as np

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import wandb

from didigress.models.transformer_model import GraphTransformer
from didigress.models.transformer_model3d import GraphTransformer3D
# from didigress.models.egnn_ablation import GraphTransformer
# print("RUNNING ABLATION")
from didigress.diffusion import diffusion_utils
from didigress.diffusion.diffusion_utils import mask_distributions, sum_except_batch
from didigress import utils
from didigress.diffusion.extra_features import ExtraFeatures
from didigress.analysis.rdkit_functions import Molecule
from didigress.datasets.adaptive_loader import effective_batch_size

from rdkit import Chem
from omegaconf import OmegaConf

class DiGress(pl.LightningModule):
    model_dtype = torch.float32
    best_val_nll = 1e8
    val_counter = 0
    start_epoch_time = None
    train_iterations = None
    val_iterations = None

    def __init__(self, cfg, dataset_infos, losses, noise_model, visualizer=None):
        super().__init__()
        nodes_dist = dataset_infos.nodes_dist

        self.cfg = cfg
        self.name = cfg.general.name
        self.T = cfg.model.diffusion_steps

        self.use_charges = self.cfg.features.use_charges
        self.use_3d      = self.cfg.features.use_3d
        self.use_ins_del = self.cfg.features.use_ins_del

        self.freeze_n_nodes_at_sampling = self.cfg.features.freeze_n_nodes_at_sampling
        self.validation_loss_type = self.cfg.features.validation_loss_type

        self.node_dist = nodes_dist
        self.dataset_infos = dataset_infos
        self.extra_features = ExtraFeatures(cfg, dataset_info=dataset_infos)
        self.input_dims = self.extra_features.update_input_dims(dataset_infos.input_dims)
        self.output_dims = dataset_infos.output_dims
        # self.domain_features = ExtraMolecularFeatures(dataset_infos=dataset_infos)

        # Train metrics
        self.train_loss             = losses['train_loss']
        self.train_metrics          = losses['train_metrics']

        self.val_loss               = losses['val_loss']
        self.val_metrics            = losses['val_metrics']
        self.val_nll                = losses['val_nll']
        self.val_sampling_metrics   = losses['val_sampling_metrics']

        self.test_metrics           = losses['test_metrics']
        self.test_nll               = losses['test_nll']
        self.test_sampling_metrics  = losses['test_sampling_metrics']

        self.graph_list_maker       = losses['graph_list_maker']

        self.ignore_hyperparameters = losses['ignore_hyperparameters']

        self.visualizer = visualizer

        self.save_hyperparameters(ignore=self.ignore_hyperparameters)

        if(self.use_3d): model_class = GraphTransformer3D
        else:            model_class = GraphTransformer

        self.max_n_nodes = dataset_infos.max_n_nodes
        self.model = model_class(input_dims=self.input_dims,
                                n_layers=cfg.model.n_layers,
                                hidden_mlp_dims=cfg.model.hidden_mlp_dims,
                                hidden_dims=cfg.model.hidden_dims,
                                output_dims=self.output_dims, cfg=cfg,
                                max_n_nodes=self.max_n_nodes)
        if(self.use_ins_del):
            self.delt_model = model_class(input_dims=self.input_dims,
                            n_layers=cfg.model.n_layers_delt,
                            hidden_mlp_dims=cfg.model.hidden_mlp_dims_delt,
                            hidden_dims=cfg.model.hidden_dims_delt,
                            output_dims=self.output_dims, cfg=cfg,
                            predict_delt = True, max_n_nodes=self.max_n_nodes)
            self.max_n = self.cfg.features.max_n if self.cfg.features.max_n > 0 \
                                                 else self.dataset_infos.max_n_nodes
            
            self.improvement_threshold      = OmegaConf.select(cfg, "guidance.improvement_threshold")
            self.corruption_step            = OmegaConf.select(cfg, "guidance.corruption_step")
            self.similarity_threshold       = OmegaConf.select(cfg, "guidance.similarity_threshold")

            self.delt_model_path            = OmegaConf.select(cfg, "features.delt_model_path")
            self.delt_loaded                = self.delt_model_path is None  #if the path is not none, we will load it later
            if not self.delt_loaded: self.losses = losses   #need it later
                
            

        
        self.experiment_type = OmegaConf.select(cfg, "guidance.experiment_type")

        self.sampling_nT = self.cfg.features.sampling_nT
        self.noise_model = noise_model

        self.log_every_steps = cfg.general.log_every_steps
        self.number_chain_steps = cfg.general.number_chain_steps

        if(self.input_dims != None): self.gdim = self.input_dims.guidance
        else:                        self.gdim = 0

        #this will be used as the guidance when we need to calculate p(G_{t-1}|G_t) without the guidance
        if(self.cfg.guidance.trainable_cf == True and self.cfg.guidance.p_uncond >= 0):
            self.cf_null_token = torch.nn.parameter.Parameter(torch.randn(size = (1, self.gdim)))
        else:
            self.cf_null_token = torch.zeros(size = (1, self.gdim))

        self.charges_policy = cfg.features.charges_policy
        assert self.charges_policy in ['no', 'partial', 'dictionary'], "unimplemented charges policy"

        self.split_check = np.zeros(self.T)

    def corrupt_data(self, dense_data, train_step=False, corruption_step = -1):
        z_t = self.noise_model.apply_noise(dense_data, corruption_step)

        if(self.use_ins_del):
            # Unpacks the data
            z_t, masked_true, delt_sample, delt_target = z_t
            
            extra_data_delt = self.extra_features(delt_sample)
            pred_delt = self.forward(delt_sample, extra_data_delt, self.delt_model, train_step=train_step)

            #TODO: check empty z_t, skip this loop if empty
            return z_t, masked_true, pred_delt, delt_target
        else:
            return z_t, dense_data, None, None

    def training_step(self, data, i):
        if data.edge_index.numel() == 0:
            print("Found a batch with no edges. Skipping.")
            return
        dense_data = utils.to_dense(data, self.dataset_infos)
        z_t, target_data, pred_delt, target_data_delt = self.corrupt_data(dense_data, train_step=True)
        extra_data = self.extra_features(z_t)
        pred = self.forward(z_t, extra_data, self.model, train_step = True)
        loss, tl_log_dict = self.train_loss(masked_pred=pred, masked_true=target_data,
                                            log=i % self.log_every_steps == 0,
                                            masked_pred_delt=pred_delt, 
                                            masked_true_delt=target_data_delt)

        # if self.local_rank == 0:
        tm_log_dict = self.train_metrics(masked_pred=pred, masked_true=target_data,
                                         log=i % self.log_every_steps == 0)
        if tl_log_dict is not None:
            self.log_dict(tl_log_dict, batch_size=self.BS)
        if tm_log_dict is not None:
            self.log_dict(tm_log_dict, batch_size=self.BS)
        return loss

    def on_validation_epoch_start(self) -> None:
        if self.validation_loss_type == 'NLL':
            self.val_nll.reset()
        else:
            self.val_loss.reset()

        self.val_metrics.reset()

    def validation_step(self, data, i):
        self.validation_step_preops(data, i)
        dense_data = utils.to_dense(data, self.dataset_infos)
        z_t, clean_data, pred_delt, target_data_delt = self.corrupt_data(dense_data)
        extra_data = self.extra_features(z_t)
        pred = self.forward(z_t, extra_data, self.model)
        if self.validation_loss_type == 'NLL':
            nll, log_dict = self.compute_val_loss(pred, z_t, clean_data=clean_data, test=False)
            return {'loss': nll}, log_dict
        else:
            loss, tl_log_dict = self.val_loss(masked_pred=pred, masked_true=clean_data,
                                            log=i % self.log_every_steps == 0,
                                            masked_pred_delt=pred_delt, 
                                            masked_true_delt=target_data_delt)
            return {'loss': loss}, tl_log_dict

    def validation_step_preops(self, data, i):
        pass

    def on_validation_epoch_end(self) -> None:
        if self.validation_loss_type == 'NLL':
            metrics = [self.val_nll.compute(), self.val_metrics.compute()]
            log_dict = {"val/epoch_NLL": metrics[0],
                        "val/X_kl": metrics[1]['XKl'] * self.T,
                        "val/E_kl": metrics[1]['EKl'] * self.T}
            
            if(self.use_charges):
                log_dict.update({"val/charges_kl": metrics[1]['ChargesKl'] * self.T})
            if(self.use_3d):
                log_dict.update({"val/pos_mse": metrics[1]['PosMSE'] * self.T})

            self.log_dict(log_dict, on_epoch=True, on_step=False, sync_dist=True)
            if wandb.run:
                wandb.log(log_dict)

            print_str = []
            for key, val in log_dict.items():
                new_val = f"{val:.2f}"
                print_str.append(f"{key}: {new_val} -- ")
            print_str = ''.join(print_str)
            print(f"Epoch {self.current_epoch}: {print_str}."[:-4])

            # Log val nll with default Lightning logger, so it can be monitored by checkpoint callback
            val_nll = metrics[0]
            self.log("val/epoch_NLL", val_nll, sync_dist=True)

            if val_nll < self.best_val_nll:
                self.best_val_nll = val_nll
            print('Val loss: %.4f \t Best val loss:  %.4f\n' % (val_nll, self.best_val_nll))
        else:
            log_dict = self.val_loss.log_epoch_metrics()

            if(self.use_charges): charges_str = f" charges: {log_dict['val_epoch/charges_CE']:.2f} --"
            else: charges_str = ""
            if(self.use_3d): pos_str = f"pos: {log_dict['val_epoch/pos_mse'] :.2f} -- "
            else: pos_str = ""
            if(self.use_ins_del): 
                ins_del_str = f"s: {log_dict['val_epoch/s_CE'] :.2f} -- " + \
                              f"delt: {log_dict['val_epoch/delt_CE'] :.2f} -- "
            else: ins_del_str = ""

            self.print(f"Epoch {self.current_epoch} finished: ",
                    f"batch loss: {log_dict['val_epoch/overall_loss'] :.2f} --", 
                    f"X: {log_dict['val_epoch/x_CE'] :.2f} --",
                    f" E: {log_dict['val_epoch/E_CE'] :.2f} --",
                    f" y: {log_dict['val_epoch/y_CE'] :.2f} --",
                    charges_str, pos_str, ins_del_str)

            self.log_dict(log_dict, on_epoch=True, on_step=False, sync_dist=True)
            if wandb.run:
                wandb.log(log_dict)

        self.val_counter += 1
        if self.name == "debug" or (self.val_counter % self.cfg.general.sample_every_val == 0 and
                                    self.current_epoch > 0 and self.cfg.general.samples_to_generate > 0):
            self.perform_validation_sampling()
        self.print(f"Val epoch {self.current_epoch} ends")

    def perform_validation_sampling(self):
        self.print(f"Sampling start")
        start = time.time()
        gen = self.cfg.general
        samples = self.sample_n_graphs(samples_to_generate=math.ceil(gen.samples_to_generate / max(len(gen.gpus), 1)),
                                    chains_to_save=gen.chains_to_save if self.local_rank == 0 else 0,
                                    samples_to_save=gen.samples_to_save if self.local_rank == 0 else 0,
                                    test=False)
        print(f'Done on {self.local_rank}. Sampling took {time.time() - start:.2f} seconds\n')
        print(f"Computing sampling metrics on {self.local_rank}...")
        self.val_sampling_metrics(samples, self.name, self.current_epoch, self.local_rank)

    def on_test_epoch_start(self):
        if self.local_rank == 0:
            utils.setup_wandb(self.cfg)
        self.test_nll.reset()
        self.test_metrics.reset()

    def test_step(self, data, i):
        """
        dense_data = utils.to_dense(data, self.dataset_infos)
        z_t, clean_data, delt_sample, delt_target = self.corrupt_data(dense_data)
        extra_data = self.extra_features(z_t)
        pred = self.forward(z_t, extra_data, self.model)
        """
        if self.validation_loss_type == 'NLL':
            nll, log_dict = self.compute_val_loss(pred, z_t, clean_data=clean_data, test=False)
        else:
            nll = torch.Tensor([0.1])
            log_dict = {"kl prior": nll,
                    "Estimator loss terms": nll,
                    "log_pn": nll,
                    'test_nll': nll}
        return {'loss': nll}, log_dict

    def on_test_epoch_end(self) -> None:
        """ Measure likelihood on a test set and compute stability metrics. """
        #TODO: fix
        """
        metrics = [self.test_nll.compute(), self.test_metrics.compute()]
        test_nll = metrics[0]
        print(f'Test loss: {test_nll :.4f}')
        log_dict = {"test/epoch_NLL": metrics[0],
                    "test/X_kl": metrics[1]['XKl'] * self.T,
                    "test/E_kl": metrics[1]['EKl'] * self.T}
        
        if(self.use_charges):
            log_dict.update({"test/charges_kl": metrics[1]['ChargesKl'] * self.T})
        if(self.use_3d):
            log_dict.update({"test/pos_mse": metrics[1]['PosMSE'] * self.T})

        self.log_dict(log_dict, sync_dist=True)

        print_str = []
        for key, val in log_dict.items():
            new_val = f"{val:.4f}"
            print_str.append(f"{key}: {new_val} -- ")
        print_str = ''.join(print_str)
        print(f"Epoch {self.current_epoch}: {print_str}."[:-4])

        if wandb.run:
            wandb.log(log_dict)
        """
        print(f"Sampling start on GR{self.global_rank}")
        start = time.time()
        print(f"Samples to generate: {self.cfg.general.final_model_samples_to_generate}")
        print(f"Samples to save: {self.cfg.general.final_model_samples_to_save}")
        samples = self.sample_n_graphs(samples_to_generate=self.cfg.general.final_model_samples_to_generate,
                                       chains_to_save=self.cfg.general.final_model_chains_to_save,
                                       samples_to_save=self.cfg.general.final_model_samples_to_save,
                                       test=True)
        print("Saving the generated graphs")
        filename = f'generated_samples1.txt'
        for i in range(2, 10):
            if os.path.exists(filename):
                filename = f'generated_samples{i}.txt'
            else:
                break
        with open(filename, 'w') as f:
            for graph in samples:
                f.write(f"N={graph.num_nodes}\n")
                # X:
                nodes = graph.node_types.tolist()
                f.write("X: \n")
                for at in nodes:
                    f.write(f"{at} ")
                f.write("\n")

                f.write("E: \n")
                for bond_list in graph.edge_types:
                    for bond in bond_list:
                        f.write(f"{bond} ")
                    f.write("\n")

                # Charges
                if(self.use_charges):
                    charges = graph.charges
                    f.write("charges: \n")
                    for c in charges:
                        f.write(f"{c} ")
                    f.write("\n")

                # Pos
                if(self.use_3d):
                    positions = graph.positions
                    f.write("pos: \n")
                    for line in positions:
                        f.write(f"{line[0]} {line[1]} {line[2]} \n")
                f.write("\n")
        print("Saved.")
        print("Computing sampling metrics...")
        self.test_sampling_metrics(samples, self.name, self.current_epoch, self.local_rank)
        print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
        print(f"Test ends.")

        print("split_check:\n", self.split_check)

    def kl_prior(self, clean_data, node_mask):
        """Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).

        This is essentially a lot of work for something that is in practice negligible in the loss. However, you
        compute it so that you see it when you've made a mistake in your noise schedule.
        """
        # Compute the last alpha value, alpha_T.
        ones = torch.ones((clean_data.X.size(0), 1), dtype=torch.long, device=clean_data.X.device)
        Ts = self.T * ones
        Qtb = self.noise_model.get_Qt_bar(t_int=Ts)

        # Compute transition probabilities
        probX = clean_data.X @ Qtb.X + 1e-7  # (bs, n, dx_out)
        probE = clean_data.E @ Qtb.E.unsqueeze(1) + 1e-7  # (bs, n, n, de_out)
        probX = probX / probX.sum(dim=-1, keepdims=True)
        probE = probE / probE.sum(dim=-1, keepdims=True)
        assert probX.shape == clean_data.X.shape

        bs, n, _ = probX.shape
        limit_dist = self.noise_model.get_limit_dist().device_as(probX)

        # Set masked rows , so it doesn't contribute to loss
        probX[~node_mask] = limit_dist.X.float()
        diag_mask = ~torch.eye(node_mask.size(1), device=node_mask.device, dtype=torch.bool).unsqueeze(0)
        probE[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = limit_dist.E.float()

        #TODO: in case of NaNs, try with this:
        #ALSO, add the same code for the charges
        """
        probX[probX == .0] = 1e-6
        probE[probE == .0] = 1e-6
        limit_dist_X[limit_dist_X == .0] = 1e-6
        limit_dist_E[limit_dist_E == .0] = 1e-6
        """

        kl_distance_X = F.kl_div(input=probX.log(), target=limit_dist.X[None, None, :], reduction='none')
        kl_distance_E = F.kl_div(input=probE.log(), target=limit_dist.E[None, None, None, :], reduction='none')

        summed_data = sum_except_batch(kl_distance_X) + sum_except_batch(kl_distance_E)

        if(self.use_charges):
            probc = clean_data.charges @ Qtb.charges + 1e-7
            probc = probc / probc.sum(dim=-1, keepdims=True)
            probc[~node_mask] = limit_dist.charges.float()
            kl_distance_c = F.kl_div(input=probc.log(), target=limit_dist.charges[None, None, :], reduction='none')
            summed_data += sum_except_batch(kl_distance_c)

        # Compute the kl on the positions
        if(self.use_3d):
            last = self.T * torch.ones((bs, 1), device=clean_data.pos.device, dtype=torch.long)
            mu_T = self.noise_model.get_alpha_bar(t_int=last, key='p')[:, :, None] * clean_data.pos
            sigma_T = self.noise_model.get_sigma_bar(t_int=last, key='p')[:, :, None]
            subspace_d = 3 * node_mask.long().sum(dim=1)[:, None, None] - 3
            kl_distance_pos = subspace_d * diffusion_utils.gaussian_KL(mu_T, sigma_T)
            summed_data += sum_except_batch(kl_distance_pos)

        return (summed_data)

    def compute_Lt(self, clean_data, pred, z_t, s_int, node_mask, test):
        # TODO: ideally all probabilities should be computed in log space
        t_int = z_t.t_int

        #TODO: we may have to generalize if pred.pos and pred.charges are None 
        charges_features = F.softmax(pred.charges, dim=-1) if(self.use_charges) else None
        pos_features     = pred.pos if(self.use_3d) else None
        pred = utils.PlaceHolder(X=F.softmax(pred.X, dim=-1), charges=charges_features, guidance=clean_data.guidance,
                                 E=F.softmax(pred.E, dim=-1), pos=pos_features, node_mask=clean_data.node_mask, y=None)

        Qtb = self.noise_model.get_Qt_bar(z_t.t_int)
        Qsb = self.noise_model.get_Qt_bar(s_int)
        Qt = self.noise_model.get_Qt(t_int)

        # Compute distributions to compare with KL
        bs, n, d = clean_data.X.shape
        prob_true = diffusion_utils.posterior_distributions(clean_data=clean_data, noisy_data=z_t,
                                                            Qt=Qt, Qsb=Qsb, Qtb=Qtb)
        prob_true.E = prob_true.E.reshape((bs, n, n, -1))
        prob_pred = diffusion_utils.posterior_distributions(clean_data=pred, noisy_data=z_t,
                                                            Qt=Qt, Qsb=Qsb, Qtb=Qtb)
        prob_pred.E = prob_pred.E.reshape((bs, n, n, -1))

        # Reshape and filter masked rows
        prob_true = diffusion_utils.mask_distributions(prob_true, node_mask)
        prob_pred = diffusion_utils.mask_distributions(prob_pred, node_mask)

        #TODO: in case of NaN, try this (add charges code)
        """
        prob_true.X[prob_true.X == 0.0] = 1e-6
        prob_pred.X[prob_pred.X == 0.0] = 1e-6
        prob_true.E[prob_true.E == 0.0] = 1e-6
        prob_pred.E[prob_pred.E == 0.0] = 1e-6
        """

        # Compute the prefactor for KL on the positions
        if(self.use_3d):
            nm = self.noise_model
            prefactor = ((nm.get_alpha_bar(t_int=s_int, key='p') / (nm.get_sigma_bar(t_int=s_int, key='p') + 1e-6)) ** 2 -
                        (nm.get_alpha_bar(t_int=t_int, key='p') / (nm.get_sigma_bar(t_int=t_int, key='p') + 1e-6)) ** 2)

            prefactor[torch.isnan(prefactor)] = 1
            prefactor = torch.sqrt(0.5 * prefactor).unsqueeze(-1)
            prob_true.pos = prefactor * clean_data.pos
            prob_pred.pos = prefactor * pred.pos
        
        metrics = (self.test_metrics if test else self.val_metrics)(prob_pred, prob_true)

        summed_metrics = metrics['XKl'] + metrics['EKl']
        if(self.use_charges):summed_metrics += metrics['ChargesKl']
        if(self.use_3d):     summed_metrics += metrics['PosMSE']
        return self.T * (summed_metrics)

    def compute_val_loss(self, pred, z_t, clean_data, test=False):
        """Computes an estimator for the variational lower bound, or the simple loss (MSE).
           pred: (batch_size, n, total_features)
           noisy_data: dict
           X, E, y : (bs, n, dx),  (bs, n, n, de), (bs, dy)
           node_mask : (bs, n)
           Output: nll (size 1)
        """
        node_mask = z_t.node_mask
        t_int = z_t.t_int
        s_int = t_int - 1

        #Problematic since this computes the log prob of a graph to
        #have a certain amount of nodes and this cannot be computed
        #in didigress where such amount may change...
        # 1.
        N = node_mask.sum(1).long()
        log_pN = self.node_dist.log_prob(N)

        # 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero.
        kl_prior = self.kl_prior(clean_data, node_mask)

        # 3. Diffusion loss
        loss_all_t = self.compute_Lt(clean_data, pred, z_t, s_int, node_mask, test)

        # Combine terms
        nlls = - log_pN + kl_prior + loss_all_t
        # Update NLL metric object and return batch nll
        nll = (self.test_nll if test else self.val_nll)(nlls)        # Average over the batch

        log_dict = {"kl prior": kl_prior.mean(),
                  "Estimator loss terms": loss_all_t.mean(),
                  "log_pn": log_pN.mean(),
                  'test_nll' if test else 'val_nll': nll}
        return nll, log_dict

    @torch.no_grad()
    def sample_batch(self, n_nodes: list, number_chain_steps: int = 50, batch_id: int = 0, keep_chain: int = 0,
                     save_final: int = 0, test=True, guidance=None, z_t=None, start_t=-1):
        """
        :param batch_id: int
        :param n_nodes: list of int containing the number of nodes to sample for each graph
        :param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes
        :param save_final: int: number of predictions to save to file
        :param keep_chain: int: number of chains to save to file
        :param keep_chain_steps: number of timesteps to save for each chain
        :return: molecule_list. Each element of this list is a tuple (node_types, charges, positions)
        """
        print(f"Sampling a batch with {len(n_nodes)} graphs. Saving {save_final} visualization and {keep_chain} full chains.")
        assert keep_chain >= 0
        assert save_final >= 0
        n_nodes = torch.Tensor(n_nodes).long().to(self.device)
        if self.sampling_nT != -1:
            print(f"Replacing all n_nodes to start with {self.sampling_nT} node(s).")
            n_nodes = torch.ones_like(n_nodes).long().to(self.device)*self.sampling_nT

        batch_size = len(n_nodes)
        n_max = torch.max(n_nodes).item()
        # Build the masks
        arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1)
        node_mask = arange < n_nodes.unsqueeze(1)
        # Sample noise  -- z has size (n_samples, n_nodes, n_features)
        if z_t is None: z_T = self.noise_model.sample_limit_dist(node_mask=node_mask)
        else:          
            z_T = z_t
            # start_guidance_val = z_t.guidance[0,:].item()
            # end_guidance_val = guidance[0,:].item()
            # print("start_guidance_val:", start_guidance_val,
            #       "end_guidance_val", end_guidance_val)
            
            # end_T = 75
            # lin_guid = np.linspace(start_guidance_val, end_guidance_val, start_t-end_T)
            # lin_guid = np.append(lin_guid, [end_guidance_val]*end_T)
            # lin_guid = np.flip(lin_guid)

        if(guidance == None):
            guidance = self.cf_null_token.repeat((batch_size, 1))
        z_T.guidance = guidance

        print("number_chain_steps: ", number_chain_steps)
        print("starting n_nodes: ", n_nodes)

        assert (z_T.E == torch.transpose(z_T.E, 1, 2)).all()
        assert number_chain_steps < self.T

        use_charges = self.use_charges
        use_3d      = self.use_3d

        if self.use_ins_del: n_max = self.max_n
        else:                n_max = z_T.X.size(1)
        if(use_charges): charges=torch.zeros((number_chain_steps, keep_chain, n_max), dtype=torch.long)
        else: charges = None

        if(use_3d): pos=torch.zeros((number_chain_steps, keep_chain, n_max, 3))
        else: pos = None

        chains = None
        if keep_chain > 0:
            chains = utils.PlaceHolder(X=torch.zeros((number_chain_steps, keep_chain, n_max), dtype=torch.long),
                                E=torch.zeros((number_chain_steps, keep_chain, n_max, n_max)),
                                charges=charges,
                                pos=pos,
                                y=None,
                                guidance=guidance)
        
        z_t = z_T
        if start_t < 0:
            start_t = self.T

        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
        for s_int in reversed(range(0, start_t, 1 if test else self.cfg.general.faster_sampling)):
            s_array = s_int * torch.ones((batch_size, 1), dtype=torch.long, device=z_t.X.device)
            # torch.set_printoptions(2, 10000)
            # print("s int: ", s_int)

            # if start_t > 0:
            #     guidance = torch.full_like(z_t.guidance, fill_value = lin_guid[s_int])
            #     z_t.guidance = guidance
            #     print("guidance, ", guidance.reshape(-1))

            # print("s_int:", s_int, " z_t.X.size: ", z_t.X.size())
            z_s = self.sample_zs_from_zt(z_t=z_t, s_int=s_array)

            """
            if(z_s.delt_mask.any()):
                torch.set_printoptions(2, 10000)
                print("detected delt at s=", s_int)
                print(f"node mask: \n{z_s.node_mask.int()}")
                print(f"delt mask (compressed): \n{z_s.delt_mask.squeeze(-1).int()}")
                print(f"X_data: {z_s.X.size()}\n{torch.argmax(F.pad(z_s.X, (1,0,0,0,0,0)), dim=-1)}")
                print(f"E_data: {z_s.E.size()}\n{torch.argmax(F.pad(z_s.E, (1,0,0,0,0,0,0,0)), dim=-1)}")
                print(f"c_data: {z_s.charges.size()}\n{torch.argmax(F.pad(z_s.charges, (1,0,0,0,0,0)), dim=-1)}")
            """
            # Save the first keep_chain graphs
            if (s_int * number_chain_steps) % start_t == 0 and keep_chain > 0:
                write_index = number_chain_steps - 1 - ((s_int * number_chain_steps) // start_t)
                print("chain saver. write index: ", write_index)
                discrete_z_s = z_s.collapse(self.dataset_infos.collapse_charges, self.use_charges)

                X_data = discrete_z_s.X[:keep_chain]
                E_data = discrete_z_s.E[:keep_chain]

                pad_size = n_max - X_data.size(1)
                pad_condition = pad_size > 0 and self.use_ins_del
                # print("pad_size: ", pad_size, ", n_max: ", n_max, ", X_data.size() (size 1 is the one that matters): ", X_data.size())
                if pad_condition:
                    X_data = F.pad(X_data, (0, pad_size, 0, 0), value=-1)
                    E_data = F.pad(E_data, (0, pad_size, 0, pad_size, 0, 0), value=-1)

                chains.X[write_index] = X_data
                chains.E[write_index] = E_data

                if(self.use_charges):
                    c_data = discrete_z_s.charges[:keep_chain]
                    if pad_condition:
                        c_data = F.pad(c_data, (0, pad_size, 0, 0), value=-1)
                    chains.charges[write_index] = c_data
                if(self.use_3d):
                    pos_data = discrete_z_s.pos[:keep_chain]
                    if pad_condition:
                        pos_data = F.pad(pos_data, (0, pad_size, 0, 0), value=-1)
                    chains.pos[write_index] = pos_data

            z_t = z_s
            
            """
            z_t_n_nodes = z_t.node_mask.sum(-1).squeeze(-1)

            discrete_z_t = z_t.collapse(self.dataset_infos.collapse_charges, self.use_charges)
            print("z_t AFTER:\n", discrete_z_t.X)
            print("===================================")

            
            graphs_list_tmp = self.graph_list_maker(sampled=discrete_z_t, chains=None,
                                                    batch_size=z_t.X.size(0),
                                                    atom_decoder=self.dataset_infos.atom_decoder, 
                                                    keep_chain=None, n_nodes=z_t_n_nodes, 
                                                    use_charges=use_charges, use_3d=use_3d,
                                                    charges_policy=self.charges_policy)
            
            n_splitted = 0
            for i, mol in enumerate(graphs_list_tmp):
                smile = Chem.MolToSmiles(mol.rdkit_mol)
                print(i, ": ", smile)

                mol_frags = Chem.rdmolops.GetMolFrags(mol.rdkit_mol, asMols=True, sanitizeFrags=False)
                if(len(mol_frags) > 1):
                    n_splitted += 1

            self.split_check[s_int] += n_splitted
            print("===================================")
            """

        # Sample final data
        sampled = z_t.collapse(collapse_charges=self.dataset_infos.collapse_charges, use_charges=use_charges)
        
        final_n_nodes = z_t.node_mask.sum(dim=-1).unsqueeze(-1)
        # print("z_t.node_mask:", z_t.node_mask.int())
        print("final_n_nodes:", final_n_nodes.squeeze(-1))

        graphs_list = self.graph_list_maker(sampled=sampled, chains=chains,
                                            batch_size=batch_size,
                                            atom_decoder=self.dataset_infos.atom_decoder, 
                                            keep_chain=keep_chain, n_nodes=final_n_nodes, 
                                            use_charges=use_charges, use_3d=use_3d,
                                            charges_policy=self.charges_policy)
        
        if(self.visualizer != None):
            # Visualize chains
            if keep_chain > 0:
                self.print('Batch sampled. Visualizing chains starts!')
                chains_path = os.path.join(os.getcwd(), f'chains/epoch{self.current_epoch}/',
                                        f'batch{batch_id}_GR{self.global_rank}')
                os.makedirs(chains_path, exist_ok=True)

                self.visualizer.visualize_chains(chains_path, chains,
                                            num_nodes=final_n_nodes[:keep_chain],
                                            atom_decoder=self.dataset_infos.atom_decoder,
                                            use_ins_del=self.use_ins_del)
                
            if save_final > 0:
                self.print(f'Visualizing {save_final} individual molecules...')

            # Visualize the final molecules
            current_path = os.getcwd()
            result_path = os.path.join(current_path, f'graphs/epoch{self.current_epoch}_b{batch_id}/')
            _ = self.visualizer.visualize(result_path, graphs_list, num_molecules_to_visualize=save_final)
            self.print("Visualizing done.")
        return graphs_list

    def sample_zs_from_zt(self, z_t, s_int):
        """Samples from zs ~ p(zs | zt). Only used during sampling.
           if last_step, return the graph prediction as well"""
        if self.use_ins_del:
            if not self.delt_loaded:
                model_kwargs  = {'cfg': self.cfg, 'dataset_infos': self.dataset_infos, 'losses': self.losses,
                                 'noise_model': self.noise_model, 'visualizer': self.visualizer}
                self.delt_model = DiGress.load_from_checkpoint(self.delt_model_path, **model_kwargs)
                self.delt_model.cfg.general.gpus = self.cfg.general.gpus

                self.delt_model = self.delt_model.delt_model
                self.delt_model = self.delt_model.to(self.cfg.general.gpus[0])
                self.delt_loaded = True

            # discrete_z_t = z_t.collapse(self.dataset_infos.collapse_charges, self.use_charges)
            # print("z_t before ins:\n", discrete_z_t.X)
            # current timestep. s_int is an array where each element is the same,
            # so we can just extract the first element.
            t_int = s_int[0,0] + 1

            zeta_t = self.noise_model.get_zeta_d(t_int=t_int) 

            batch_n_nodes = z_t.node_mask.sum(dim=-1)
            #DELt can be inserted if and only if it is possible that a node can go
            # from t_int -1 to t_int and turn into a DELt. EG: zt > 0.
            if zeta_t > 0 and not self.freeze_n_nodes_at_sampling:
                extra_data = self.extra_features(z_t)
                pred_delt = self.forward(z_t, extra_data, self.delt_model)
                n_delt = torch.multinomial(F.softmax(pred_delt.y, dim=-1), 
                                        1, replacement=True).reshape(-1)

                # We avoid overcrowding by putting a limit on how many nodes
                # a batch sample can have
                batch_new_n_nodes = batch_n_nodes + n_delt
                n_delt = torch.where(batch_new_n_nodes > self.max_n,
                                    torch.clip(self.max_n - batch_n_nodes, min=0),
                                    n_delt)

                # If we have at least one element in the batch that needs to insert one node:
                if n_delt.any():
                    z_t.insert_delt(n_delt)
 
            # discrete_z_t = z_t.collapse(self.dataset_infos.collapse_charges, self.use_charges)
            # print("z_t before ins:\n", discrete_z_t.X)

        # We must recompute extra data after we have inserted z_t as well
        extra_data = self.extra_features(z_t)
        pred = self.forward(z_t, extra_data, self.model)

        # Delete elements predicted to have been inserted at step s_int
        if self.use_ins_del:
            # We don't need it as a tensor anymore
            t_int = t_int.item()
            
            zeta_s = self.noise_model.get_zeta_d(t_int=s_int[0,0])
            # removing a node requires that the activation time is == t_int.
            # to do so, a node must have been inserted at step s_int. And this
            # is possible only if zeta_s_int > 0.
                
            if self.freeze_n_nodes_at_sampling or \
                (zeta_s <= 0 and t_int <= self.train_loss.right_interval_start):
                predicted_activation_t = torch.zeros_like(z_t.insert_time, device = z_t.X.device)
            else:
                ins_times = pred.insert_time.flatten(0, 1)
                sz1, sz2  = pred.insert_time.size(0), pred.insert_time.size(1)
                ins_times[:, (t_int + 1):]   = -float("inf") # NOTE: DO NOT reactivate this line 
                ins_times[:, 1:self.train_loss.left_interval_end] = -float("inf")
                ins_times = F.softmax(ins_times, dim=-1)
                # if t_int < self.T - 2:
                #     ins_times_residual = ins_times[:, t_int:].sum(dim=-1)
                #     ins_times[:, t_int+1:] = 0
                #     ins_times[:, t_int] += ins_times_residual
                predicted_activation_t = torch.multinomial(ins_times, num_samples=1)
                predicted_activation_t = predicted_activation_t.reshape((sz1, sz2, 1))

            # Updates the activation time
            z_t.insert_time = predicted_activation_t.int()
            pred.insert_time = predicted_activation_t.int() 

            if zeta_s > 0:
                # if the activation time is the current time step, we can proceed and remove
                # them. NOTE: no need to check if t_int == 0 since it can never be equal to 0
                # as that would require s_int = -1 and that's not allowed by the code
                del_mask = predicted_activation_t == t_int

                z_t.remove_elements(del_mask.squeeze(-1))
                pred.remove_elements(del_mask.squeeze(-1))

                # TODO: leave at least one node for each graph in case they do
                # all get removed.

        z_s = self.noise_model.sample_zs_from_zt_and_pred(z_t=z_t, pred=pred, s_int=s_int)
        return z_s

    def sample_n_graphs(self, samples_to_generate: int, chains_to_save: int, samples_to_save: int, test: bool, guidance=None, z_t=None):
        if samples_to_generate <= 0:
            return []

        chains_left_to_save = chains_to_save

        samples = []
        # The first graphs are sampled without sorting the sizes, so that the visualizations are not biased
        first_sampling = min(samples_to_generate, max(samples_to_save, chains_to_save))
        is_guidance = not (guidance is None)
        if is_guidance: first_sampling = 0

        if z_t is not None:
            samples.extend(self.sample_batch(n_nodes=z_t.node_mask.sum(dim=-1, keepdim=True).squeeze(-1).tolist(), 
                        batch_id=0, save_final=0, keep_chain=0,
                        number_chain_steps=0, test=test, guidance=guidance, z_t=z_t, start_t=self.corruption_step))
            return samples

        if first_sampling > 0:
            n_nodes = self.node_dist.sample_n(first_sampling, self.device)
            current_max_size = 0
            current_n_list = []
            for i, n in enumerate(n_nodes):
                potential_max_size = max(current_max_size, n)
                if self.cfg.dataset.adaptive_loader:
                    potential_ebs = effective_batch_size(potential_max_size, self.cfg.train.reference_batch_size,
                                                         sampling=True)
                else:
                    potential_ebs = int(1.8 * self.cfg.train.batch_size)  # No need to make a backward pass
                if potential_ebs > len(current_n_list) or len(current_n_list) == 0:
                    current_n_list.append(n)
                    current_max_size = potential_max_size
                else:
                    chains_save = max(min(chains_left_to_save, len(current_n_list)), 0)
                    samples.extend(self.sample_batch(n_nodes=current_n_list, batch_id=i,
                                                     save_final=len(current_n_list), keep_chain=chains_save,
                                                     number_chain_steps=self.number_chain_steps, test=test, guidance=guidance))
                    chains_left_to_save -= chains_save
                    current_n_list = [n]
                    current_max_size = n
            chains_save = max(min(chains_left_to_save, len(current_n_list)), 0)
            samples.extend(self.sample_batch(n_nodes=current_n_list, batch_id=i + 1,
                                             save_final=len(current_n_list), keep_chain=chains_save,
                                             number_chain_steps=self.number_chain_steps, test=test, guidance=guidance))
            if samples_to_generate - first_sampling <= 0:
                return samples

        # The remaining graphs are sampled in decreasing graph size
        n_nodes = self.node_dist.sample_n(samples_to_generate - first_sampling, self.device)

        if self.cfg.dataset.adaptive_loader:
            n_nodes = torch.sort(n_nodes, descending=True)[0]
        max_size = 0
        current_n_list = []
        for i, n in enumerate(n_nodes):
            max_size = max(max_size, n)
            potential_ebs = effective_batch_size(max_size, self.cfg.train.reference_batch_size, sampling=True) \
                            if self.cfg.dataset.adaptive_loader else 1.8 * self.cfg.train.batch_size
            if potential_ebs > len(current_n_list) or len(current_n_list) == 0 or is_guidance:
                current_n_list.append(n)
            else:
                samples.extend(self.sample_batch(n_nodes=current_n_list, test=test, guidance=guidance))
                current_n_list = [n]
                max_size = n
        samples.extend(self.sample_batch(n_nodes=current_n_list, test=test, guidance=guidance))

        return samples

    @property
    def BS(self):
        return self.cfg.train.batch_size

    def forward(self, z_t, extra_data, model_ref, train_step=False):
        assert z_t.node_mask is not None
        model_input = z_t.copy()
        model_input.X = torch.cat((z_t.X, extra_data.X), dim=2).float()
        model_input.E = torch.cat((z_t.E, extra_data.E), dim=3).float()
        model_input.y = torch.hstack((z_t.y, extra_data.y, z_t.t)).float()
        return model_ref(model_input)

    def on_train_epoch_end(self) -> None:
        self.print(f"Train epoch {self.current_epoch} ends")
        tle_log = self.train_loss.log_epoch_metrics()

        if(self.use_charges): charges_str = f" charges: {tle_log['train_epoch/charges_CE']:.2f} --"
        else: charges_str = ""
        if(self.use_3d): pos_str = f"pos: {tle_log['train_epoch/pos_mse'] :.2f} -- "
        else: pos_str = ""
        if(self.use_ins_del): 
            ins_del_str = f"s: {tle_log['train_epoch/s_CE'] :.2f} -- " + \
                          f"delt: {tle_log['train_epoch/delt_CE'] :.2f} -- "
        else: ins_del_str = ""

        self.print(f"Epoch {self.current_epoch} finished: ",
                   f"X: {tle_log['train_epoch/x_CE'] :.2f} --",
                   f" E: {tle_log['train_epoch/E_CE'] :.2f} --",
                   f" y: {tle_log['train_epoch/y_CE'] :.2f} --",
                   charges_str, pos_str, ins_del_str)
        
        
        self.print(f" {time.time() - self.start_epoch_time:.1f}s ")
        
        self.log_dict(tle_log, batch_size=self.BS)
        # if self.local_rank == 0:
        tme_log = self.train_metrics.log_epoch_metrics(self.current_epoch, self.local_rank)
        if tme_log is not None:
            self.log_dict(tme_log, batch_size=self.BS)
        if wandb.run:
            wandb.log({"epoch": self.current_epoch}, commit=False)

    def on_train_epoch_start(self) -> None:
        self.print("Starting epoch", self.current_epoch)
        self.start_epoch_time = time.time()
        self.train_loss.reset()
        self.train_metrics.reset()

    def on_fit_start(self) -> None:
        self.train_iterations = 100      # TODO: fix -- previously was len(self.trainer.datamodule.train_dataloader())
        if self.local_rank == 0:
            utils.setup_wandb(self.cfg)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.cfg.train.lr, amsgrad=True,
                                 weight_decay=self.cfg.train.weight_decay)
