import itertools
import math
import os
from dataclasses import dataclass

import hydra.utils
import lightning as L
import numpy as np
import torch
import torch.nn.functional as F
import torch_geometric
import torchmetrics
import wandb
from rdkit import Chem
from scipy.optimize import linear_sum_assignment
from torch import Tensor
from torch_geometric.utils import to_dense_batch
from shepherd_score.score.gaussian_overlap import get_overlap, get_overlap_batch

from src import dataloader, models, noise_schedule
from src.constants import (
    COORDS_STD,
    DATASET_PRIORS,
    MAX_ATOMS,
    N_BUILDING_BLOCKS,
    N_CENTERS,
    N_REACTIONS,
    TRAIN_SMILES,
)
from src.models.utils import to_dense
from src.utils.builder_utils import build_molecule, is_valid_smiles
from src.utils.indexing_utils import (
    get_compatibility_masks,
    node_to_atom_padding_mask,
    onehot_to_node_indices,
    onehot_to_reaction_type_and_centers,
    padding_mask,
    perfrag_atom_padding_mask,
    remove_coords_padding,
    remove_graph_padding,
)
from src.utils.misc_utils import get_logger, print_nans
from src.utils.spatial_utils import (
    align_and_permute,
    augment_coordinates,
    bond_length_loss,
    calc_energy,
    center_atom_coords,
    check_batch_coord_means,
    coordinates_to_mol,
    get_bonds,
    pairwise_distance_loss,
    smooth_lddt_loss,
    xyz_to_coordinates,
    sdf_to_coordinates,
    select_conformer,
)
from src.utils.viz import save_partial_as_sdf

LOGGER = get_logger(__name__)
LOG2 = math.log(2)
MAX_LOGIT = 15


def _get_unmasked_indices(tensor, mask_index=-1):
    """
    Determines which elements are unmasked in a one-hot encoded tensor.
    """
    is_masked = tensor[..., mask_index] == 1
    unmasked = torch.ones(tensor.shape[:-1], dtype=torch.bool, device=tensor.device)
    unmasked[is_masked] = False
    # Expand unmasked to match the size of the input tensor
    unmasked = unmasked.unsqueeze(-1).expand_as(tensor)

    return unmasked


def _sample_categorical(categorical_probs, temperature=1.0):
    """
    Sample from a categorical distribution with temperature scaling.

    Args:
        categorical_probs: Unnormalized probabilities
        temperature: Temperature parameter for scaling logits. Higher values produce more uniform samples.

    Returns:
        samples: Sampled indices from the categorical distribution
        normalized_probs: The normalized probabilities used for sampling
    """
    original_shape = categorical_probs.size()[:-1]
    num_categories = categorical_probs.size(-1)
    flat_probs = categorical_probs.reshape(-1, num_categories)

    # Apply temperature scaling and normalize
    if temperature != 1.0:
        flat_probs = flat_probs.pow(1.0 / temperature)
    normalized_probs = F.normalize(flat_probs, p=1, dim=-1)

    # Sample from the normalized distribution
    samples = torch.multinomial(normalized_probs, num_samples=1).squeeze(-1)

    return samples.view(original_shape), normalized_probs.view(*original_shape, num_categories)


def _to_onehot(matrix, num_classes):
    return F.one_hot(matrix, num_classes=num_classes).to(matrix.dtype)


@dataclass
class Loss:
    total_loss: torch.FloatTensor
    node_nll: torch.FloatTensor
    edge_nll: torch.FloatTensor
    ppl: torch.FloatTensor
    node_ppl: torch.FloatTensor
    edge_ppl: torch.FloatTensor
    fm_mse: torch.FloatTensor
    weighted_bond_loss: torch.FloatTensor
    weighted_mse: torch.FloatTensor


class NLL(torchmetrics.aggregation.MeanMetric):
    pass


class MSE(torchmetrics.aggregation.MeanMetric):
    pass


class Perplexity(NLL):
    def compute(self) -> Tensor:
        """Computes the Perplexity.

        Returns:
         Perplexity
        """
        return torch.exp(self.mean_value / self.weight)


class Diffusion(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()
        self.config = config
        self.n_node_features = N_BUILDING_BLOCKS + 1
        self.n_edge_features = N_REACTIONS * N_CENTERS * N_CENTERS + 2
        self.sampler = self.config.sampling.predictor
        self.antithetic_sampling = self.config.training.antithetic_sampling
        self.importance_sampling = self.config.training.importance_sampling
        self.change_of_variables = self.config.training.change_of_variables
        self.parameterization = self.config.parameterization
        if self.config.backbone == "graph_transformer_base":
            self.backbone = models.graph_transformer_base.GraphTransformer(self.config)
        elif self.config.backbone == "graph_transformer_large":
            self.backbone = models.graph_transformer_large.GraphTransformer(self.config)
        elif self.config.backbone == "graph_transformer_equivariant":
            self.backbone = models.graph_transformer_equivariant.GraphTransformer(self.config)
        elif self.config.backbone == "graph_transformer_fused":
            self.backbone = models.graph_transformer_fused.GraphTransformer(self.config)
        elif self.config.backbone == "semla":
            if self.config.spatial.pharmacophore_conditioning:
                self.backbone = models.semla_pharm.SemlaGenerator(self.config)
            else:
                self.backbone = models.semla.SemlaGenerator(self.config)
        elif self.config.backbone == "semla_atom":
            self.backbone = models.semla_atom.SemlaGenerator(self.config)
        elif self.config.backbone == "semla_attn":
            self.backbone = models.semla_attn.SemlaGenerator(self.config)
        else:
            raise ValueError(f"Unknown backbone: {self.config.backbone}")

        # self.backbone = torch.compile(self.backbone)
        self.T = self.config.T
        self.subs_masking = self.config.subs_masking

        self.softplus = torch.nn.Softplus()
        # metrics are automatically reset at end of epoch
        metrics = torchmetrics.MetricCollection(
            {
                "total_loss": NLL(),
                "node_nll": NLL(),
                "edge_nll": NLL(),
                "ppl": Perplexity(),
                "node_ppl": Perplexity(),
                "edge_ppl": Perplexity(),
                "fm_mse": MSE(),
                "weighted_bond_loss": MSE(),
                "weighted_mse": MSE(),
            }
        )
        metrics.set_dtype(torch.float64)
        self.train_metrics = metrics.clone(prefix="train/")
        self.valid_metrics = metrics.clone(prefix="val/")
        self.test_metrics = metrics.clone(prefix="test/")

        self.discrete_noise = noise_schedule.get_noise(
            self.config, field="discrete_noise", dtype=self.dtype
        )
        self.spatial_noise = noise_schedule.get_noise(
            self.config, field="spatial_noise", dtype=self.dtype
        )

        if self.config.training.ema > 0:
            self.ema = models.ema.ExponentialMovingAverage(
                itertools.chain(
                    self.backbone.parameters(),
                    self.discrete_noise.parameters(),
                    self.spatial_noise.parameters(),
                ),
                decay=self.config.training.ema,
            )
        else:
            self.ema = None

        self.lr = self.config.optim.lr
        self.sampling_eps = self.config.training.sampling_eps
        self.time_conditioning = self.config.time_conditioning
        self.neg_infinity = -1000000.0
        self.fast_forward_epochs = None
        self.fast_forward_batches = None
        self.node_mask_index = N_BUILDING_BLOCKS
        self.edge_mask_index = N_REACTIONS * N_CENTERS * N_CENTERS + 1
        self.churn = 1
        self.inference_dataloader = None
        self._validate_configuration()

    def _validate_configuration(self):
        assert not (self.change_of_variables and self.importance_sampling)
        assert self.parameterization == "subs"  # other parametrizations not supported yet

    def on_load_checkpoint(self, checkpoint):
        if self.ema:
            self.ema.load_state_dict(checkpoint["ema"])
        # Copied from:
        # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
        self.fast_forward_epochs = checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"][
            "completed"
        ]
        self.fast_forward_batches = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"][
            "current"
        ]["completed"]

    def on_save_checkpoint(self, checkpoint):
        if self.ema:
            checkpoint["ema"] = self.ema.state_dict()
        # Copied from:
        # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
        # ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
        # behind, so we're using the optimizer's progress.
        checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["total"]["completed"] = (
            checkpoint["loops"]["fit_loop"]["epoch_loop.automatic_optimization.optim_progress"][
                "optimizer"
            ]["step"]["total"]["completed"]
            * self.trainer.accumulate_grad_batches
        )
        checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"]["completed"] = (
            checkpoint["loops"]["fit_loop"]["epoch_loop.automatic_optimization.optim_progress"][
                "optimizer"
            ]["step"]["current"]["completed"]
            * self.trainer.accumulate_grad_batches
        )
        # _batches_that_stepped tracks the number of global steps, not the number
        # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
        checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"][
            "_batches_that_stepped"
        ] = checkpoint["loops"]["fit_loop"]["epoch_loop.automatic_optimization.optim_progress"][
            "optimizer"
        ][
            "step"
        ][
            "total"
        ][
            "completed"
        ]
        if "sampler" not in checkpoint.keys():
            checkpoint["sampler"] = {}
        if hasattr(self.trainer.train_dataloader.sampler, "state_dict"):
            sampler_state_dict = self.trainer.train_dataloader.sampler.state_dict()
            checkpoint["sampler"]["random_state"] = sampler_state_dict.get("random_state", None)
        else:
            checkpoint["sampler"]["random_state"] = None

    def on_train_start(self):
        if self.ema:
            self.ema.move_shadow_params_to_device(self.device)
        # Adapted from:
        # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
        distributed = (
            self.trainer._accelerator_connector.use_distributed_sampler
            and self.trainer._accelerator_connector.is_distributed
        )
        if distributed:
            sampler_cls = dataloader.FaultTolerantDistributedSampler
        else:
            sampler_cls = dataloader.RandomFaultTolerantSampler
        updated_dls = []
        for dl in self.trainer.fit_loop._combined_loader.flattened:
            if hasattr(dl.sampler, "shuffle"):
                dl_sampler = sampler_cls(dl.dataset, shuffle=dl.sampler.shuffle)
            else:
                dl_sampler = sampler_cls(dl.dataset)
            if (
                distributed
                and self.fast_forward_epochs is not None
                and self.fast_forward_batches is not None
            ):
                dl_sampler.load_state_dict(
                    {
                        "epoch": self.fast_forward_epochs,
                        "counter": (self.fast_forward_batches * self.config.loader.batch_size),
                    }
                )
            updated_dls.append(
                torch_geometric.loader.DataLoader(
                    dl.dataset,
                    batch_size=self.config.loader.batch_size,
                    num_workers=self.config.loader.num_workers,
                    pin_memory=self.config.loader.pin_memory,
                    sampler=dl_sampler,
                    shuffle=False,
                    persistent_workers=self.config.loader.num_workers > 0,
                )
            )
        self.trainer.fit_loop._combined_loader.flattened = updated_dls

    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        if self.ema:
            self.ema.update(
                itertools.chain(
                    self.backbone.parameters(),
                    self.discrete_noise.parameters(),
                    self.spatial_noise.parameters(),
                )
            )

    def _subs_parameterization(
        self, logits_X, logits_E, xt, et, node_padding_mask, edge_padding_mask
    ):
        unmasked_indices_X = _get_unmasked_indices(xt)
        unmasked_indices_E = _get_unmasked_indices(et)

        n_edge_features = et.shape[-1]
        active_edges = et[..., :-2].sum(dim=-1)  # Sum over all features except last 2
        num_active_edges = (active_edges == 1).sum(dim=(1, 2))  # Sum over nodes for each batch item
        n_nodes = node_padding_mask.sum(dim=1)

        # If all the edges are denoised, only allow no-edge predictions
        fully_connected = (num_active_edges // 2) >= (n_nodes - 1)
        logits_E[fully_connected] = torch.full_like(logits_E[fully_connected], self.neg_infinity)
        logits_E[fully_connected, :, :, -2] = 0  # Allow only no-edge token

        # log prob at the mask index = - infinity
        logits_X[:, :, -1] += self.neg_infinity  # fine
        logits_E[:, :, :, -1] += self.neg_infinity  # fine

        if self.config.rgfn.compatibility_mask:
            compatibility_mask_X, compatibility_mask_E = get_compatibility_masks(xt, et)
            logits_X[~compatibility_mask_X] += self.neg_infinity
            logits_E[~compatibility_mask_E] += self.neg_infinity

        # Normalize the logits such that x.exp() is
        # a probability distribution over vocab_size.
        logits_X = logits_X - torch.logsumexp(logits_X, dim=-1, keepdim=True)
        logits_E = logits_E - torch.logsumexp(logits_E, dim=-1, keepdim=True)

        # Apply updates directly in the logits matrix.
        # For the logits of the unmasked tokens, set all values
        # to -infinity except for the indices corresponding to
        # the unmasked tokens.

        logits_X[unmasked_indices_X] = self.neg_infinity
        logits_X[(xt == 1) & unmasked_indices_X] = 0

        logits_E[unmasked_indices_E] = self.neg_infinity
        logits_E[(et == 1) & unmasked_indices_E] = 0

        # Hard-code the logits for the diagonals, respecting node_padding_mask
        batch_size, n, _, _ = logits_E.shape
        diag_indices = torch.arange(n, device=logits_E.device)
        diag_values = torch.full((logits_E.shape[-1],), self.neg_infinity, device=logits_E.device)
        diag_values[-2] = 0
        diag_values = diag_values.expand(batch_size, n, -1)
        logits_E[:, diag_indices, diag_indices, :] = diag_values

        # Check if the sums are close to 1 for non-padded tokens
        tolerance = 1e-4
        sum_exp_X = torch.exp(logits_X).sum(dim=-1)
        sum_exp_E = torch.exp(logits_E).sum(dim=-1)

        assert torch.all(
            logits_X[node_padding_mask] <= 0
        ), "Found log probabilities > 0 in non-padded logits_X"
        assert torch.all(torch.abs(sum_exp_X[node_padding_mask] - 1) < tolerance), (
            "Sum of exp(logits_X) is not close to 1 for non-padded tokens: "
            f"{torch.abs(sum_exp_X[node_padding_mask] - 1)}"
        )
        assert torch.all(torch.abs(sum_exp_E[edge_padding_mask] - 1) < tolerance), (
            "Sum of exp(logits_E) is not close to 1 for non-padded tokens: "
            f"{torch.abs(sum_exp_E[edge_padding_mask] - 1)}"
        )

        return logits_X, logits_E

    def _process_sigma(self, sigma):
        if sigma is None:
            assert self.parameterization == "ar"
            return sigma
        if sigma.ndim > 1:
            sigma = sigma.squeeze(-1)
        if not self.time_conditioning:
            sigma = torch.zeros_like(sigma)
        assert sigma.ndim == 1, sigma.shape
        return sigma

    def forward(self, X, E, C, node_padding_mask, edge_padding_mask, sigma, cond=None):
        """Returns log score."""
        sigma = self._process_sigma(sigma)
        with torch.cuda.amp.autocast(dtype=self.dtype):
            if self.config.spatial.pharmacophore_conditioning:
                pharm_types, pharm_pos, pharm_padding_mask = cond
                logits_X, logits_E, C0_pred = self.backbone(
                    X, E, C, node_padding_mask=node_padding_mask, sigma=sigma, pharm_types=pharm_types, pharm_pos=pharm_pos, pharm_padding_mask=pharm_padding_mask
                )
            else:
                logits_X, logits_E, C0_pred = self.backbone(
                    X, E, C, node_padding_mask=node_padding_mask, sigma=sigma
                )

        if self.config.self_conditioning:
            # We have to pass the original X into subs_parameterization
            X = X[..., : X.shape[-1] // 2]
            E = E[..., : E.shape[-1] // 2]

        logits_X_subs, logits_E_subs = self._subs_parameterization(
            logits_X=logits_X,
            logits_E=logits_E,
            xt=X,
            et=E,
            node_padding_mask=node_padding_mask,
            edge_padding_mask=edge_padding_mask,
        )
        return logits_X_subs, logits_E_subs, C0_pred

    def _compute_loss(self, batch, prefix):
        cond_X, cond_C, cond_padding_mask = None, None, None
        if self.config.spatial.pharmacophore_conditioning:
            dense_data, node_padding_mask, edge_padding_mask, coords_mask, cond_X, cond_C, cond_padding_mask = to_dense(
                batch.x,
                batch.edge_index,
                batch.edge_attr,
                batch.batch,
                self.config.model.length,
                batch.coordinates,
                batch.coords_mask,
                batch.pharm_types,
                batch.pharm_pos,
                batch.pharm_padding_mask,
                self.config.spatial.pharmacophore_subset
            )
        else:
            dense_data, node_padding_mask, edge_padding_mask, coords_mask, _, _, _ = to_dense(
                batch.x,
                batch.edge_index,
                batch.edge_attr,
                batch.batch,
                self.config.model.length,
                batch.coordinates,
                batch.coords_mask,
            )
                   
        batch_smiles = batch.smiles
        if self.config.paths.use_lmdb:
            conformers_path = os.path.join(self.config.data.cache_dir, "conformers.lmdb")
        else:
            conformers_path = os.path.join(
                self.config.data.cache_dir, "conformers/final_conformers"
            )
        X, E, C = dense_data.X, dense_data.E, dense_data.C

        C = augment_coordinates(
            coords=C.reshape(C.shape[0], C.shape[1] * C.shape[2], 3),
            coords_masks=coords_mask.reshape(
                coords_mask.shape[0], coords_mask.shape[1] * coords_mask.shape[2]
            ),
            pharm_coords=cond_C,
            pharm_masks=cond_padding_mask,
            center=self.config.spatial.center,
            normalize=self.config.spatial.normalize,
            align=self.config.spatial.align,
            rotate=self.config.spatial.rotate,
            translate=self.config.spatial.translate,
            reference_coords=None,
            conf_dir=conformers_path,
        )

        if self.config.spatial.pharmacophore_conditioning:
            C, cond_C = C

        assert check_batch_coord_means(
            C, coords_mask.reshape(C.shape[0], -1)
        ), "Coordinates must be centered (mean=0)"

        C = C.reshape(C.shape[0], self.config.model.length, MAX_ATOMS, 3)

        X, E, C = self._cast_batch(X), self._cast_batch(E), self._cast_batch(C)
        losses = self._loss(
            X, E, C, batch_smiles, node_padding_mask, edge_padding_mask, coords_mask, prefix, cond_X=cond_X, cond_C=cond_C, cond_padding_mask=cond_padding_mask
        )
        loss = losses.total_loss
        if prefix == "train":
            self.train_metrics["total_loss"].update(losses.total_loss)
            self.train_metrics["node_nll"].update(losses.node_nll)
            self.train_metrics["edge_nll"].update(losses.edge_nll)
            self.train_metrics["ppl"].update(losses.ppl)
            self.train_metrics["node_ppl"].update(losses.node_ppl)
            self.train_metrics["edge_ppl"].update(losses.edge_ppl)
            self.train_metrics["fm_mse"].update(losses.fm_mse)
            self.train_metrics["weighted_bond_loss"].update(losses.weighted_bond_loss)
            self.train_metrics["weighted_mse"].update(losses.weighted_mse)
            metrics = self.train_metrics
        elif prefix == "val":
            self.valid_metrics["total_loss"].update(losses.total_loss)
            self.valid_metrics["node_nll"].update(losses.node_nll)
            self.valid_metrics["edge_nll"].update(losses.edge_nll)
            self.valid_metrics["ppl"].update(losses.ppl)
            self.valid_metrics["node_ppl"].update(losses.node_ppl)
            self.valid_metrics["edge_ppl"].update(losses.edge_ppl)
            self.valid_metrics["fm_mse"].update(losses.fm_mse)
            self.valid_metrics["weighted_bond_loss"].update(losses.weighted_bond_loss)
            self.valid_metrics["weighted_mse"].update(losses.weighted_mse)
            metrics = self.valid_metrics
        elif prefix == "test":
            self.test_metrics["total_loss"].update(losses.total_loss)
            self.test_metrics["node_nll"].update(losses.node_nll)
            self.test_metrics["edge_nll"].update(losses.edge_nll)
            self.test_metrics["ppl"].update(losses.ppl)
            self.test_metrics["node_ppl"].update(losses.node_ppl)
            self.test_metrics["edge_ppl"].update(losses.edge_ppl)
            self.test_metrics["fm_mse"].update(losses.fm_mse)
            self.test_metrics["weighted_bond_loss"].update(losses.weighted_bond_loss)
            self.test_metrics["weighted_mse"].update(losses.weighted_mse)
            metrics = self.test_metrics
        else:
            raise ValueError(f"Invalid prefix: {prefix}")

        self.log_dict(metrics, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def on_train_epoch_start(self):
        self.backbone.train()
        self.spatial_noise.train()
        self.discrete_noise.train()

    def _get_dtype_from_config(self):
        # Map your config precision setting to torch dtype
        precision = self.config.trainer.precision
        if precision == "16" or precision == "16-mixed":
            return torch.float16
        elif precision == "bf16" or precision == "bf16-mixed":
            return torch.bfloat16
        return torch.float32

    def _cast_batch(self, batch):
        dtype = self._get_dtype_from_config()
        if isinstance(batch, torch.Tensor):
            return batch.to(dtype=dtype)
        else:
            raise ValueError(f"Unsupported batch type: {type(batch)}")

    def training_step(self, batch, batch_idx):
        loss = self._compute_loss(batch, prefix="train")
        self.log(
            name="trainer/loss", value=loss.item(), on_step=True, on_epoch=False, sync_dist=True
        )
        # Print names of layers with no gradients
        # for name, param in self.backbone.named_parameters():
        #     if param.grad is None:
        #         print(f"Layer {name} has no gradient")
        return loss

    def on_validation_epoch_start(self):
        if self.ema:
            self.ema.store(
                itertools.chain(
                    self.backbone.parameters(),
                    self.discrete_noise.parameters(),
                    self.spatial_noise.parameters(),
                )
            )
            self.ema.copy_to(
                itertools.chain(
                    self.backbone.parameters(),
                    self.discrete_noise.parameters(),
                    self.spatial_noise.parameters(),
                )
            )
        self.backbone.eval()
        self.spatial_noise.eval()
        self.discrete_noise.eval()
        assert self.valid_metrics.total_loss.mean_value == 0
        assert self.valid_metrics.total_loss.weight == 0

    def validation_step(self, batch, batch_idx):
        if self.config.denoise_discrete or self.config.denoise_coordinates:
            if not hasattr(self, "cached_val_batch"):
                self.cached_val_batch = batch
            elif self.cached_val_batch.x.shape[0] < self.config.loader.eval_batch_size:
                remaining = self.config.loader.eval_batch_size - self.cached_val_batch.x.shape[0]
                if batch.x.shape[0] <= remaining:
                    # Add entire batch if it fits
                    self.cached_val_batch = torch_geometric.data.Batch.from_data_list(
                        list(self.cached_val_batch.to_data_list()) + list(batch.to_data_list())
                    )
                else:
                    # Only take what we need from this batch
                    self.cached_val_batch = torch_geometric.data.Batch.from_data_list(
                        list(self.cached_val_batch.to_data_list())
                        + list(batch.to_data_list())[:remaining]
                    )
        return self._compute_loss(batch, prefix="val")

    def on_validation_epoch_end(self):
        if (
            not self.trainer.sanity_checking
            and self.config.eval.generate_samples
            and not self.parameterization == "ar"
        ):
            graph_samples = []
            for _ in range(self.config.sampling.num_sample_batches):
                X, E, C, C_true, node_padding_mask, _ = self._sample()
                graph_samples.append((X, E, C, C_true, node_padding_mask))

            if self.trainer.global_rank == 0 and hasattr(self.trainer.logger, "log_table"):
                # Take first 2 samples from batch of 128
                (
                    sample_X,
                    sample_E,
                    sample_C,
                    sample_C_true,
                    sample_node_padding_mask,
                ) = graph_samples[0]
                single_graphs = []
                for i in range(self.config.sampling.num_sample_log):
                    if sample_C_true is not None:
                        sample_C_true_i = sample_C_true[i]
                    else:
                        sample_C_true_i = None
                    single_graphs.append(
                        (
                            sample_X[i],
                            sample_E[i],
                            sample_C[i],
                            sample_C_true_i,
                            sample_node_padding_mask[i],
                        )
                    )

                columns = ["Generated Samples"]
                data = [[s] for s in single_graphs]

                smiles_samples = []
                frag_indices = []
                all_indices = []
                energies = []
                for X, E, C, C_true, node_padding_mask in single_graphs:
                    length = node_padding_mask.sum()
                    X, E = remove_graph_padding(X, E, length)
                    C = remove_coords_padding(C, length)
                    if C_true is not None:
                        C_true = remove_coords_padding(C_true, length)
                    ohx = onehot_to_node_indices(X)
                    ohe = onehot_to_reaction_type_and_centers(E)
                    ohx_np = ohx.cpu().numpy()
                    frag_indices.append(ohx_np)
                    built_mol = build_molecule(ohx, ohe, smiles=False)
                    smiles_sequence = None
                    if built_mol:
                        smiles_sequence = Chem.MolToSmiles(built_mol)
                    if is_valid_smiles(smiles_sequence):
                        if self.config.spatial.normalize:
                            C = C * COORDS_STD
                        smiles_samples.append(smiles_sequence)
                        mol = coordinates_to_mol(X, E, C)
                        if C_true is not None:
                            C_true = C_true * COORDS_STD
                            mol_true = coordinates_to_mol(X, E, C_true)
                        energy = calc_energy(mol)
                        if energy is not None:
                            energies.append(energy)
                        # log molecule images using wandb.Molecule.from_rdkit
                        try:
                            self.logger.experiment.log(
                                {
                                    f"molecule_images/step_{self.global_step}": wandb.Molecule.from_rdkit(
                                        mol,
                                        convert_to_3d_and_optimize=False,
                                        caption=smiles_sequence,
                                    )
                                }
                            )
                            if C_true is not None:
                                self.logger.experiment.log(
                                    {
                                        f"molecule_images/true_step_{self.global_step}": wandb.Molecule.from_rdkit(
                                            mol_true,
                                            convert_to_3d_and_optimize=False,
                                            caption=smiles_sequence,
                                        )
                                    }
                                )
                        except Exception as e:
                            print(f"Error logging molecule images: {e}")
                    else:
                        smiles_samples.append(None)
                    all_indices.extend(ohx_np)

                if self.config.rgfn.reassembly_logging:
                    # Track fragment indices for histogram
                    # Create histogram data for current step
                    unique_indices, counts = np.unique(all_indices, return_counts=True)
                    counts_hist, bins = np.histogram(all_indices, bins=len(unique_indices))

                    # Create tables
                    current_table_data = [
                        [idx, count] for idx, count in zip(unique_indices, counts_hist)
                    ]
                    current_table = wandb.Table(
                        data=current_table_data, columns=["Fragment Index", "Count"]
                    )

                    self.logger.experiment.log(
                        {
                            f"fragment_distributions/step_{self.global_step}": wandb.plot.bar(
                                current_table,
                                "Fragment Index",
                                "Count",
                                title=f"Fragment Distribution at Step {self.global_step}",
                            )
                        }
                    )

                    columns.append("Reassembled Samples")
                    valid_smiles = [sample for sample in smiles_samples if sample is not None]
                    p_valid = len(valid_smiles) / len(smiles_samples)

                    # log novelty
                    num_novel = len([smile for smile in valid_smiles if smile not in TRAIN_SMILES])
                    self.log(
                        "val/novelty",
                        num_novel / len(smiles_samples),
                        on_epoch=True,
                        on_step=False,
                        sync_dist=False,
                    )

                    # Format each row to have consistent length
                    formatted_data = []
                    for frags, smile in zip(frag_indices, smiles_samples):
                        frag_str = ",".join(map(str, frags))
                        formatted_data.append([frag_str, smile if smile is not None else "None"])

                    self.log("val/p_valid", p_valid, on_epoch=True, on_step=False, sync_dist=False)
                    if len(energies) > 0:  # Only log if there are valid samples
                        print("Median energy: ", torch.tensor(energies).median())
                        self.log(
                            "val/median_energy",
                            torch.tensor(energies).median(),
                            on_epoch=True,
                            on_step=False,
                            sync_dist=False,
                        )

                    self.trainer.logger.log_table(
                        key=f"samples@global_step{self.global_step}",
                        columns=columns,
                        data=formatted_data,
                    )

        if self.ema:
            self.ema.restore(
                itertools.chain(
                    self.backbone.parameters(),
                    self.discrete_noise.parameters(),
                    self.spatial_noise.parameters(),
                )
            )

    def configure_optimizers(self):
        # TODO(yair): Lightning currently giving this warning when using `fp16`:
        #  "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
        #  Not clear if this is a problem or not.
        #  See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.backbone.parameters(),
                self.discrete_noise.parameters(),
                self.spatial_noise.parameters(),
            ),
            lr=self.config.optim.lr,
            betas=(self.config.optim.beta1, self.config.optim.beta2),
            eps=self.config.optim.eps,
            weight_decay=self.config.optim.weight_decay,
        )

        scheduler = hydra.utils.instantiate(self.config.lr_scheduler, optimizer=optimizer)
        scheduler_dict = {
            "scheduler": scheduler,
            "interval": "step",
            "monitor": "val/loss",
            "name": "trainer/lr",
        }
        return [optimizer], [scheduler_dict]

    def q_xt_et(self, X, E, move_chance):
        """Computes the noisy samples Xt and Et.
        Args:
        X: int torch.Tensor with shape (batch_size, n_nodes, n_node_features), node matrix.
        E: int torch.Tensor with shape (batch_size, n_nodes, n_nodes, n_edge_features), edge matrix.
        move_chance: float torch.Tensor with shape (batch_size, 1).
        """
        # Expand move_chance to match the shape of X and E
        node_move_chance = move_chance.unsqueeze(2).expand(-1, -1, X.shape[2])
        edge_move_chance = move_chance.unsqueeze(2).unsqueeze(3).expand(-1, -1, -1, E.shape[3])

        # Generate random values for nodes and edges and expand to feature dim
        node_rand = torch.rand(X.shape[0], X.shape[1], device=X.device)
        edge_rand = torch.rand(E.shape[0], E.shape[1], E.shape[2], device=E.device)
        node_rand = node_rand.unsqueeze(-1).expand_as(X)

        # Make edge_rand symmetric by averaging with its transpose
        edge_rand = 0.5 * (edge_rand + edge_rand.transpose(1, 2))
        edge_rand = edge_rand.unsqueeze(-1).expand_as(E)

        # Generate random tensors and compare with expanded move_chance
        node_move_indices = node_rand < node_move_chance
        edge_move_indices = edge_rand < edge_move_chance

        # Create one-hot encoding for mask token
        mask_onehot_X = torch.zeros_like(X)
        mask_onehot_X[..., -1] = 1  # Set last feature to 1
        mask_onehot_E = torch.zeros_like(E)
        mask_onehot_E[..., -1] = 1  # Set last feature to 1

        # Apply masking
        Xt = torch.where(node_move_indices, mask_onehot_X, X)
        Et = torch.where(edge_move_indices, mask_onehot_E, E)

        return Xt, Et

    def C_t(self, C0, t, current_coords_mask=None, true_coords_mask=None, prefix=None, cond_C=None, cond_padding_mask=None):
        """Get the coordinates at time t using linear probability path with optional spatial noise."""
        t = t.view(-1, 1, 1, 1)
        z = torch.randn_like(C0)
        C1 = self._sample_coord_prior(C0.shape, coords_mask=current_coords_mask)
        current_coords_mask_flat = current_coords_mask.reshape(current_coords_mask.shape[0], -1)
        true_coords_mask_flat = true_coords_mask.reshape(true_coords_mask.shape[0], -1)

        if self.config.spatial.equivariant_ot and prefix == "train":
            if current_coords_mask is None:
                raise ValueError("coords_mask is required when using equivariant_ot")

            original_shape = C0.shape
            C1_flat = C1.reshape(C1.shape[0], -1, 3)
            C0_flat = C0.reshape(C0.shape[0], -1, 3)
            C1_aligned = align_and_permute(
                C0_flat, C1_flat, coords_mask=true_coords_mask_flat, sqrd=True
            )
            C1_aligned = center_atom_coords(C1_aligned, current_coords_mask_flat)
            C1 = C1_aligned.reshape(original_shape)

        # Reshape tensors to (batch_size, total_atoms, 3) for easier masking
        C0_flat = C0.reshape(C0.shape[0], -1, 3)
        C1_flat = C1.reshape(C1.shape[0], -1, 3)
        assert check_batch_coord_means(
            C0_flat.reshape(C0_flat.shape[0], -1, 3), true_coords_mask_flat
        ), "C0 must be centered (mean=0)"
        assert check_batch_coord_means(
            C1_flat.reshape(C1_flat.shape[0], -1, 3), current_coords_mask_flat
        ), "C1 must be centered (mean=0)"

        # Calculate center of mass for valid atoms in C1
        expanded_true_mask = true_coords_mask_flat.unsqueeze(-1).expand(-1, -1, 3)
        expanded_current_mask = current_coords_mask_flat.unsqueeze(-1).expand(-1, -1, 3)
        C1_valid = torch.where(expanded_true_mask, C1_flat, torch.zeros_like(C1_flat))
        valid_counts = true_coords_mask_flat.sum(dim=1, keepdim=True).unsqueeze(-1)
        C1_center = C1_valid.sum(dim=1, keepdim=True) / valid_counts  # (batch_size, 1, 3)

        # Add C1's center to C0 for valid atoms only
        C0_flat = torch.where(expanded_true_mask, C0_flat + C1_center, C0_flat)
        
        # Also add C1's center to the pharmacophore positions if they're there
        if cond_C is not None:
            expanded_cond_mask = cond_padding_mask.unsqueeze(-1).expand(-1, -1, 3)
            cond_C = torch.where(expanded_cond_mask.bool(), cond_C + C1_center, cond_C)

        C0_flat = torch.where(
            (expanded_current_mask & ~expanded_true_mask).bool(), C1_flat, C0_flat
        )

        C0 = C0_flat.reshape(C0.shape)
        C1 = C1_flat.reshape(C1.shape)

        if self.config.spatial.do_noise:
            spatial_sigma_t = self.spatial_noise.total_noise(t).view(-1, 1, 1, 1)
            interpolated = (1 - t) * C0 + t * C1 + spatial_sigma_t * z
        else:
            interpolated = (1 - t) * C0 + t * C1

        try:
            assert check_batch_coord_means(
                interpolated.reshape(interpolated.shape[0], -1, 3),
                current_coords_mask.reshape(interpolated.shape[0], -1),
            ), "Coordinates must be centered (mean=0)"
        except:
            print(
                f"current_coords_mask.reshape: {current_coords_mask.reshape(interpolated.shape[0], -1)}"
            )
            raise
        interpolated = interpolated.reshape(
            interpolated.shape[0], self.config.model.length, MAX_ATOMS, 3
        )
        if cond_C is None:
            return C0, interpolated
        else:
            return C0, interpolated, cond_C

    def create_inference_dataloader(self, data_source=None, batch_size=32, shuffle=True):
        """
        Create a dataloader for inference time sampling.
        
        Args:
            data_source: Optional path to dataset file. If None, uses default path.
            batch_size: Batch size for the dataloader
            shuffle: Whether to shuffle the data
        
        Returns:
            torch_geometric.loader.DataLoader
        """
        # Default path to dataset file
        if data_source is None:
            data_source = "/hpf/projects/XXXXX/XXXXX/SynCoGen/data/all_steps_clean/dataset_list_full.pt"
        
        # Load the graph data
        if isinstance(data_source, str):
            graph_data_list = torch.load(data_source)
        else:
            # Assume data_source is already a list of graphs
            graph_data_list = data_source
        
        # Get conformer paths
        if self.config.paths.use_lmdb:
            conformers_path = os.path.join(self.config.data.cache_dir, "conformers.lmdb")
        else:
            conformers_path = os.path.join(
                self.config.data.cache_dir, "conformers/final_conformers"
            )
        
        # Create a GraphDataset
        from src.dataloader import GraphDataset
        dataset = GraphDataset(
            config=self.config,
            cache_dir=self.config.data.cache_dir,
            data_list=graph_data_list,
            sample_conformer=True,  # Random conformer selection
            coord_mask_value=self.config.spatial.coord_mask_value,
        )
        
        # Create dataloader
        dataloader = torch_geometric.loader.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=0,  # Keep simple for inference
            pin_memory=False,
        )
        
        return dataloader

    def get_pharm_cond(self, batch):
        """
        Get pharmacophore profiles for sampling with conditional input.
        """
        dense_data, node_padding_mask, edge_padding_mask, coords_mask, cond_X, cond_C, cond_padding_mask = to_dense(
            batch.x,
            batch.edge_index,
            batch.edge_attr,
            batch.batch,
            self.config.model.length,
            batch.coordinates,
            batch.coords_mask,
            batch.pharm_types,
            batch.pharm_pos,
            batch.pharm_padding_mask,
            self.config.spatial.pharmacophore_subset
        )

        C = dense_data.C
        
        # Get conformer paths for augment_coordinates
        if self.config.paths.use_lmdb:
            conformers_path = os.path.join(self.config.data.cache_dir, "conformers.lmdb")
        else:
            conformers_path = os.path.join(
                self.config.data.cache_dir, "conformers/final_conformers"
            )
        
        # Apply coordinate augmentations (centering, normalization, etc.)
        C = augment_coordinates(
            coords=C.reshape(C.shape[0], C.shape[1] * C.shape[2], 3),
            coords_masks=coords_mask.reshape(
                coords_mask.shape[0], coords_mask.shape[1] * coords_mask.shape[2]
            ),
            pharm_coords=cond_C,
            pharm_masks=cond_padding_mask,
            center=self.config.spatial.center,
            normalize=self.config.spatial.normalize,
            align=self.config.spatial.align,
            rotate=self.config.spatial.rotate,
            translate=self.config.spatial.translate,
            reference_coords=None,
            conf_dir=conformers_path,
        )

        C, cond_C = C

        return cond_X.to(self.device), cond_C.to(self.device), cond_padding_mask.to(self.device)

    def generate_data_pairs(self, batch):
        """
        Sample a batch of (C0, C1) coordinate pairs from a batch of graphs.
        
        Args:
            batch: torch_geometric.data.Batch object
        
        Returns:
            Tuple of (C0, C1) coordinate tensors of shape [num_pairs, n_nodes, max_atoms, 3]
        """
        # Convert batch to dense format
        dense_data, node_padding_mask, edge_padding_mask, coords_mask = to_dense(
            batch.x,
            batch.edge_index,
            batch.edge_attr,
            batch.batch,
            self.config.model.length,
            batch.coordinates,
            batch.coords_mask,
        )
        
        # Get coordinates and apply same preprocessing as in training
        C0 = dense_data.C
        
        # Get conformer paths for augment_coordinates
        if self.config.paths.use_lmdb:
            conformers_path = os.path.join(self.config.data.cache_dir, "conformers.lmdb")
        else:
            conformers_path = os.path.join(
                self.config.data.cache_dir, "conformers/final_conformers"
            )
        
        # Apply coordinate augmentations (centering, normalization, etc.)
        C0 = augment_coordinates(
            coords=C0.reshape(C0.shape[0], C0.shape[1] * C0.shape[2], 3),
            coords_masks=coords_mask.reshape(
                coords_mask.shape[0], coords_mask.shape[1] * coords_mask.shape[2]
            ),
            center=self.config.spatial.center,
            normalize=self.config.spatial.normalize,
            align=self.config.spatial.align,
            rotate=self.config.spatial.rotate,
            translate=self.config.spatial.translate,
            reference_coords=None,
            conf_dir=conformers_path,
        )

        assert check_batch_coord_means(
            C0,
            coords_mask.reshape(C0.shape[0], -1),
        ), "Coordinates must be centered (mean=0)"

        # Reshape back to fragment format
        C0 = C0.reshape(C0.shape[0], self.config.model.length, MAX_ATOMS, 3)
        # Create current_coords_mask as all True (all atoms are "current")
        current_coords_mask = torch.ones_like(coords_mask, dtype=torch.bool)
        
        # Sample C1 using C_t with t=1 (fully noised)
        t = torch.ones(C0.shape[0])
        
        C0_modified, C1 = self.C_t(
            C0.to(self.device),
            t.to(self.device),
            current_coords_mask=current_coords_mask.to(self.device),
            true_coords_mask=coords_mask.to(self.device),
            prefix="sample"
        )
        
        return C0_modified, C1

    def v_t_pred(self, C, C0_pred, t):
        """
        Get velocity vector from coordinate prediction.
        """
        return (C0_pred - C) / t

    def _sample_graph_prior(self, shape, onehot=True):
        """
        Creates a tensor for the mask token, either one-hot encoded or as indices.

        Args:
        shape: Tuple representing the shape of the tensor to create.
        onehot: If True, returns one-hot encoded tensor. If False, returns tensor of indices.

        Returns:
        A tensor of the given shape with the mask token encoded.
        """
        # Create one-hot encoding with mask token
        mask_onehot = torch.zeros(shape, dtype=self.dtype)
        mask_onehot[..., -1] = 1

        if onehot:
            return mask_onehot
        else:
            return torch.argmax(mask_onehot, dim=-1)

    def _sample_coord_prior(self, shape, coords_mask=None):
        """Sample coordinates based on the configured prior distribution.
        Samples independently for each item in the batch, respecting node mask.

        Args:
            shape: Tuple (bs, n, max_atoms, 3) for output shape
            coords_mask: Boolean tensor of shape (bs, n, max_atoms) indicating valid nodes
        """
        prior = self.config.spatial.prior
        batch_size = shape[0]

        # Initialize output tensor with zeros
        prior_choices = torch.zeros(shape, device=self.device)

        # Sample coordinates only for valid nodes
        for b in range(batch_size):
            if prior == "ones":
                sample = torch.ones(shape[1:], device=self.device)
            elif prior == "uniform":
                sample = torch.rand(shape[1:], device=self.device) * 2 - 1
            elif prior == "gaussian":
                # generator = torch.Generator(device=self.device)
                # generator.manual_seed(42)
                sample = torch.randn(shape[1:], device=self.device)
            else:
                raise ValueError(f"Unknown spatial prior: {prior}")

            if coords_mask is not None:
                mask = coords_mask[b].unsqueeze(-1)
                mask = mask.expand(shape[1], shape[2], 3)
                sample = sample * mask

            prior_choices[b] = sample

        coords_mask_flat = coords_mask.reshape(coords_mask.shape[0], -1)
        prior_choices_flat = prior_choices.reshape(prior_choices.shape[0], -1, 3)
        num_atoms = coords_mask_flat.sum(dim=-1)
        if self.config.spatial.scale_noise:
            scale = torch.log(num_atoms).view(-1, 1, 1)
            prior_choices_flat = prior_choices_flat * scale * self.config.spatial.scale_noise_factor
        prior_choices_flat = center_atom_coords(prior_choices_flat, coords_mask_flat)
        prior_choices = prior_choices_flat.reshape(prior_choices.shape)
        assert check_batch_coord_means(
            prior_choices_flat, coords_mask_flat
        ), "Coordinates must be centered (mean=0)"

        return prior_choices

    def reward_g(self, C: torch.Tensor, cond_mol: torch.Tensor, coef: float, t: float = None) -> torch.Tensor:
        """
        Naive reward gradient guidance.
        """
        with torch.enable_grad():
            C = C.detach().requires_grad_(True)
            batch_reward = torch.mean(torch.stack([get_overlap(c * COORDS_STD, cond_mol) for i, c in enumerate(C.reshape(C.shape[0], -1, 3))]))
            print("BATCH REWARD", batch_reward)
            batch_reward_scaled = batch_reward * coef # The second term here is just to provide an additional boost to early steps that doesn't decay too far
            (reward_grad, ) = torch.autograd.grad(outputs=batch_reward_scaled, inputs=C)
            return reward_grad

    def mc_guidance(
        self,
        C_t: torch.Tensor,             # [B, N, 3]  current state at time t
        C0_mc: torch.Tensor,           # [B, N_mc, N, 3]  sampled data endpoints x0^{(i)}
        C1_mc: torch.Tensor,           # [B, N_mc, N, 3]  sampled noise endpoints x1^{(i)}
        cond_mol: torch.Tensor,        # [N_ref, 3] or [B, N_ref, 3] for reward evaluation
        t: torch.Tensor,               # [B] or scalar in [0,1]
        beta: float = 1.0,             # temperature for reward
        sigma_t: float = 0.5,         # small path noise for likelihood
    ) -> torch.Tensor:
        """
        Monte-Carlo guidance for linear interpolation path with x0=data, x1=noise.
        """
        
        B, N, _ = C_t.shape
        _, N_mc, _, _ = C0_mc.shape
        device = C_t.device
        dtype  = C_t.dtype

        # Ensure t is [B,1,1,1]
        if not torch.is_tensor(t):
            t = torch.tensor(t, device=device, dtype=dtype)
        if t.dim() == 0:
            t = t.expand(B)
        t_b = t.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        one_minus_t = 1.0 - t_b

        # 1) Means of conditional path for each proposal (b,i): mu = (1-t)x0 + t x1
        MU = one_minus_t * C0_mc + t_b * C1_mc                        # [B,N_mc,N,3]

        # 2) Log-likelihood weights: p_t(C_t | z_i) with small sigma_t
        #    Expand C_t to [B,N_mc,N,3] to subtract per-proposal mean
        C_t_exp = C_t[:, None, :, :]                                   # [B,1,N,3]
        diff = (C_t_exp - MU) / (sigma_t + 1e-8)
        log_w = -0.5 * (diff**2).sum(dim=(2,3))                        # [B,N_mc]
        wbar = torch.softmax(log_w, dim=1)                             # normalize across N_mc per batch item

        # 3) Reward factors on data endpoints x0^{(i)} (no gradients needed)
        #    Example reward_fn: lambda X0, cond: get_overlap_batched(X0*coords_std, cond) -> [B,N_mc]
        R = torch.stack([
            torch.stack([
                get_overlap(c * COORDS_STD, cond_mol) 
                for c in C0_mc[b].reshape(C0_mc.shape[1], -1, 3)
            ]) 
            for b in range(C0_mc.shape[0])
        ])                    # [B,N_mc]
        u = torch.exp(beta * R).clamp_max(1e6)                         # [B,N_mc]

        # 4) Partition function estimate Z_hat = sum_i wbar_i * u_i
        Z_hat = (wbar * u).sum(dim=1, keepdim=True) + 1e-12            # [B,1]
        coef = (u / Z_hat - 1.0)                                       # [B,N_mc]
        coef = coef[:, :, None, None]                                   # [B,N_mc,1,1]

        V = C0_mc - C1_mc                                          # [B,N_mc,N,3]
        log_wbar = torch.log(wbar + 1e-12)

        # 6) Guidance: sum_i wbar_i * ((u_i/Z)-1) * v_i
        g = (wbar[:, :, None, None] * coef * V).sum(dim=1)             # [B,N,3]
        return g

    def _euler_step(self, C, C0_pred, t, dt):
        guidance_term = torch.Tensor([0]).to(C.device).to(C.dtype)
        inference_annealing_term = torch.Tensor([1]).to(C.device).to(C.dtype)
        if self.config.sampling.spatial.guidance == "reward":
            cond_mol = sdf_to_coordinates(self.config.sampling.spatial.cond_mol_path).to(C.device).to(C.dtype)
            cond_mol = cond_mol - cond_mol.mean(dim=0, keepdim=True)
            guidance_term = self.reward_g(C, cond_mol, 1e4, t)
        elif self.config.sampling.spatial.guidance == "mc":
            cond_mol = sdf_to_coordinates(self.config.sampling.spatial.cond_mol_path).to(C.device).to(C.dtype)
            cond_mol = cond_mol - cond_mol.mean(dim=0, keepdim=True)
            C0_mc_list = []
            C1_mc_list = []
            for i in range(C.shape[0]):
                batch = next(iter(self.inference_dataloader))
                C0_rand, C1_rand = self.generate_data_pairs(batch)
                C0_mc_list.append(C0_rand.reshape(C0_rand.shape[0], -1, 3))
                C1_mc_list.append(C1_rand.reshape(C1_rand.shape[0], -1, 3))
            C0_mc = torch.stack(C0_mc_list, dim=0) # [B, N_mc, N, 3]
            C1_mc = torch.stack(C1_mc_list, dim=0) # [B, N_mc, N, 3]
            print("C0_mc", C0_mc.shape, "C1_mc", C1_mc.shape)
            guidance_term = self.mc_guidance(C.reshape(C.shape[0], -1, 3), C0_mc, C1_mc, cond_mol, t).reshape(C.shape)
            print("MC GUIDANCE MEAN", guidance_term.mean())
        if self.config.sampling.spatial.inference_annealing:
            inference_annealing_term = self.config.sampling.spatial.inference_annealing_coef * t
        
        vtheta = self.v_t_pred(C, C0_pred, t)
        #print("VELOCITY AND GUIDANCE MEAN", torch.mean(torch.abs(vtheta * dt)), torch.mean(torch.abs(guidance_term * dt)))
        C_next = C + (vtheta * dt + guidance_term * dt) * inference_annealing_term
        return C_next

    def topk_lowest_masking(self, scores, cutoff_len):
        sorted_scores, _ = scores.sort(dim=-1)
        threshold = sorted_scores.gather(dim=-1, index=cutoff_len)
        return scores < threshold

    def sample_edges(
        self, E: torch.Tensor, p_e0: torch.Tensor, lengths: torch.Tensor, argmax: bool = True
    ) -> torch.Tensor:
        """
        Build a one‑hot edge tensor that keeps any already‑denoised incoming edge
        and, for nodes still missing one, samples exactly ONE (parent, type) pair
        from the score logits.
        Args
        ----
        E       : (B, n, n, R)  current (partly‑denoised) one‑hot edges
        p_e0    : (B, n, n, R)  score logits; last‑2 channels = [no‑edge, masked]
        lengths : (B,)          number of real nodes per graph (padding after that)
        argmax  : bool          if True, take argmax instead of sampling
        Returns
        -------
        E_out   : (B, n, n, R)  one‑hot, upper/lower triangles mirrored, obeying
                • every j>0 has exactly one incoming edge i<j (or is padded)
                • channel R‑2 holds "no‑edge", channel R‑1 left for "masked"
        """
        B, n, _, R = p_e0.shape
        dev = p_e0.device
        real_R = R - 2  # usable reaction types
        E_out = torch.zeros_like(p_e0)  # initialise
        # Boolean: is there already a denoised incoming edge for (i<j)?
        existing_mask = E[..., :real_R].sum(-1) > 0  # (B,n,n)
        scores = p_e0[..., :real_R]  # drop no‑edge & masked
        for j in range(1, n):
            # graphs that already have an incoming edge to j
            has_edge = existing_mask[:, :j, j].any(1)  # (B,)
            need_edge = ~has_edge
            # ── (1) copy the existing edge  ────────────────────────────────────
            if has_edge.any():
                rows_h = torch.arange(B, device=dev)[has_edge]
                parent_h = existing_mask[has_edge, :j, j].float().argmax(1)  # (B')
                edge_h = (E[has_edge, :j, j, :real_R].float().argmax(-1))[
                    torch.arange(parent_h.size(0), device=dev), parent_h
                ]  # (B')
                E_out[rows_h, parent_h, j, edge_h] = 1
            # ── (2) sample new parent/type where needed  ───────────────────────
            if need_edge.any():
                rows_n = torch.arange(B, device=dev)[need_edge]
                cand = scores[need_edge, :j, j, :]  # (B', j, real_R)
                flat = cand.reshape(rows_n.size(0), -1)  # (B', j*real_R)
                zero_row = flat.sum(1) == 0  # avoid all‑zero
                flat[zero_row] = 1.0
                if argmax:
                    idx = flat.argmax(dim=1)  # (B')
                else:
                    idx = torch.multinomial(flat, 1).squeeze(1)  # (B')
                k_new = idx.remainder(real_R)
                i_new = idx // real_R
                E_out[rows_n, i_new, j, k_new] = 1
        # ── (3) mark untouched positions as "no‑edge"  ─────────────────────────
        none_mask = E_out.sum(-1, keepdim=True) == 0
        E_out[..., -2] = none_mask.squeeze(-1).float()  # channel R‑2 = no‑edge
        # ── (4) zero‑out padding beyond graph length, set to no‑edge ───────────
        if lengths is not None:
            for b in range(B):
                l = lengths[b].item()
                if l < n:
                    E_out[b, l:, :, :] = 0
                    E_out[b, :, l:, :] = 0
                    E_out[b, l:, :, -2] = 1  # no‑edge
                    E_out[b, :, l:, -2] = 1
        # ── (5) mirror upper triangle per‑channel for symmetry ─────────────────
        upper = torch.triu(E_out.permute(0, 3, 1, 2))  # (B,R,n,n)
        upper = upper.bool()
        full = upper | upper.transpose(-1, -2)  # per‑channel OR
        E_out = full.permute(0, 2, 3, 1).float()  # back to (B,n,n,R)
        return E_out

    def _update_X_E_C_path_planning(self, X, E, C, node_padding_mask, edge_padding_mask, t, dt):
        # Calculate unmasking rate according to schedule
        discrete_sigma_t, _ = self.discrete_noise(t)
        move_chance_t = 1 - torch.exp(-discrete_sigma_t)

        p_x0, p_e0, C0_pred = self.forward(
            X, E, C, node_padding_mask, edge_padding_mask, discrete_sigma_t
        )

        if self.config.self_conditioning:
            X, E, C = (
                X[..., : X.shape[-1] // 2],
                E[..., : E.shape[-1] // 2],
                C[..., : C.shape[-1] // 2],
            )

        ### UPDATE COORDINATES ###

        ### OPTIONALLY ALIGN PREDICTION TO CURRENT NOISED COORDINATES ###
        if self.config.sampling.boltz_inference_align:
            X_indices = X.argmax(dim=-1)
            atom_mask = node_to_atom_padding_mask(node_padding_mask) & perfrag_atom_padding_mask(
                X_indices
            )
            C_flat = C.reshape(C.shape[0], -1, 3)
            C0_pred_flat = C0_pred.reshape(C0_pred.shape[0], -1, 3)
            aligned_C0_pred = align_and_permute(
                C_flat, C0_pred_flat, coords_mask=atom_mask_flat, sqrd=True
            )
            aligned_C0_pred = aligned_C0_pred.reshape(C.shape[0], C.shape[1], C.shape[2], 3)
        else:
            aligned_C0_pred = C0_pred

        if self.config.sampling.spatial.integrator == "euler":
            C_next = self._euler_step(C, aligned_C0_pred, t[0], dt)
        else:
            raise ValueError(f"Unknown integrator: {self.config.sampling.spatial.integrator}")

        ### PATH PLANNING FOR GRAPH NODES ###
        # Get current node masks
        X_indices = X.argmax(dim=-1)
        last_mask_X = X_indices == self.node_mask_index

        unmask_candidates_X = ~last_mask_X

        # Sample node predictions and scores
        X_logits = torch.clamp(p_x0, max=MAX_LOGIT).exp()
        x0_indices, logp_X = _sample_categorical(
            X_logits, temperature=self.config.sampling.path_planning.tau
        )
        logp_X = logp_X.max(dim=-1)[0]
        # Calculate scores for masking decisions
        if self.config.sampling.path_planning.score_type == "confidence":
            score_X = logp_X
        elif self.config.sampling.path_planning.score_type == "random":
            score_X = torch.rand_like(logp_X).log()

        # Apply node padding mask to scores
        score_X = score_X.masked_fill(~node_padding_mask, float("inf"))

        # Apply eta scaling to unmask candidates
        score_X[unmask_candidates_X] *= self.config.sampling.path_planning.eta

        # Calculate how many tokens to mask based on schedule
        num_to_mask_X = (
            node_padding_mask.sum(dim=1, keepdim=True).float() * (move_chance_t)
        ).long()
        mask_X = self.topk_lowest_masking(score_X, num_to_mask_X)

        # Create new node representations
        X_next = X.clone()

        # Apply masks
        X_next_indices = X_indices.clone()
        X_next_indices[mask_X] = self.node_mask_index

        # Unmask nodes that should be revealed
        mask_to_x0_X = last_mask_X & ~mask_X
        X_next_indices[mask_to_x0_X] = x0_indices[mask_to_x0_X]

        # Convert indices back to one-hot
        X_next = _to_onehot(X_next_indices, self.n_node_features).to(X.dtype)

        ### PATH PLANNING FOR GRAPH EDGES ###
        # Similar process for edges
        mask_token_id_E = self.edge_mask_index

        # Get current edge masks
        E_indices = E.argmax(dim=-1)
        last_mask_E = E_indices == mask_token_id_E

        # Get candidates for unmasking (masked edges that aren't fixed)
        unmask_candidates_E = ~last_mask_E

        # Sample edge predictions and scores
        E_logits = torch.clamp(p_e0, max=MAX_LOGIT).exp()
        e0_indices, logp_E = _sample_categorical(
            E_logits, temperature=self.config.sampling.path_planning.tau
        )

        # Symmetrize and set diagonals to "no edge" index (second to last index)
        e0_indices_upper = torch.triu(
            e0_indices, diagonal=1
        )  # Get upper triangle (excluding diagonal)
        e0_indices = e0_indices_upper + e0_indices_upper.transpose(
            1, 2
        )  # Add transpose to make symmetric
        diag_indices = torch.arange(e0_indices.shape[1], device=e0_indices.device)
        e0_indices[:, diag_indices, diag_indices] = self.edge_mask_index - 1

        # Model output logits are already symmetric
        logp_E = logp_E.max(dim=-1)[0]

        # Calculate scores for masking decisions
        if self.config.sampling.path_planning.score_type == "confidence":
            score_E = logp_E
        elif self.config.sampling.path_planning.score_type == "random":
            score_E = torch.rand_like(logp_E).log()

        # Apply edge padding mask to scores
        score_E = score_E.masked_fill(~edge_padding_mask, float("inf"))

        # Apply eta scaling to unmask candidates
        score_E[unmask_candidates_E] *= self.config.sampling.path_planning.eta

        # Calculate how many edges to mask based on schedule
        # It's 2 * the number of edges in the upper triangle of the edge matrix
        num_to_mask_E = (
            2
            * (
                edge_padding_mask.triu(diagonal=1)
                .sum(dim=(1, 2), keepdim=True)
                .squeeze(dim=2)
                .float()
                * (move_chance_t)
            ).long()
        )
        mask_E = self.topk_lowest_masking(score_E.view(E.shape[0], -1), num_to_mask_E).view(
            E.shape[0], E.shape[1], E.shape[2]
        )

        # Assert mask_E is symmetric
        assert torch.allclose(mask_E, mask_E.transpose(1, 2)), "mask_E is not symmetric"
        # Create new edge representations
        E_next = E.clone()

        # Apply masks
        E_next_indices = E_indices.clone()
        E_next_indices[mask_E] = self.edge_mask_index

        # Unmask edges that should be revealed
        mask_to_e0_E = last_mask_E & ~mask_E
        E_next_indices[mask_to_e0_E] = e0_indices[mask_to_e0_E]

        # Convert indices back to one-hot
        E_next = _to_onehot(E_next_indices, self.n_edge_features).to(E.dtype)
        if self.config.self_conditioning:
            return (
                torch.cat([X_next, X_logits], dim=-1),
                torch.cat([E_next, E_logits], dim=-1),
                torch.cat([C_next, C0_pred], dim=-1),
            )
        else:
            return X_next, E_next, C_next

    def _update_X_E_C(self, X, E, C, node_padding_mask, edge_padding_mask, t, dt, cond=None):
        ### FORWARD PASS ###
        lengths = node_padding_mask.sum(dim=-1)
        discrete_sigma_t, _ = self.discrete_noise(t)
        discrete_sigma_s, _ = self.discrete_noise(t - dt)
        if discrete_sigma_t.ndim > 1:
            discrete_sigma_t = discrete_sigma_t.squeeze(-1)
        if discrete_sigma_s.ndim > 1:
            discrete_sigma_s = discrete_sigma_s.squeeze(-1)
        assert discrete_sigma_t.ndim == 1, discrete_sigma_t.shape
        assert discrete_sigma_s.ndim == 1, discrete_sigma_s.shape

        if self.config.sampling.spatial.stochastic and (
            self.config.sampling.spatial.tmin <= t[0] <= self.config.sampling.spatial.tmax
        ):
            gamma = self.config.sampling.spatial.churn / self.config.sampling.steps
            eps = torch.randn_like(C)
            dt_hat = gamma * t[0]
            t_hat = torch.clamp(t[0] + dt_hat, max=1)
            C = C + dt_hat * torch.sqrt(t_hat**2 - t[0] ** 2) * eps
            p_x0, p_e0, C0_pred = self.forward(
                X, E, C, node_padding_mask, edge_padding_mask, discrete_sigma_t, cond
            )
            dt = dt + dt_hat
            t = t_hat
        else:
            p_x0, p_e0, C0_pred = self.forward(
                X, E, C, node_padding_mask, edge_padding_mask, discrete_sigma_t, cond
            )

        if self.config.self_conditioning:
            X, E, C = (
                X[..., : X.shape[-1] // 2],
                E[..., : E.shape[-1] // 2],
                C[..., : C.shape[-1] // 2],
            )

        ### UPDATE COORDINATES ###

        ### OPTIONALLY ALIGN PREDICTION TO CURRENT NOISED COORDINATES ###
        if self.config.sampling.boltz_inference_align:
            X_indices = X.argmax(dim=-1)
            atom_mask = node_to_atom_padding_mask(node_padding_mask) & perfrag_atom_padding_mask(
                X_indices
            )
            C_flat = C.reshape(C.shape[0], -1, 3)
            C0_pred_flat = C0_pred.reshape(C0_pred.shape[0], -1, 3)
            aligned_C0_pred = align_and_permute(
                C_flat, C0_pred_flat, coords_mask=atom_mask_flat, sqrd=True
            )
            aligned_C0_pred = aligned_C0_pred.reshape(C.shape[0], C.shape[1], C.shape[2], 3)
        else:
            aligned_C0_pred = C0_pred

        if self.config.sampling.spatial.integrator == "euler":
            C_next = self._euler_step(C, aligned_C0_pred, t[0], dt)
        else:
            raise ValueError(f"Unknown integrator: {self.config.sampling.spatial.integrator}")

        ### UPDATE GRAPH NODES AND EDGES ###

        # Broadcast move chance to match the shape of the logits
        move_chance_t = 1 - torch.exp(-discrete_sigma_t)
        move_chance_s = 1 - torch.exp(-discrete_sigma_s)
        move_chance_t_X = move_chance_t[:, None, None]
        move_chance_t_E = move_chance_t[:, None, None, None]
        move_chance_s_X = move_chance_s[:, None, None]
        move_chance_s_E = move_chance_s[:, None, None, None]

        # Clamp logits to prevent overflow
        p_x0 = p_x0.exp()
        p_e0 = p_e0.exp()

        assert move_chance_t_X.ndim == p_x0.ndim
        assert move_chance_t_E.ndim == p_e0.ndim

        p_e0_unconstrained = p_e0.clone()
        if self.config.sampling.constrain_edge_sampling:
            p_e0 = self.sample_edges(E, p_e0, lengths)
        # Sample from time-dependent probability of a node or edge staying masked
        q_xs = p_x0 * (move_chance_t_X - move_chance_s_X)
        q_xs[:, :, -1] = move_chance_s_X[:, :, 0]
        _X, _ = _sample_categorical(q_xs)

        q_es = p_e0 * (move_chance_t_E - move_chance_s_E)
        q_es[:, :, :, -1] = move_chance_s_E[:, :, :, 0]
        _E, _ = _sample_categorical(q_es)

        # Symmetrize and set diagonals to "no edge" index (second to last index)
        _E_upper = torch.triu(_E, diagonal=1)  # Get upper triangle (excluding diagonal)
        _E = _E_upper + _E_upper.transpose(1, 2)  # Add transpose to make symmetric
        diag_indices = torch.arange(_E.shape[1], device=_E.device)
        _E[:, diag_indices, diag_indices] = self.edge_mask_index - 1

        # Expand X and E into one-hot encodings. The mask index is the last class here.
        _X = _to_onehot(_X, self.n_node_features)
        _E = _to_onehot(_E, self.n_edge_features)

        copy_flag_X = _get_unmasked_indices(X).to(X.dtype)
        copy_flag_E = _get_unmasked_indices(E).to(E.dtype)

        # X_next = copy_flag_X * X + (1 - copy_flag_X) * _X
        # E_next = copy_flag_E * E + (1 - copy_flag_E) * _E

        # X_next, E_next = X_next.argmax(dim=-1), E_next.argmax(dim=-1)
        # print(X_next[0])
        # unmasked_nodes = (X_next != p_x0.shape[-1] - 1).float().mean()
        # unmasked_edges = (E_next != p_e0.shape[-1] - 1).float().mean()

        # print(f"Proportion unmasked at time {t[0]} - Nodes: {unmasked_nodes}, Edges: {unmasked_edges}")
        if self.config.self_conditioning:
            return (
                torch.cat([copy_flag_X * X + (1 - copy_flag_X) * _X, p_x0], dim=-1),
                torch.cat([copy_flag_E * E + (1 - copy_flag_E) * _E, p_e0_unconstrained], dim=-1),
                torch.cat([C_next, C0_pred], dim=-1),
            )
        else:
            return (
                copy_flag_X * X + (1 - copy_flag_X) * _X,
                copy_flag_E * E + (1 - copy_flag_E) * _E,
                C_next,
            )

    @torch.no_grad()
    def _sample(self, num_steps=None, eps=1e-5, cond=None):
        """Generate samples from the model."""
        batch_size_per_gpu = self.config.loader.eval_batch_size
        X, E, C, C_true = None, None, None, None
        node_padding_mask, edge_padding_mask, coords_mask = None, None, None

        # send conditioning info to gpu
        if cond is not None:
            cond = tuple(item.to(self.device) for item in cond)

        # Initialize MC guidance dataloader if requested
        if self.config.sampling.spatial.guidance == "mc" or self.config.spatial.pharmacophore_conditioning:
            if self.inference_dataloader is None and cond is None: # only build once
                self.inference_dataloader = self.create_inference_dataloader(
                    data_source=None,
                    batch_size=self.config.loader.eval_batch_size,
                    shuffle=True
                )

        if self.config.spatial.pharmacophore_conditioning and cond is None:
            batch = next(iter(self.inference_dataloader))
            pharm_types, pharm_pos, pharm_padding_mask = self.get_pharm_cond(batch)
            cond = (pharm_types, pharm_pos, pharm_padding_mask)

        # First get data from batch if denoising anything
        if self.config.denoise_discrete or self.config.denoise_coordinates:
            batch = self.cached_val_batch
            dense_data, node_padding_mask, edge_padding_mask, coords_mask = to_dense(
                batch.x,
                batch.edge_index,
                batch.edge_attr,
                batch.batch,
                self.config.model.length,
                batch.coordinates,
                batch.coords_mask,
            )

            # Handle each denoising case
            if self.config.denoise_discrete:
                X, E = dense_data.X, dense_data.E
                X_indices = X.argmax(dim=-1)
                C_true = dense_data.C / COORDS_STD

            if self.config.denoise_coordinates:
                C = dense_data.C / COORDS_STD

        # If we don't have X and E yet, initialize from priors
        if X is None or E is None:
            X = self._sample_graph_prior(
                (batch_size_per_gpu, self.config.model.length, self.n_node_features)
            ).to(self.device)
            E = self._sample_graph_prior(
                (
                    batch_size_per_gpu,
                    self.config.model.length,
                    self.config.model.length,
                    self.n_edge_features,
                )
            ).to(self.device)

            if C is None:
                # Load node mask distribution from dataset priors
                dataset_name = self.config.data.train
                priors = DATASET_PRIORS[dataset_name]
                values = torch.tensor([int(k) for k in priors.keys()], device=self.device)
                probabilities = torch.tensor(list(priors.values()), device=self.device)
                num_nodes = values[
                    torch.multinomial(probabilities, num_samples=X.shape[0], replacement=True)
                ]
                X_indices = X.argmax(dim=-1)
                node_padding_mask, edge_padding_mask = padding_mask(X, E, num_nodes)
                coords_mask = node_to_atom_padding_mask(
                    node_padding_mask
                ) & perfrag_atom_padding_mask(X_indices)
                C = self._sample_coord_prior(
                    (batch_size_per_gpu, self.config.model.length, MAX_ATOMS, 3),
                    coords_mask=coords_mask.reshape(
                        coords_mask.shape[0], self.config.model.length, MAX_ATOMS
                    ),
                )
            else:
                C = dense_data.C / COORDS_STD

        if C is None:  # X and E are already initialized
            C = self._sample_coord_prior(
                (batch_size_per_gpu, self.config.model.length, MAX_ATOMS, 3),
                coords_mask=coords_mask.reshape(
                    coords_mask.shape[0], self.config.model.length, MAX_ATOMS
                ),
            )

        ### HARDCODE FOR 1000 EXAMPLES ###
        # Xes = []
        # Es = []
        # node_masks = []
        # for example in onethousand:
        #     # Create onehot mask tensor for X padding
        #     over_X = torch.zeros(self.config.model.length, example.x.size(1))
        #     over_X[:example.x.size(0)] = example.x
        #     over_X[example.x.size(0):, -1] = 1  # Set mask bit to 1 for padding
        #     # Create onehot mask tensor for E padding
        #     edge_shape = example.edge_attr.reshape(example.x.size(0), example.x.size(0), -1)
        #     over_E = torch.zeros(self.config.model.length, self.config.model.length, edge_shape.size(-1))
        #     over_E[:example.x.size(0), :example.x.size(0)] = edge_shape
        #     over_E[example.x.size(0):, :, -1] = 1  # Set mask bit to 1 for padding
        #     over_E[:, example.x.size(0):, -1] = 1  # Set mask bit to 1 for padding
        #     # Create node padding mask
        #     over_node_mask = torch.zeros(self.config.model.length, dtype=torch.bool)
        #     over_node_mask[:example.x.size(0)] = True
        #     Xes.append(over_X)
        #     Es.append(over_E)
        #     node_masks.append(over_node_mask)
        # over_X = torch.stack(Xes)
        # over_E = torch.stack(Es)
        # over_node_mask = torch.stack(node_masks)
        # X = over_X[:self.config.loader.eval_batch_size].to(X.device)
        # E = over_E[:self.config.loader.eval_batch_size].to(E.device)
        # node_padding_mask = over_node_mask[:self.config.loader.eval_batch_size].to(node_padding_mask.device)
        #### END HARDCODE ###

        assert batch_size_per_gpu == X.shape[0], print(batch_size_per_gpu, X.shape[0])

        if num_steps is None:
            num_steps = self.config.sampling.steps

        if self.config.self_conditioning:
            X = torch.cat([X, torch.zeros_like(X)], dim=-1)
            E = torch.cat([E, torch.zeros_like(E)], dim=-1)
            C = torch.cat([C, torch.zeros_like(C)], dim=-1)

        timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
        dt = (1 - eps) / num_steps

        for i in range(num_steps):
            # Center coordinates
            if self.config.self_conditioning:
                X, X_self_cond = X[:, :, : X.shape[-1] // 2], X[:, :, X.shape[-1] // 2 :]
                C, C_self_cond = C[:, :, :, : C.shape[-1] // 2], C[:, :, :, C.shape[-1] // 2 :]

            X_indices = X.argmax(dim=-1)
            coords_mask = node_to_atom_padding_mask(node_padding_mask) & perfrag_atom_padding_mask(
                X_indices
            )
            C = C.reshape(batch_size_per_gpu, -1, 3)
            if self.config.spatial.pharmacophore_conditioning:
                C_and_pharm = torch.cat([C, cond[1]], dim=1)
                coords_and_pharm_mask = torch.cat([coords_mask, torch.zeros_like(cond[2], device=cond[2].device)], dim=1)
                C_and_pharm = center_atom_coords(C_and_pharm, coords_and_pharm_mask)
                C = C_and_pharm[:, :C.shape[1], :]
                pharm_pos = C_and_pharm[:, C.shape[1]:, :]
                cond = (cond[0], pharm_pos, cond[2])
            else:
                C = center_atom_coords(C, coords_mask)
            assert check_batch_coord_means(C, coords_mask), "Coordinates must be centered (mean=0)"
            C = C.reshape(batch_size_per_gpu, self.config.model.length, MAX_ATOMS, -1)
            #save_partial_as_sdf(X[0], E[0], C[0], node_padding_mask[0], coords_mask[0], f"sample_outs/step_{i}.pdb")

            # Reshape to (bs, n, max_atoms, 6)
            if self.config.self_conditioning:
                X = torch.cat([X, X_self_cond], dim=-1)
                C = torch.cat([C, C_self_cond], dim=-1)

            # update X, E, C
            t = timesteps[i] * torch.ones(X.shape[0], 1, device=self.device)
            if self.config.sampling.predictor == "ddpm":
                update_fn = self._update_X_E_C
            elif self.config.sampling.predictor == "path_planning":
                update_fn = self._update_X_E_C_path_planning
            else:
                raise ValueError(f"Unknown predictor: {self.config.sampling.predictor}")

            X_next, E_next, C_next = update_fn(X, E, C, node_padding_mask, edge_padding_mask, t, dt, cond)

            X = X_next
            E = E_next
            C = C_next
            if self.config.denoise_coordinates:
                C = dense_data.C / COORDS_STD

        if self.config.sampling.noise_removal:
            t = timesteps[-1] * torch.ones(X.shape[0], 1, device=self.device)
            sigma_0 = self.discrete_noise(t)[0]
            X_logits, E_logits, C0_pred = self.forward(
                X, E, C, node_padding_mask, edge_padding_mask, sigma_0, cond
            )
            X = _to_onehot(X_logits.argmax(dim=-1), self.n_node_features).float()
            E = _to_onehot(E_logits.argmax(dim=-1), self.n_edge_features).float()
            if self.config.sampling.refine_coordinates_steps > 0:
                X_final_cond = torch.cat([X, X], dim=-1)
                E_final_cond = torch.cat([E, E], dim=-1)
                C = torch.cat([C0_pred, C[:, :, :, : C.shape[-1] // 2]], dim=-1)
                for _ in range(self.config.sampling.refine_coordinates_steps):
                    C_cond = C[:, :, :, : C.shape[-1] // 2]
                    _, _, C_refine = self.forward(X_final_cond, E_final_cond, C, node_padding_mask, edge_padding_mask, sigma_0, cond)
                    C = torch.cat([C_refine, C_cond], dim=-1)
            
            else:
                #case where no noise removal for coords
                #C_refine = C[:, :, :, : C.shape[-1] // 2]
                C_refine = C0_pred

        #save_partial_as_sdf(X[0], E[0], C_refine[0], node_padding_mask[0], coords_mask[0], f"sample_outs/final.pdb")
        return X, E, C_refine, C_true, node_padding_mask, edge_padding_mask

    def restore_model_and_sample(self, num_steps, eps=1e-5, cond=None):
        """Generate samples from the model."""
        # Lightning auto-casting is not working in this method for some reason
        if self.ema:
            self.ema.store(
                itertools.chain(
                    self.backbone.parameters(),
                    self.discrete_noise.parameters(),
                    self.spatial_noise.parameters(),
                )
            )
            self.ema.copy_to(
                itertools.chain(
                    self.backbone.parameters(),
                    self.discrete_noise.parameters(),
                    self.spatial_noise.parameters(),
                )
            )
        self.backbone.eval()
        self.discrete_noise.eval()
        self.spatial_noise.eval()
        samples = self._sample(
            num_steps=num_steps, 
            eps=eps, 
            cond=cond
        )
        if self.ema:
            self.ema.restore(
                itertools.chain(
                    self.backbone.parameters(),
                    self.discrete_noise.parameters(),
                    self.spatial_noise.parameters(),
                )
            )
        self.backbone.train()
        self.discrete_noise.train()
        self.spatial_noise.train()
        return samples

    def _sample_t(self, n, device):
        _eps_t = torch.rand(n, device=device)
        if self.antithetic_sampling:
            offset = torch.arange(n, device=device) / n
            _eps_t = (_eps_t / n + offset) % 1
        t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
        if self.importance_sampling:
            # do importance sampling by discrete noise schedule
            return self.discrete_noise.importance_sampling_transformation(t)
        return t

    def _forward_pass_diffusion(
        self, X0, E0, C0, batch_smiles, node_padding_mask, edge_padding_mask, coords_mask, prefix, cond_X=None, cond_C=None, cond_padding_mask=None
    ):
        t = self._sample_t(X0.shape[0], X0.device)
        if self.T > 0:
            t = (t * self.T).to(torch.int)
            t = t / self.T
            t += 1 / self.T

        self.sampled_t = t

        if self.change_of_variables:
            discrete_sigma = t[:, None]
            f_T = torch.log1p(-torch.exp(-self.discrete_noise.sigma_max))
            f_0 = torch.log1p(-torch.exp(-self.discrete_noise.sigma_min))
            move_chance = torch.exp(f_0 + t * (f_T - f_0))
            move_chance = move_chance[:, None]
        else:
            discrete_sigma, discrete_dsigma = self.discrete_noise(t)
            move_chance = 1 - torch.exp(-discrete_sigma[:, None])

        if self.config.denoise_discrete:
            Xt, Et = X0, E0
        else:
            Xt, Et = self.q_xt_et(X0, E0, move_chance)

        if self.config.denoise_coordinates:
            Ct = C0
        else:
            current_coords_mask = node_to_atom_padding_mask(
                node_padding_mask
            ) & perfrag_atom_padding_mask(Xt.argmax(dim=-1))
            current_coords_mask = current_coords_mask.reshape(
                current_coords_mask.shape[0], self.config.model.length, MAX_ATOMS
            )
            if self.config.spatial.pharmacophore_conditioning:
                C0, Ct, cond_C = self.C_t(
                    C0,
                    t,
                    current_coords_mask=current_coords_mask,
                    true_coords_mask=coords_mask,
                    prefix=prefix,
                    cond_C=cond_C,
                    cond_padding_mask=cond_padding_mask
                )
            else:
                C0, Ct = self.C_t(
                    C0,
                    t,
                    current_coords_mask=current_coords_mask,
                    true_coords_mask=coords_mask,
                    prefix=prefix,
                )

        #cond_C = torch.ones_like(cond_C, device=cond_C.device) * 1000

        if self.config.self_conditioning:
            x_cond = torch.zeros_like(Xt)
            e_cond = torch.zeros_like(Et)
            c_cond = torch.zeros_like(Ct)

            if torch.rand(1).item() < 0.5:
                x_cond, e_cond, c_cond = self.forward(
                    torch.cat([Xt, x_cond], dim=-1),
                    torch.cat([Et, e_cond], dim=-1),
                    torch.cat([Ct, c_cond], dim=-1),
                    node_padding_mask=node_padding_mask,
                    edge_padding_mask=edge_padding_mask,
                    sigma=discrete_sigma[:, None],
                    cond = (cond_X, cond_C, cond_padding_mask)
                )
                x_cond = torch.clamp(x_cond, max=MAX_LOGIT).exp()
                e_cond = torch.clamp(e_cond, max=MAX_LOGIT).exp()

            model_output_Xt, model_output_Et, model_output_Ct = self.forward(
                torch.cat([Xt, x_cond], dim=-1),
                torch.cat([Et, e_cond], dim=-1),
                torch.cat([Ct, c_cond], dim=-1),
                node_padding_mask=node_padding_mask,
                edge_padding_mask=edge_padding_mask,
                sigma=discrete_sigma[:, None],
                cond = (cond_X, cond_C, cond_padding_mask)
            )

        else:
            model_output_Xt, model_output_Et, model_output_Ct = self.forward(
                Xt,
                Et,
                Ct,
                node_padding_mask=node_padding_mask,
                edge_padding_mask=edge_padding_mask,
                sigma=discrete_sigma[:, None],
                cond = (cond_X, cond_C, cond_padding_mask),
            )

        print_nans(model_output_Xt, "model_output_Xt")
        print_nans(model_output_Et, "model_output_Et")

        # SUBS parameterization.
        # Log probabilities of each true token in X and E.

        # We need to convert back from one-hot to indices to get the logprob of each fragment type
        X0_index = X0.argmax(dim=-1)
        log_p_theta_X = torch.gather(
            input=model_output_Xt, dim=-1, index=X0_index[:, :, None]
        ).squeeze(-1)

        # Same for E, but we need to make sure the diagonals are all 1 because they are zeroed before passing into the network
        # but we know that no node has an edge to itself. The index of "no edge" is 1, so we add 1 to the diagonals.
        E0_index = E0.argmax(dim=-1)
        no_edge_idx = self.n_edge_features - 2  # Second to last index for "no edge"
        E0_index.diagonal(dim1=-2, dim2=-1).fill_(no_edge_idx)

        log_p_theta_E = torch.gather(
            input=model_output_Et, dim=-1, index=E0_index[:, :, :, None]
        ).squeeze(-1)

        batch_size, n_nodes, max_atoms, _ = C0.shape

        if self.config.spatial.center:
            coords_mask_flat = coords_mask.reshape(coords_mask.shape[0], -1)
            model_output_Ct_flat = model_output_Ct.reshape(model_output_Ct.shape[0], -1, 3)
            model_output_Ct_flat = center_atom_coords(model_output_Ct_flat, coords_mask_flat)
            assert check_batch_coord_means(
                model_output_Ct_flat, coords_mask_flat
            ), "Coordinates must be centered (mean=0)"
            model_output_Ct = model_output_Ct_flat.reshape(
                model_output_Ct_flat.shape[0], self.config.model.length, MAX_ATOMS, 3
            )

        # CALCULATE COORDINATE MSE
        squared_error = (model_output_Ct - C0) ** 2

        # Calculate spatial loss based on config
        bond_losses = []
        total_bond_loss = 0.0
        individual_losses = {}

        if "bond_length" in self.config.spatial.bond_loss:
            # CALCULATE BOND LENGTH LOSS
            bl_losses = []
            for i in range(batch_size):
                length = node_padding_mask[i].sum()
                X_trimmed, E_trimmed = remove_graph_padding(X0[i], E0[i], length)
                bonds = get_bonds(X_trimmed, E_trimmed, reindex=False)
                bl_loss = bond_length_loss(
                    C0[i], model_output_Ct[i], bonds, sqrd=self.config.spatial.square_bond_loss
                )
                bl_losses.append(bl_loss)
            bl_losses = torch.stack(bl_losses)
            bond_losses.append(bl_losses * self.config.spatial.bond_length_coef)
            individual_losses["bl"] = bl_losses.mean()
            total_bond_loss += bl_losses * self.config.spatial.bond_length_coef

        if "pairwise_distance" in self.config.spatial.bond_loss:
            # CALCULATE PAIRWISE DISTANCE LOSS
            threshold = self.config.spatial.pairwise_threshold
            if self.config.spatial.normalize:
                threshold = threshold / COORDS_STD
            pd_loss = pairwise_distance_loss(
                C0,
                model_output_Ct,
                coords_mask,
                threshold=threshold,
                sqrd=self.config.spatial.square_bond_loss,
            )
            bond_losses.append(pd_loss.squeeze(-1) * self.config.spatial.pairwise_distance_coef)
            individual_losses["pd"] = pd_loss.mean()
            total_bond_loss += pd_loss.squeeze(-1) * self.config.spatial.pairwise_distance_coef

        if "smooth_lddt" in self.config.spatial.bond_loss:
            # CALCULATE SMOOTH LDDT LOSS
            slddt_loss = smooth_lddt_loss(
                C0, model_output_Ct, coords_mask, sqrd=self.config.spatial.square_bond_loss
            )
            bond_losses.append(slddt_loss * self.config.spatial.smooth_lddt_coef)
            individual_losses["slddt"] = slddt_loss.mean()
            total_bond_loss += slddt_loss * self.config.spatial.smooth_lddt_coef

        if not bond_losses:
            raise ValueError(
                "Bond loss list is empty. Must contain at least one of: bond_length, pairwise_distance, smooth_lddt"
            )

        # Print individual losses
        loss_str = " ".join([f"{k}: {v:.4f}" for k, v in individual_losses.items()])
        print(loss_str)

        bond_losses = total_bond_loss

        if self.change_of_variables or self.importance_sampling:
            return log_p_theta_X * torch.log1p(
                -torch.exp(-self.discrete_noise.sigma_min)
            ), log_p_theta_E * torch.log1p(-torch.exp(-self.discrete_noise.sigma_min))

        true_edges = E0[:, :, :, :-2].sum(dim=-1) > 0
        edge_loss_weights = torch.ones_like(log_p_theta_E)
        edge_loss_weights[true_edges] *= self.config.true_edge_weight

        return (
            -log_p_theta_X * (discrete_dsigma / torch.expm1(discrete_sigma))[:, None],
            -log_p_theta_E
            * (discrete_dsigma / torch.expm1(discrete_sigma))[:, None, None]
            * edge_loss_weights,
            squared_error,
            bond_losses,
            current_coords_mask,
        )

    def _loss(
        self, X, E, C, batch_smiles, node_padding_mask, edge_padding_mask, coords_mask, prefix, cond_X=None, cond_C=None, cond_padding_mask=None
    ):
        loss_X, loss_E, loss_v, loss_b, coords_mask = self._forward_pass_diffusion(
            X, E, C, batch_smiles, node_padding_mask, edge_padding_mask, coords_mask, prefix, cond_X=cond_X, cond_C=cond_C, cond_padding_mask=cond_padding_mask
        )
        assert loss_X.shape == node_padding_mask.shape
        assert loss_E.shape == edge_padding_mask.shape

        # Mask both padded atoms and masked atoms in coordinate loss
        node_padding_mask_expanded = node_padding_mask.unsqueeze(-1).unsqueeze(
            -1
        )  # bs, n_fragments, 1, 1
        node_padding_mask_expanded = node_padding_mask_expanded.expand(
            -1, -1, MAX_ATOMS, 3
        )  # bs, n_fragments, max_atoms, 3
        coords_mask_expanded = coords_mask.unsqueeze(-1)  # bs, n_fragments, max_atoms, 1
        coords_mask_expanded = coords_mask_expanded.expand(
            -1, -1, -1, 3
        )  # bs, n_fragments, max_atoms, 3
        combined_coord_mask = node_padding_mask_expanded * coords_mask_expanded

        # Calculate both unweighted and weighted MSE loss
        # Get current timestep for weighting MSE loss
        t = self.sampled_t.reshape(-1, 1, 1, 1)  # bs, 1, 1, 1
        if self.config.spatial.square_time_weight:
            time_weight = (1.0 / t) ** 2
        else:
            time_weight = 1.0 / t

        weighted_loss_v = loss_v * time_weight * self.config.spatial.mse_coef
        weighted_loss_b = (
            loss_b * time_weight * (t <= self.config.spatial.bond_loss_time_threshold).float()
        )

        mse_loss_v = (loss_v * combined_coord_mask).sum() / combined_coord_mask.sum()
        weighted_mse_loss_v = (
            weighted_loss_v * combined_coord_mask
        ).sum() / combined_coord_mask.sum()

        weighted_bond_loss = weighted_loss_b.mean()

        # Mask padded atoms in node loss
        masked_loss_X = loss_X * node_padding_mask
        avg_loss_X = masked_loss_X.sum(dim=1) / node_padding_mask.sum(dim=1)

        # Mask padded atoms in edge loss
        if loss_E.dim() == 2:
            masked_loss_E = loss_E * edge_padding_mask
            avg_loss_E = masked_loss_E.sum(dim=1) / edge_padding_mask.sum(dim=1)
        elif loss_E.dim() == 3:
            masked_loss_E = loss_E * edge_padding_mask
            avg_loss_E = masked_loss_E.sum(dim=(1, 2)) / edge_padding_mask.sum(dim=(1, 2))

        combined_discrete_loss = avg_loss_X + avg_loss_E
        total_loss = combined_discrete_loss.mean() + weighted_mse_loss_v + weighted_bond_loss

        # Calculate perplexity
        node_ppl = torch.exp(avg_loss_X.mean())
        edge_ppl = torch.exp(avg_loss_E.mean())
        total_ppl = torch.exp(avg_loss_X.mean() + avg_loss_E.mean())

        print(
            f"Average Node Loss: {avg_loss_X.mean().item():.4f} | Average Edge Loss: {avg_loss_E.mean().item():.4f} | "
            f"Node PPL: {node_ppl.item():.4f} | Edge PPL: {edge_ppl.item():.4f} | Total PPL: {total_ppl.item():.4f} | "
            f"Weighted Bond Loss: {weighted_bond_loss.item():.4f} | MSE Loss: {mse_loss_v.item():.4f} | Weighted MSE Loss: {weighted_mse_loss_v.item():.4f}"
        )

        return Loss(
            total_loss=total_loss,
            node_nll=avg_loss_X.mean(),
            edge_nll=avg_loss_E.mean(),
            ppl=total_ppl,
            node_ppl=node_ppl,
            edge_ppl=edge_ppl,
            fm_mse=mse_loss_v,
            weighted_bond_loss=weighted_bond_loss,
            weighted_mse=weighted_mse_loss_v,
        )