from typing import Any
from pytorch_lightning.utilities.types import STEP_OUTPUT
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import time
import wandb
import os

from models.transformer_model import GraphTransformer
from metrics.train_metrics import TrainLossBFN
from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
from src import utils
from src.py_utils import ctxmgr

## For Debug
import pdb
import datetime, pytz
import numpy as np
from tqdm import tqdm
from torch_scatter import scatter_mean
import torch.distributions as dist
from pytorch_lightning.utilities import rank_zero_only
from src.diffusion import diffusion_utils


class BFN_Discrete(pl.LightningModule):
    def __init__(
        self,
        cfg,
        dataset_infos,
        train_metrics,
        sampling_metrics,
        visualization_tools,
        extra_features,
        domain_features,
    ):
        super().__init__()

        input_dims = dataset_infos.input_dims
        output_dims = dataset_infos.output_dims
        # output_dims["y"] = 1  # for multi-GPU
        # # The prior distribution. Currently, this is the histogram prior proposed and used in digress. See distribution.py for more details.
        nodes_dist = dataset_infos.nodes_dist

        self.train_losses = []

        self.cfg = cfg
        self.name = cfg.exp_name
        self.model_dtype = torch.float32
        self.sample_steps = (
            cfg.model.sample_steps
        ) 

        self.Xdim = input_dims["X"]
        self.Edim = input_dims["E"]
        self.ydim = input_dims["y"]
        self.Xdim_output = output_dims["X"]
        self.Edim_output = output_dims["E"]
        self.ydim_output = output_dims["y"]
        # self.ydim_output = output_dims["y"] + 1  # for multi-GPU
        self.node_dist = nodes_dist

        self.num_nodes = None

        self.dataset_info = dataset_infos

        self.lambda_list = [
            self.cfg.model.lambda_train_node,
            self.cfg.model.lambda_train_edge,
            self.cfg.model.lambda_train_y,
        ]

        self.visualization_tools = visualization_tools
        self.extra_features = extra_features
        self.domain_features = domain_features
        self.sample_steps = self.cfg.model.sample_steps

        ## How to sample extra features from noisy graphs
        self.extra_mode = self.cfg.model.extra_mode
        self.n_iid = self.cfg.model.n_iid

        """BFN"""
        """------------------------------------------------"""
        self.beta_node = cfg.model.beta_node
        self.beta_edge = cfg.model.beta_edge

        self.beta_node_init = cfg.model.beta_node_init
        self.beta_edge_init = cfg.model.beta_edge_init

        # -----------------------
        self.t_min = cfg.model.t_min
        # self.continuous_t = cfg.model.time == "continuous"

        self.node_time_scheduler = cfg.model.node_time_scheduler
        self.edge_time_scheduler = cfg.model.edge_time_scheduler
        self.with_discretised_time = cfg.model.discretised_time

        """------------------------------------------------"""

        # self.train_loss = TrainLossBFN(self.lambda_list, sample_steps=self.sample_steps, continuous_t=self.continuous_t)
        self.train_loss = TrainLossBFN(
            self.lambda_list,
            sample_steps=self.sample_steps,
            continuous_t=(not self.with_discretised_time),
        )

        # self.train_loss = TrainLossBFN(self.cfg.model.lambda_train)
        """D3PM Loss and Metrics"""
        """------------------------------------------------"""

        self.val_nll = NLL()
        self.val_X_kl = SumExceptBatchKL()
        self.val_E_kl = SumExceptBatchKL()
        self.val_X_logp = SumExceptBatchMetric()
        self.val_E_logp = SumExceptBatchMetric()

        self.test_nll = NLL()
        self.test_X_kl = SumExceptBatchKL()
        self.test_E_kl = SumExceptBatchKL()
        self.test_X_logp = SumExceptBatchMetric()
        self.test_E_logp = SumExceptBatchMetric()

        self.train_metrics = train_metrics
        self.sampling_metrics = sampling_metrics
        """------------------------------------------------"""
        self.model = GraphTransformer(
            n_layers=cfg.model.n_layers,
            input_dims=input_dims,
            hidden_mlp_dims=cfg.model.hidden_mlp_dims,
            hidden_dims=cfg.model.hidden_dims,
            output_dims=output_dims,
            act_fn_in=nn.ReLU(),
            act_fn_out=nn.ReLU(),
        )

        if cfg.model.transition == "uniform":
            x_limit = torch.ones(self.Xdim_output) / self.Xdim_output
            e_limit = torch.ones(self.Edim_output) / self.Edim_output
            y_limit = torch.ones(self.ydim_output) / self.ydim_output
            self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit)
        elif cfg.model.transition == "marginal":
            node_types = self.dataset_info.node_types.float()
            x_marginals = node_types / torch.sum(node_types)

            edge_types = self.dataset_info.edge_types.float()
            e_marginals = edge_types / torch.sum(edge_types)
            print(
                f"Marginal distribution of the classes: {x_marginals} for nodes, {e_marginals} for edges"
            )

            self.limit_dist = utils.PlaceHolder(
                X=x_marginals,
                E=e_marginals,
                y=torch.ones(self.ydim_output) / self.ydim_output,
            )
        else:
            raise NotImplementedError(
                f"Unknow transition model (prior distribution) {cfg.model.transition}"
            )

        self.print_logit = cfg.general.print_logit

        self.save_hyperparameters(cfg.todict())
        self.start_epoch_time = None
        self.train_iterations = None
        self.val_iterations = None
        self.log_every_steps = cfg.general.log_every_steps
        self.number_chain_steps = cfg.general.number_chain_steps
        self.best_val_nll = 1e8
        self.val_counter = 0

    def discrete_var_bayesian_flow(
        self, t, beta1, x, K, time_scheduler, beta_init=None
    ):
        """
        Compute the Baysian Flow distribution (P_F) at time t
        x: (bs, N, K)
        t: (bs, 1)
        beta1: scalor
        """
        theta_0 = None
        if len(x.shape) == 3:
            # node
            if time_scheduler == "quad":
                beta = beta1 * (t[:, None] ** 2).repeat(1, x.shape[1], 1)  # (bs, N, 1)
            elif time_scheduler == "linear":
                beta = beta1 * t[:, None].repeat(1, x.shape[1], 1)
            elif time_scheduler == "third":
                beta = beta1 * (t[:, None] ** 3).repeat(1, x.shape[1], 1)
            # elif self.time_scheduler == "discrete":
            elif time_scheduler == "hybrid":  # essentially hybrid of linear and quad.
                assert beta_init is not None
                beta = (beta1 - beta_init) * (t[:, None] ** 2).repeat(
                    1, x.shape[1], 1
                ) + beta_init * t[:, None].repeat(1, x.shape[1], 1)
            else:
                raise NotImplementedError
            theta_0 = self.limit_dist.type_as(x).X
        else:
            # edge
            if time_scheduler == "quad":
                beta = beta1 * (t[:, None, None] ** 2).repeat(
                    1, x.shape[1], x.shape[2], 1
                )
            elif time_scheduler == "linear":
                beta = beta1 * t[:, None, None].repeat(1, x.shape[1], x.shape[2], 1)
            elif time_scheduler == "third":
                beta = beta1 * (t[:, None, None] ** 3).repeat(
                    1, x.shape[1], x.shape[2], 1
                )
            elif (
                time_scheduler == "hybrid"
            ):  # 2beta1 - beta_init is also we need for the last step accuracy.
                assert beta_init is not None
                beta = (beta1 - beta_init) * (t[:, None, None] ** 2).repeat(
                    1, x.shape[1], x.shape[2], 1
                ) + beta_init * t[:, None, None].repeat(1, x.shape[1], x.shape[2], 1)
            else:
                raise NotImplementedError
            theta_0 = self.limit_dist.type_as(x).E
        # Flatten beta's preceeding dimensions to handle the different shapes of node and edge tensors

        # if self.time_schedule == "linear":

        beta = beta.reshape(-1, 1)

        # Sample y from N(beta_t*(K*e_x - 1), beta_t*K)
        one_hot_x = x.reshape(-1, K)  # (B*N, K)
        mean = beta * (K * one_hot_x - 1)
        std = (beta * K).sqrt()
        eps = torch.randn_like(mean)
        y = mean + std * eps
        # theta = F.softmax(y, dim=-1)
        theta_unnormalized = torch.exp(y) * theta_0
        theta = theta_unnormalized / theta_unnormalized.sum(dim=-1, keepdim=True)

        # Recover node and edge tensors' orignal shape
        if len(x.shape) == 3:
            # for node features
            theta = theta.reshape(x.shape[0], x.shape[1], K)
        elif len(x.shape) == 4:
            # for edge features
            theta = theta.reshape(x.shape[0], x.shape[1], x.shape[2], K)
            if self.cfg.force_symmetric_theta_E:
                # BUG: we need to Force theta_E to be symmetric
                theta = (theta + theta.transpose(1, 2)) / 2
        return theta

    def training_step(self, data, curr_step):
        if data.edge_index.numel() == 0:
            self.print("Found a batch with no edges. Skipping.")
            return

        dense_data, node_mask = utils.to_dense(
            data.x, data.edge_index, data.edge_attr, data.batch
        )

        dense_data = dense_data.mask(node_mask)
        X, E = dense_data.X, dense_data.E

        if self.cfg.overfit_batches == 1:
            all_counts = {}
            unique, counts = torch.unique(data.batch, return_counts=True)
            for count in counts:
                if count.item() in all_counts:
                    all_counts[count.item()] += 1
                else:
                    all_counts[count.item()] = 1
            from src.diffusion.distributions import DistributionNodes

            self.node_dist = DistributionNodes(all_counts)

        N = self.sample_steps
        t, i = None, None

        if not self.with_discretised_time:
            t = torch.rand(size=(X.size(0), 1), device=X.device).float()  # (bs, 1)
        else:
            i = torch.randint(
                1, N + 1, size=(X.size(0), 1), device=X.device
            ).float()  # (bs, 1)
            t = (i - 1) / N
        # t = torch.clamp(t, min=1e-4)
        theta_X_t = self.discrete_var_bayesian_flow(
            t,
            self.beta_node,
            X,
            self.Xdim_output,
            time_scheduler=self.node_time_scheduler,
            beta_init=self.beta_node_init,
        )
        theta_E_t = self.discrete_var_bayesian_flow(
            t,
            self.beta_edge,
            E,
            self.Edim_output,
            time_scheduler=self.edge_time_scheduler,
            beta_init=self.beta_edge_init,
        )

        # BFN_paper page 27: all input probabilities (i.e. theta) are rescaled to the range [-1,1] before feeding them into the network
        rescaled_theta_X_t = theta_X_t * 2 - 1  # scale -1 1
        rescaled_theta_E_t = theta_E_t * 2 - 1  # scale -1 1

        rescaled_masked_theta_t = (
            utils.PlaceHolder(X=rescaled_theta_X_t, E=rescaled_theta_E_t, y=data.y)
            .type_as(rescaled_theta_X_t)
            .mask(node_mask)
        )

        # z_t = utils.PlaceHolder(X=theta_X_t, E=theta_E_t, y=data.y).type_as(theta_X_t).mask(node_mask)、
        model_inputs = {
            "t": t,
            "X_t": rescaled_masked_theta_t.X,
            "E_t": rescaled_masked_theta_t.E,
            "y_t": rescaled_masked_theta_t.y,
            "node_mask": node_mask,
        }

        noisy_para = {
            "t": t,
            "X_t": theta_X_t,
            "E_t": theta_E_t,
            "y_t": data.y,
            "node_mask": node_mask,
        }

        # noisy_data = self.apply_noise(X, E, data.y, node_mask)
        extra_data = self.compute_extra_data_bfn(
            noisy_para
        )  

        if not self.cfg.model.output_dist_extra:
            pred = self.forward(
                model_inputs, extra_data, node_mask
            )  # interdependency modeling

            # softmax
            pred.X = torch.nn.functional.softmax(pred.X, dim=-1)
            pred.E = torch.nn.functional.softmax(pred.E, dim=-1)
        else:
            with torch.no_grad():
                pred = self.forward(
                    model_inputs, extra_data, node_mask
                )  # interdependency modeling

                # softmax
                pred.X = torch.nn.functional.softmax(pred.X, dim=-1)
                pred.E = torch.nn.functional.softmax(pred.E, dim=-1)
                noisy_para = {
                    "t": t,
                    "X_t": pred.X,
                    "E_t": pred.E,
                    "y_t": data.y,
                    "node_mask": node_mask,
                }
                extra_data = self.compute_extra_data_bfn(noisy_para)
            pred = self.forward(model_inputs, extra_data, node_mask)
            pred.X = torch.nn.functional.softmax(pred.X, dim=-1)
            pred.E = torch.nn.functional.softmax(pred.E, dim=-1)

        """
        Continuous:
            We let weight_x/e be the weight in L_{cont}, which is 1/2 * alpha(t) * K = K*beta_1*t
        Discrete:
            We let weight_x/e be alpha, which is beta_1*(2i-1)/n^2
        """
        weight_x, weight_e = None, None
        if not self.with_discretised_time:
            if self.node_time_scheduler == "quad":
                weight_x = self.Xdim_output * self.beta_node * t[:, None]
                # weight_e = self.Edim_output * self.beta_edge * t[:, None, None]
            elif self.node_time_scheduler == "linear":
                weight_x = (
                    self.Xdim_output
                    * (self.beta_node / 2)
                    * torch.ones_like(t[:, None])
                )
                # weight_e = self.Edim_output * (self.beta_edge/2) * torch.ones_like(t[:, None, None])
            elif self.node_time_scheduler == "third":
                weight_x = (
                    self.Xdim_output * (self.beta_node * 3 / 2) * (t[:, None] ** 2)
                )
                # weight_e = self.Edim_output * (self.beta_edge/3) * (t[:, None, None]**2)
            elif self.node_time_scheduler == "hybrid":
                assert self.beta_node_init is not None
                weight_x = self.Xdim_output * (
                    2 * (self.beta_node - self.beta_node_init) * (t[:, None])
                    + (self.beta_node_init) * torch.ones_like(t[:, None])
                )
                # weight_e = self.Edim_output * (self.beta_edge/3) * (t[:, None, None]**2)
            else:
                raise NotImplementedError

            if self.edge_time_scheduler == "quad":
                weight_e = self.Edim_output * self.beta_edge * t[:, None, None]
            elif self.edge_time_scheduler == "linear":
                weight_e = (
                    self.Edim_output
                    * (self.beta_edge / 2)
                    * torch.ones_like(t[:, None, None])
                )
            elif self.edge_time_scheduler == "third":
                weight_e = (
                    self.Edim_output
                    * (self.beta_edge * 3 / 2)
                    * (t[:, None, None] ** 2)
                )
            elif self.edge_time_scheduler == "hybrid":
                assert self.beta_node_init is not None
                weight_e = self.Edim_output * (
                    2 * (self.beta_edge - self.beta_edge_init) * (t[:, None, None])
                    + (self.beta_edge_init) * torch.ones_like(t[:, None, None])
                )
            else:
                raise NotImplementedError
        else:
            weight_x = self.beta_node * (2 * i[:, None] - 1) / (N**2)
            weight_e = self.beta_edge * (2 * i[:, None, None] - 1) / (N**2)

        weight_x = weight_x.repeat(1, pred.X.shape[1], 1)
        weight_e = weight_e.repeat(1, pred.E.shape[1], pred.E.shape[2], 1)

        
        loss, to_log = self.train_loss(
            masked_pred_X=pred.X,
            masked_pred_E=pred.E,
            pred_y=pred.y,
            true_X=X,
            true_E=E,
            true_y=data.y,
            weight_x=weight_x,
            weight_e=weight_e,
            log=curr_step % self.log_every_steps == 0,
        )

        self.log_dict(
            to_log,
            on_step=True,
            prog_bar=True,
            batch_size=self.cfg.train.batch_size,
        )
        # check if loss is finite, skip update if not
        if not torch.isfinite(loss):
            return None
        self.train_losses.append(loss.clone().detach().cpu())

        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.AdamW(
            self.parameters(),
            lr=self.cfg.train.lr,
            amsgrad=True,
            weight_decay=self.cfg.train.weight_decay,
        )

    def on_fit_start(self) -> None:
        self.train_iterations = len(self.trainer.datamodule.train_dataloader())
        self.print("Size of the input features", self.Xdim, self.Edim, self.ydim)

    def on_train_epoch_start(self) -> None:
        self.print("Starting train epoch...")
        self.start_epoch_time = time.time()
        self.train_loss.reset()
        # self.train_metrics.reset()

    def on_train_epoch_end(self) -> None:
        ## Log epoch_loss
        to_log = self.train_loss.log_epoch_metrics()
        self.print(
            f"Epoch {self.current_epoch}: X_MSE: {to_log['train_epoch/X_MSE'] :.3f}"
            f" -- E_MSE: {to_log['train_epoch/E_MSE'] :.3f} --"
            f" y_CE: {to_log['train_epoch/y_CE'] :.3f}"
            f" -- {time.time() - self.start_epoch_time:.1f}s "
        )

        print(torch.cuda.memory_summary())
        self.log_dict(to_log, batch_size=self.cfg.train.batch_size)

    def on_validation_epoch_start(self) -> None:
        self.sampling_metrics.reset()

    def validation_step(self, data, i):
        pass

    def on_validation_epoch_end(self) -> None:
        # Skip the validation at epoch 0
        if self.current_epoch == 0:
            return
        self.print(f"validating")

        # self.val_counter += 1
        # if self.val_counter % self.cfg.general.sample_every_val == 0:
        start = time.time()
        samples_left_to_generate = self.cfg.general.samples_to_generate
        samples_left_to_save = self.cfg.general.samples_to_save
        chains_left_to_save = self.cfg.general.chains_to_save

        samples = []

        ident = 0
        while samples_left_to_generate > 0:
            # bs = 2 * self.cfg.train.batch_size
            bs = self.cfg.general.sampling_bs
            to_generate = min(samples_left_to_generate, bs)
            to_save = min(samples_left_to_save, bs)
            chains_save = min(chains_left_to_save, bs)
            samples.extend(
                # Original
                self.sample_batch(
                    batch_id=ident,
                    batch_size=to_generate,
                    num_nodes=self.num_nodes,
                    save_final=to_save,
                    keep_chain=chains_save,
                    number_chain_steps=self.number_chain_steps,
                )
            )
            ident += to_generate

            samples_left_to_save -= to_save
            samples_left_to_generate -= to_generate
            chains_left_to_save -= chains_save
        self.print("Computing sampling metrics...")

        metrics = self.sampling_metrics.forward(
            samples,
            self.name,
            self.current_epoch,
            val_counter=-1,
            test=False,
            local_rank=self.local_rank,
        )
        if not isinstance(
            metrics, tuple
        ):  # handle the case where self.sampling_metrics.forward() has a single return
            self.log_dict(metrics, batch_size=self.cfg.train.batch_size)
        else:
            for metric in metrics:
                self.log_dict(metric, batch_size=self.cfg.train.batch_size)

        self.print(f"Done. Sampling took {time.time() - start:.2f} seconds\n")
        print("Validation epoch end ends...")

    def on_test_epoch_start(self) -> None:
        self.print("Starting test...")
        # if self.local_rank == 0:
        #     utils.setup_wandb(self.cfg)

    def test_step(self, data, i):
        pass

    def on_test_epoch_end(self) -> None:
        """Measure likelihood on a test set and compute stability metrics."""
        self.print(f"Testing")

        samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
        samples_left_to_save = self.cfg.general.final_model_samples_to_save
        chains_left_to_save = self.cfg.general.final_model_chains_to_save

        samples = []
        id = 0
        while samples_left_to_generate > 0:
            print(
                f"Samples left to generate: {samples_left_to_generate}/"
                f"{self.cfg.general.final_model_samples_to_generate}",
            )
            # bs = 2 * self.cfg.train.batch_size
            bs = self.cfg.general.sampling_bs
            to_generate = min(samples_left_to_generate, bs)
            to_save = min(samples_left_to_save, bs)
            chains_save = min(chains_left_to_save, bs)
            samples.extend(
                self.sample_batch(
                    id,
                    to_generate,
                    num_nodes=self.num_nodes,
                    save_final=to_save,
                    keep_chain=chains_save,
                    number_chain_steps=self.number_chain_steps,
                    local_rank=self.local_rank,
                )
            )

            id += to_generate
            samples_left_to_save -= to_save
            samples_left_to_generate -= to_generate
            chains_left_to_save -= chains_save
        self.print("Saving the generated graphs")
        filename = f"generated_samples1.txt"
        for i in range(2, 10):
            if os.path.exists(filename):
                filename = f"generated_samples{i}.txt"
            else:
                break
        with open(filename, "w") as f:
            for item in samples:
                f.write(f"N={item[0].shape[0]}\n")
                atoms = item[0].tolist()
                f.write("X: \n")
                for at in atoms:
                    f.write(f"{at} ")
                f.write("\n")
                f.write("E: \n")
                for bond_list in item[1]:
                    for bond in bond_list:
                        f.write(f"{bond} ")
                    f.write("\n")
                f.write("\n")
        self.print("Generated graphs Saved. Computing sampling metrics...")

        metrics = self.sampling_metrics.forward(
            samples,
            self.name,
            self.current_epoch,
            val_counter=-1,
            test=False,
            local_rank=self.local_rank,
        )
        if not isinstance(
            metrics, tuple
        ):  # handle the case where self.sampling_metrics.forward() has a single return
            self.log_dict(metrics, batch_size=self.cfg.train.batch_size)
        else:
            for metric in metrics:
                self.log_dict(metric, batch_size=self.cfg.train.batch_size)

        self.print("Done testing.")

    def forward(self, noisy_data, extra_data, node_mask):
        X = torch.cat((noisy_data["X_t"], extra_data.X), dim=2).float()
        E = torch.cat((noisy_data["E_t"], extra_data.E), dim=3).float()
        y = torch.hstack((noisy_data["y_t"], extra_data.y)).float()
        return self.model(X, E, y, node_mask)

    @torch.no_grad()
    def sample_batch(
        self,
        batch_id: int,
        batch_size: int,
        keep_chain: int,
        number_chain_steps: int,
        save_final: int,
        num_nodes=None,
        local_rank=0,
    ):
        """
        :param batch_id: int
        :param batch_size: int
        :param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes
        :param save_final: int: number of predictions to save to file
        :param keep_chain: int: number of chains to save to file
        :param keep_chain_steps: number of timesteps to save for each chain
        :return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions)
        """
        if num_nodes is None:
            n_nodes = self.node_dist.sample_n(batch_size, self.device)
        elif type(num_nodes) == int:
            n_nodes = num_nodes * torch.ones(
                batch_size, device=self.device, dtype=torch.int
            )
        else:
            assert isinstance(num_nodes, torch.Tensor)
            n_nodes = num_nodes
        n_max = torch.max(n_nodes).item()

        # Build the masks
        arange = (
            torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1)
        )
        node_mask = arange < n_nodes.unsqueeze(1)

        ## Sample Prior noise  -- z has size (n_samples, n_nodes, n_features)
        theta_X = torch.ones(
            (batch_size, n_max, self.Xdim_output), device=self.device
        ) * self.limit_dist.X.to(self.device)
        theta_E = torch.ones(
            (batch_size, n_max, n_max, self.Edim_output), device=self.device
        ) * self.limit_dist.E.to(self.device)
        # For BFN, y will not be used
        y = torch.empty((batch_size, 0), device=self.device)

        assert (theta_E == torch.transpose(theta_E, 1, 2)).all()
        assert number_chain_steps < self.sample_steps
        chain_X_size = torch.Size((number_chain_steps, keep_chain, theta_X.size(1)))
        chain_E_size = torch.Size(
            (number_chain_steps, keep_chain, theta_E.size(1), theta_E.size(2))
        )

        chain_X = torch.zeros(chain_X_size)
        chain_E = torch.zeros(chain_E_size)

        if self.cfg.input_dist_sample:
            chain_input_dist_X = torch.zeros(chain_X_size)
            chain_input_dist_E = torch.zeros(chain_E_size)
        if self.cfg.plot_input_dist_entropy:
            entropy_ = []
            from src.diffusion.diffusion_utils import compute_mean_entropy

            entropy = compute_mean_entropy(theta_E)
            entropy_.append(entropy)

        ## Iterative sampling throuhg t
        for i in tqdm(range(1, self.sample_steps + 1)):  

            rescaled_theta_X_t = theta_X * 2 - 1  
            rescaled_theta_E_t = theta_E * 2 - 1

            t = torch.ones(batch_size, 1).to(self.device) * (i - 1) / self.sample_steps
            t = torch.clamp(t, min=1e-4)
            rescaled_masked_theta_t = (
                utils.PlaceHolder(X=rescaled_theta_X_t, E=rescaled_theta_E_t, y=y)
                .type_as(rescaled_theta_X_t)
                .mask(node_mask)
            )

            model_inputs = {
                "t": t,
                "X_t": rescaled_masked_theta_t.X,
                "E_t": rescaled_masked_theta_t.E,
                "y_t": rescaled_masked_theta_t.y,
                "node_mask": node_mask,
            }

            if self.cfg.model.output_dist_extra and i > 1:
                noisy_para = {
                    "t": t,
                    "X_t": pred.X,
                    "E_t": pred.E,
                    "y_t": y,
                    "node_mask": node_mask,
                }
            else:
                noisy_para = {
                    "t": t,
                    "X_t": theta_X,
                    "E_t": theta_E,
                    "y_t": y,
                    "node_mask": node_mask,
                }

            # noisy_data = self.apply_noise(X, E, data.y, node_mask)
            extra_data = self.compute_extra_data_bfn(noisy_para)
            pred = self.forward(
                model_inputs, extra_data, node_mask
            )  # interdependency modeling
            # softmax

            # 1. Clamp Logits (that are not finite real values)
            if not torch.all(pred.X.isfinite()):
                pred.X = torch.where(
                    pred.X.isfinite(), pred.X, torch.zeros_like(pred.X)
                )
                self.print("WARNING: pred.X is not finite")
            if not torch.all(pred.E.isfinite()):
                pred.E = torch.where(
                    pred.E.isfinite(), pred.E, torch.zeros_like(pred.E)
                )
                self.print("WARNING: pred.E is not finite")

            # Take softmax to get probabilities
            pred.X = torch.nn.functional.softmax(pred.X, dim=-1)
            pred.E = torch.nn.functional.softmax(pred.E, dim=-1)

            # 2. Clamp Prbabilities (that are too small)
            pred.X = torch.clamp(pred.X, min=1e-6)
            pred.E = torch.clamp(pred.E, min=1e-6)

            sample_pred_x = torch.distributions.Categorical(pred.X).sample()

            # print("sample_pred_x",sample_pred_x.shape)
            sample_pred_kx = F.one_hot(
                sample_pred_x, num_classes=self.Xdim_output
            ).float()

            sample_pred_e = torch.distributions.Categorical(pred.E).sample()
            sample_pred_e = torch.triu(sample_pred_e, diagonal=1)
            sample_pred_e = sample_pred_e + torch.transpose(sample_pred_e, 1, 2)
            assert (sample_pred_e == torch.transpose(sample_pred_e, 1, 2)).all()
            ## -------------------
            sample_pred_ke = F.one_hot(
                sample_pred_e, num_classes=self.Edim_output
            ).float()

            if ((i - 1) / self.sample_steps) >= (
                1 - self.cfg.model.alternative_sampling_theta_update_ratio
            ):
                theta_X = self.discrete_var_bayesian_flow(
                    t,
                    self.beta_node,
                    sample_pred_kx,
                    self.Xdim_output,
                    time_scheduler=self.node_time_scheduler,
                    beta_init=self.beta_node_init,
                )
                theta_E = self.discrete_var_bayesian_flow(
                    t,
                    self.beta_edge,
                    sample_pred_ke,
                    self.Edim_output,
                    time_scheduler=self.edge_time_scheduler,
                    beta_init=self.beta_edge_init,
                )
                # if force_symmetric_theta_E is enabled, theta_E is symmetric because self.discrete_var_bayesian_flow ensures it
            else:
                # Baysian Update
                if self.node_time_scheduler == "quad":
                    alpha_h_x = self.beta_node * (2 * i - 1) / (self.sample_steps**2)
                elif self.node_time_scheduler == "linear":
                    alpha_h_x = (
                        self.beta_node / self.sample_steps
                    )  # * torch.ones_like(i)
                elif self.node_time_scheduler == "third":
                    alpha_h_x = (
                        self.beta_node
                        * (i**3 - (i - 1) ** 3)
                        / (self.sample_steps**3)
                    )
                elif self.node_time_scheduler == "hybrid":
                    assert self.beta_node_init is not None
                    alpha_h_x = (self.beta_node - self.beta_node_init) * (2 * i - 1) / (
                        self.sample_steps**2
                    ) + self.beta_node_init / self.sample_steps
                else:
                    raise NotImplementedError

                mean_x = alpha_h_x * (self.Xdim_output * sample_pred_kx - 1)
                var = alpha_h_x * self.Xdim_output
                std = torch.full_like(mean_x, fill_value=var).sqrt()
                y_h = mean_x + std * torch.randn_like(sample_pred_kx)
                theta_prime = torch.exp(y_h) * theta_X  
                theta_X = theta_prime / theta_prime.sum(dim=-1, keepdim=True)


                if self.edge_time_scheduler == "quad":
                    alpha_h_e = self.beta_edge * (2 * i - 1) / (self.sample_steps**2)
                elif self.edge_time_scheduler == "linear":
                    alpha_h_e = self.beta_edge / self.sample_steps
                elif self.edge_time_scheduler == "third":
                    alpha_h_e = (
                        self.beta_edge
                        * (i**3 - (i - 1) ** 3)
                        / (self.sample_steps**3)
                    )
                elif self.edge_time_scheduler == "hybrid":
                    assert self.beta_edge_init is not None
                    alpha_h_e = (self.beta_edge - self.beta_edge_init) * (2 * i - 1) / (
                        self.sample_steps**2
                    ) + self.beta_edge_init / self.sample_steps
                else:
                    raise NotImplementedError

                mean_e = alpha_h_e * (self.Edim_output * sample_pred_ke - 1)
                var = alpha_h_e * self.Edim_output
                std = torch.full_like(mean_e, fill_value=var).sqrt()
                y_h = mean_e + std * torch.randn_like(sample_pred_ke)
                theta_prime = torch.exp(y_h) * theta_E  
                theta_E = theta_prime / theta_prime.sum(dim=-1, keepdim=True)
                # print ("E_after",i, E.sum())
                if self.cfg.force_symmetric_theta_E:
                    # BUG: we need to Force theta_E to be symmetric
                    theta_E = (theta_E + theta_E.transpose(1, 2)) / 2

            write_index = ((i - 1) * number_chain_steps) // self.sample_steps

            intermidiate_ = (
                utils.PlaceHolder(X=sample_pred_kx, E=sample_pred_ke, y=y)
                .type_as(theta_X)
                .mask(node_mask, collapse=True)
            )
            chain_X[write_index] = intermidiate_.X[:keep_chain]
            chain_E[write_index] = intermidiate_.E[:keep_chain]

            if self.cfg.input_dist_sample:

                sample_input_dist_X = torch.distributions.Categorical(theta_X).sample()
                sample_input_dist_X = F.one_hot(
                    sample_input_dist_X, num_classes=self.Xdim_output
                ).float()
                sample_input_dist_E = torch.distributions.Categorical(theta_E).sample()
                sample_input_dist_E = torch.triu(sample_input_dist_E, diagonal=1)
                sample_input_dist_E = sample_input_dist_E + torch.transpose(
                    sample_input_dist_E, 1, 2
                )
                sample_input_dist_E = F.one_hot(
                    sample_input_dist_E, num_classes=self.Edim_output
                ).float()
                assert (
                    sample_input_dist_E == torch.transpose(sample_input_dist_E, 1, 2)
                ).all()
                sample_input_dist = (
                    utils.PlaceHolder(X=sample_input_dist_X, E=sample_input_dist_E, y=y)
                    .type_as(theta_X)
                    .mask(node_mask, collapse=True)
                )
                chain_input_dist_X[write_index] = sample_input_dist.X[:keep_chain]
                chain_input_dist_E[write_index] = sample_input_dist.E[:keep_chain]

            if self.cfg.plot_input_dist_entropy:
                from src.diffusion.diffusion_utils import compute_mean_entropy

                entropy = compute_mean_entropy(theta_E)
                entropy_.append(entropy)

        ## Final sampling step:
        if not self.with_discretised_time:
            # not involve the last step for discretised time.
            t = torch.ones((batch_size, 1)).to(self.device)
            rescaled_theta_X_t = theta_X * 2 - 1
            rescaled_theta_E_t = theta_E * 2 - 1

            rescaled_masked_theta_t = (
                utils.PlaceHolder(X=rescaled_theta_X_t, E=rescaled_theta_E_t, y=y)
                .type_as(rescaled_theta_X_t)
                .mask(node_mask)
            )

            model_inputs = {
                "t": t,
                "X_t": rescaled_masked_theta_t.X,
                "E_t": rescaled_masked_theta_t.E,
                "y_t": rescaled_masked_theta_t.y,
                "node_mask": node_mask,
            }

            if self.cfg.model.output_dist_extra:
                noisy_para = {
                    "t": t,
                    "X_t": pred.X,
                    "E_t": pred.E,
                    "y_t": y,
                    "node_mask": node_mask,
                }
            else:
                noisy_para = {
                    "t": t,
                    "X_t": theta_X,
                    "E_t": theta_E,
                    "y_t": y,
                    "node_mask": node_mask,
                }

            extra_data = self.compute_extra_data_bfn(noisy_para)
            pred = self.forward(
                model_inputs, extra_data, node_mask
            )  # interdependency modeling
        # if with discretised time, we do not involve the last steps.

        # 1. Clamp Logits (that are not finite real values)
        if not torch.all(pred.X.isfinite()):
            pred.X = torch.where(pred.X.isfinite(), pred.X, torch.zeros_like(pred.X))
            self.print("WARNING: pred.X is not finite")
        if not torch.all(pred.E.isfinite()):
            pred.E = torch.where(pred.E.isfinite(), pred.E, torch.zeros_like(pred.E))
            self.print("WARNING: pred.E is not finite")
        # softmax

        pred.X = torch.nn.functional.softmax(pred.X, dim=-1)
        pred.E = torch.nn.functional.softmax(pred.E, dim=-1)

        # 2. Clamp Probabilites (that are too small)
        pred.X = torch.clamp(pred.X, min=1e-6)
        pred.E = torch.clamp(pred.E, min=1e-6)

        sample_pred_x = torch.distributions.Categorical(pred.X).mode
        sample_pred_kx = F.one_hot(sample_pred_x, num_classes=self.Xdim_output)

        sample_pred_e = torch.distributions.Categorical(pred.E).mode
        # Symmetric E:
        sample_pred_e = torch.triu(sample_pred_e, diagonal=1)
        sample_pred_e = sample_pred_e + torch.transpose(sample_pred_e, 1, 2)
        assert (sample_pred_e == torch.transpose(sample_pred_e, 1, 2)).all()

        sample_pred_ke = F.one_hot(sample_pred_e, num_classes=self.Edim_output)

        result_ = utils.PlaceHolder(X=sample_pred_kx, E=sample_pred_ke, y=y).type_as(
            sample_pred_x
        )
        result_ = result_.mask(node_mask, collapse=True)
        X, E, y = result_.X, result_.E, result_.y

        # Prepare the chain for saving
        if keep_chain > 0:
            final_X_chain = X[:keep_chain]
            final_E_chain = E[:keep_chain]

            chain_X[-1] = final_X_chain  # Overwrite last frame with the resulting X, E
            chain_E[-1] = final_E_chain

            # Repeat last frame to see final sample better
            chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1, 1)], dim=0)
            chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1, 1)], dim=0)
            assert chain_X.size(0) == (number_chain_steps + 10)

            if self.cfg.input_dist_sample:
                sample_input_dist_X = torch.distributions.Categorical(theta_X).mode
                sample_input_dist_X = F.one_hot(
                    sample_input_dist_X, num_classes=self.Xdim_output
                ).float()
                sample_input_dist_E = torch.distributions.Categorical(theta_E).mode
                sample_input_dist_E = torch.triu(sample_input_dist_E, diagonal=1)
                sample_input_dist_E = sample_input_dist_E + torch.transpose(
                    sample_input_dist_E, 1, 2
                )
                sample_input_dist_E = F.one_hot(
                    sample_input_dist_E, num_classes=self.Edim_output
                ).float()
                assert (
                    sample_input_dist_E == torch.transpose(sample_input_dist_E, 1, 2)
                ).all()
                sample_input_dist = (
                    utils.PlaceHolder(X=sample_input_dist_X, E=sample_input_dist_E, y=y)
                    .type_as(theta_X)
                    .mask(node_mask, collapse=True)
                )
                chain_input_dist_X[-1] = sample_input_dist.X[:keep_chain]
                chain_input_dist_E[-1] = sample_input_dist.E[:keep_chain]

                # Repeat last frame to see final sample better
                chain_input_dist_X = torch.cat(
                    [chain_input_dist_X, chain_input_dist_X[-1:].repeat(10, 1, 1)],
                    dim=0,
                )
                chain_input_dist_E = torch.cat(
                    [chain_input_dist_E, chain_input_dist_E[-1:].repeat(10, 1, 1, 1)],
                    dim=0,
                )
                assert chain_X.size(0) == (number_chain_steps + 10)

        molecule_list = []
        for i in range(batch_size):
            n = n_nodes[i]
            atom_types = X[i, :n].cpu()
            edge_types = E[i, :n, :n].cpu()
            molecule_list.append([atom_types, edge_types])

        # Visualize chains
        if self.visualization_tools is not None:
            # Visualize generation processs as Gifs and Grid Images
            print(f"Visualizing chains...")
            current_path = os.getcwd()
            num_molecules = chain_X.size(1)  # number of molecules
            for i in range(num_molecules):
                result_path = os.path.join(
                    current_path,
                    f"chains/{self.name}/"
                    f"epoch{self.current_epoch}/"
                    f"molecule_{local_rank}_{batch_id + i}",  # f"chains/molecule_{batch_id + i}" results in chains/epoch{self.current_epoch}/chains/molecule_{id}, the second "chains" is redundant
                )
                if not os.path.exists(result_path):
                    os.makedirs(result_path)
                    _, gif, column = self.visualization_tools.visualize_chain(
                        result_path,
                        chain_X[:, i, :].numpy(),
                        chain_E[:, i, :].numpy(),
                        epoch=self.current_epoch,
                    )

            # Visualize generation processs as Gifs and Grid Images
            if self.cfg.input_dist_sample:
                print(f"\nVisualizing samples drawn from the input distribution")
                num_molecules = chain_input_dist_X.size(1)  # number of molecules
                for i in range(num_molecules):
                    result_path = os.path.join(
                        current_path,
                        f"chains/{self.name}/"
                        f"epoch{self.current_epoch}/"
                        f"input_dist/"
                        f"molecule_{local_rank}_{batch_id + i}",  # f"chains/molecule_{batch_id + i}" results in chains/epoch{self.current_epoch}/chains/molecule_{id}, the second "chains" is redundant
                    )
                    if not os.path.exists(result_path):
                        os.makedirs(result_path)
                        _, gif, column = self.visualization_tools.visualize_chain(
                            result_path,
                            chain_input_dist_X[:, i, :].numpy(),
                            chain_input_dist_E[:, i, :].numpy(),
                            epoch=self.current_epoch,
                        )

            # Visualize the comparison between the generatin processes of input and output distributions
            if self.cfg.compare_input_output_dist_samples:
                assert (
                    self.cfg.input_dist_sample
                ), "input_dist_sample must be enabled to draw the comparison"
                print(
                    f"\nVisualizing the comparison between the generatin processes of input and output distributions"
                )
                num_molecules = chain_X.size(1)  # number of molecules
                for i in range(num_molecules):
                    output_sample_path = os.path.join(
                        current_path,
                        f"chains/{self.name}/"
                        f"epoch{self.current_epoch}/"
                        f"molecule_{local_rank}_{batch_id + i}",  # f"chains/molecule_{batch_id + i}" results in chains/epoch{self.current_epoch}/chains/molecule_{id}, the second "chains" is redundant
                    )
                    input_sample_path = os.path.join(
                        current_path,
                        f"chains/{self.name}/"
                        f"epoch{self.current_epoch}/"
                        f"input_dist/"
                        f"molecule_{local_rank}_{batch_id + i}",  # f"chains/molecule_{batch_id + i}" results in chains/epoch{self.current_epoch}/chains/molecule_{id}, the second "chains" is redundant
                    )
                    result_parent_path = result_path = os.path.join(
                        current_path,
                        f"chains/{self.name}/"
                        f"epoch{self.current_epoch}/"
                        f"comparison/",
                    )
                    if not os.path.exists(result_parent_path):
                        os.mkdir(result_parent_path)
                    result_path = os.path.join(
                        result_parent_path, f"molecule_{local_rank}_{batch_id + i}.png"
                    )

                    self.visualization_tools.visualize_input_output_dist(
                        result_path,
                        input_sample_path,
                        output_sample_path,
                        epoch=self.current_epoch,
                    )

            # Visualize the final molecules
            print(f"\nVisualizing molecules...")
            result_path = os.path.join(
                current_path,
                f"graphs/{self.name}/epoch{self.current_epoch}_b{batch_id}_{local_rank}/",
            )
            table = self.visualization_tools.visualize(
                result_path, molecule_list, save_final, epoch=self.current_epoch
            )

            self.print("Done.")

            # Plot input dist entropy
            if self.cfg.plot_input_dist_entropy:
                print(f"\n Plotting input_dist's entropy over t")
                result_path = os.path.join(
                    os.getcwd(),
                    f"chains/{self.name}/" f"epoch{self.current_epoch}_{local_rank}",
                )
                legend = f"{self.beta_edge}_{self.edge_time_scheduler}_{self.cfg.model.transition}"
                self.visualization_tools.plot_entropy(result_path, entropy_, legend)

        return molecule_list

    def compute_extra_data(self, noisy_data):
        """At every training step (after adding noise) and step in sampling, compute extra information and append to
        the network input."""

        extra_features = self.extra_features(noisy_data)
        extra_molecular_features = self.domain_features(noisy_data)

        extra_X = torch.cat((extra_features.X, extra_molecular_features.X), dim=-1)
        extra_E = torch.cat((extra_features.E, extra_molecular_features.E), dim=-1)
        extra_y = torch.cat((extra_features.y, extra_molecular_features.y), dim=-1)

        t = noisy_data["t"]
        extra_y = torch.cat((extra_y, t), dim=1)

        return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y)

    def compute_extra_data_bfn(self, noisy_para, mode="prob"):
        bs = noisy_para["X_t"].shape[0]
        if self.extra_mode == "prob":
            sampled_t = diffusion_utils.sample_discrete_features(
                probX=noisy_para["X_t"],
                probE=noisy_para["E_t"],
                node_mask=noisy_para["node_mask"],
                mode="prob",
            )
            X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
            E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
            assert (noisy_para["X_t"].shape == X_t.shape) and (
                noisy_para["E_t"].shape == E_t.shape
            )
            z_t = (
                utils.PlaceHolder(X=X_t, E=E_t, y=noisy_para["y_t"])
                .type_as(X_t)
                .mask(noisy_para["node_mask"])
            )
            noisy_data = {
                "t": noisy_para["t"],
                "X_t": z_t.X,
                "E_t": z_t.E,
                "y_t": z_t.y,
                "node_mask": noisy_para["node_mask"],
            }

        elif self.extra_mode == "max":
            sampled_t = diffusion_utils.sample_discrete_features(
                probX=noisy_para["X_t"],
                probE=noisy_para["E_t"],
                node_mask=noisy_para["node_mask"],
                mode="max",
            )
            X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
            E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
            assert (noisy_para["X_t"].shape == X_t.shape) and (
                noisy_para["E_t"].shape == E_t.shape
            )

            z_t = (
                utils.PlaceHolder(X=X_t, E=E_t, y=noisy_para["y_t"])
                .type_as(X_t)
                .mask(noisy_para["node_mask"])
            )

            noisy_data = {
                "t": noisy_para["t"],
                "X_t": z_t.X,
                "E_t": z_t.E,
                "y_t": z_t.y,
                "node_mask": noisy_para["node_mask"],
            }
        elif self.extra_mode == "iid":
            sampled_t = diffusion_utils.sample_discrete_features(
                probX=noisy_para["X_t"],
                probE=noisy_para["E_t"],
                node_mask=noisy_para["node_mask"],
                mode="prob",
                n_iid=self.n_iid,
            )
            X_t = F.one_hot(
                sampled_t.X, num_classes=self.Xdim_output
            )  # (n_iid, bs, n, dx_out)
            E_t = F.one_hot(
                sampled_t.E, num_classes=self.Edim_output
            )  # (n_iid, bs, n, n, de_out)
            if self.n_iid == 0:
                assert (noisy_para["X_t"].shape == X_t.shape) and (
                    noisy_para["E_t"].shape == E_t.shape
                )
            else:
                assert (noisy_para["X_t"].shape == X_t[0].shape) and (
                    noisy_para["E_t"].shape == E_t[0].shape
                )
            X_t = X_t.reshape(-1, *X_t.shape[2:])  # (n_iid*bs, n, dx_out)
            E_t = E_t.reshape(-1, *E_t.shape[2:])  # (n_iid*bs, n, n, de_out)

            ## replicate t, y_t, and node musk
            t = noisy_para["t"].tile(self.n_iid, 1)  # (n_iid * bs, 1)
            y_t = noisy_para["y_t"].tile(self.n_iid, 1)  # (n_iid * bs, y_dim)
            node_mask = noisy_para["node_mask"].tile(self.n_iid, 1)  # (n_iid * bs, n)
            z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
            noisy_data = {
                "t": noisy_para["t"],
                "X_t": z_t.X,
                "E_t": z_t.E,
                "y_t": z_t.y,
                "node_mask": node_mask,
            }
        elif self.extra_mode == "direct":  # directly use the paramter to get
            
            X_t = noisy_para["X_t"]
            E_t = noisy_para["E_t"]
            assert (E_t == torch.transpose(E_t, 1, 2)).all()

            z_t = (
                utils.PlaceHolder(X=X_t, E=E_t, y=noisy_para["y_t"])
                .type_as(X_t)
                .mask(noisy_para["node_mask"])
            )

            noisy_data = {
                "t": noisy_para["t"],
                "X_t": z_t.X,
                "E_t": z_t.E,
                "y_t": z_t.y,
                "node_mask": noisy_para["node_mask"],
            }
            # noisy_data ==  noisy_para

        else:
            raise NotImplementedError

        extra_features = self.extra_features(noisy_data)

        if self.extra_mode == "direct":
            # For molecular data, direct is not computable as charges and valencies are not defined on the continuous space, so we instead do random sampling under the "max" mode
            sampled_t = diffusion_utils.sample_discrete_features(
                probX=noisy_para["X_t"],
                probE=noisy_para["E_t"],
                node_mask=noisy_para["node_mask"],
                mode="max",
            )
            X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
            E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
            assert (noisy_para["X_t"].shape == X_t.shape) and (
                noisy_para["E_t"].shape == E_t.shape
            )

            z_t = (
                utils.PlaceHolder(X=X_t, E=E_t, y=noisy_para["y_t"])
                .type_as(X_t)
                .mask(noisy_para["node_mask"])
            )

            noisy_data = {
                "t": noisy_para["t"],
                "X_t": z_t.X,
                "E_t": z_t.E,
                "y_t": z_t.y,
                "node_mask": noisy_para["node_mask"],
            }
            extra_molecular_features = self.domain_features(noisy_data)

        else:
            extra_molecular_features = self.domain_features(noisy_data)

        extra_X = torch.cat((extra_features.X, extra_molecular_features.X), dim=-1)
        """
        Graph Feat.(extra_features.X):
            cycles:
                dim=3: # of 3,4,5-cycles that X belongs to
            eigenvalues:
                dim=1: the biggest connected compunent (using the eigevectors associated to eigenvalue 0)
                dim=2: the 2 first eigenvectors associated to non zero eigenvalues
        Mol Feat.(extra_molecular_features.X):
            dim=1: charge
            dim=1: valency
        If "all": 
            total_dim = 3+1+2+1+1 = 8
        """
        extra_E = torch.cat((extra_features.E, extra_molecular_features.E), dim=-1)
        """
        total_dim = 0
        """
        extra_y = torch.cat((extra_features.y, extra_molecular_features.y), dim=-1)
        """
        Graph Feat.(extra_features.y):
            graph size:
                dim=1: (# of nodes in the graph) / max_n_nodes
            cycles:
                dim=4: # of 3,4,5,6-cycles contained in the graph
            eigenvalues:
                dim=1: # of connected compunent
                dim=5: the 5 first nonzero eigenvalues
        Mol Feat.(extra_molecular_features.y):
            dim=1: molecular weight
        If "all": 
            total_dim = 1+4+1+5+1 = 12
        """

        if self.extra_mode == "iid":
            extras = [extra_X, extra_E, extra_y]
            for i, _ in enumerate(extras):
                extra = extras[i]
                extra = extra.reshape(
                    self.n_iid, bs, *extra.shape[1:]
                )  # (n_iid, bs, n, extra_x_dim) / (n_iid, bs, n, n, extra_e_dim) / #(n_iid, bs, extra_y_dim)
                extra = extra.transpose(
                    0, 1
                )  
                if 0 in extra.shape:  # check whether some extra feature has 0 dimension
                    extra = extra[:, 0]
                else:
                    extra = extra.mean(
                        dim=1, dtype=noisy_para["X_t"].dtype
                    )  # convert to float
                    # (bs, n, extra_x_dim) / (bs, n, n, extra_e_dim) / #(bs, extra_y_dim)
                extras[i] = extra
            extra_X, extra_E, extra_y = extras
            # No need to mask here: each sample's computed extra features are masked, so the mean is masked (i.e. mean(0,...,0)=0 )

        # concatenate t to y
        t = noisy_data["t"]
        extra_y = torch.cat((extra_y, t), dim=1)

        return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y)
