# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Layers for perturbing protein structure with noise.

This module contains pytorch layers for perturbing protein structure with noise,
which can be useful both for data augmentation, benchmarking, or denoising based
training.
"""
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
from tqdm.auto import tqdm

from src.chroma.constants import AA20
from src.chroma.data.xcs import validate_XC
from src.chroma.layers import basic, sde
from src.chroma.layers.structure import backbone, hbonds, mvn, rmsd


class GaussianNoiseSchedule:
    """
    A general noise schedule for the General Gaussian Forward Path, where noise is added
    to the input signal.

    The noise is modeled as Gaussian noise with mean `alpha_t x_0` and variance
     `sigma_t^2`, with 'x_0 ~ p(x_0)' The time range of the noise schedule is
     parameterized with a user-specified logarithmic signal-to-noise ratio (SNR) range,
    where  `snr_t = alpha_t^2 / sigma_t^2` is the SNR at time `t`.

    In addition, the object defines a quantity called the scaled signal-to-noise ratio
    (`ssnr_t`), which is given by `ssnr_t = alpha_t^2 / (alpha_t^2 + sigma_t^2)`
    and is a helpful quantity for analyzing the performance of signal processing
    algorithms under different noise conditions.

    This object implements a few standard noise schedule:

        'log_snr': variance-preserving process with linear log SNR schedule
        (https://arxiv.org/abs/2107.00630)

        'ot_linear': OT schedule
        (https://arxiv.org/abs/2210.02747)

        've_log_snr': variance-exploding process with linear log SNR s hedule
        (https://arxiv.org/abs/2011.13456 with log SNR noise schedule)

    User can also implement their own schedules by specifying alpha_func, sigma_func
    and compute_t_range.

    """

    def __init__(
        self, log_snr_range: Tuple[float, float] = (-7.0, 13.5), kind: str = "log_snr",
    ) -> None:
        super().__init__()

        if kind not in ["log_snr", "ot_linear", "ve_log_snr"]:
            raise NotImplementedError(
                f"noise type {kind} is not implemented,                            only"
                " log_snr and ot_linear are supported "
            )
        self.kind = kind
        self.log_snr_range = log_snr_range

        l_min, l_max = self.log_snr_range

        # map t \in [0, 1] to match the prescribed log_snr range
        self.t_max = self.compute_t_range(l_min)
        self.t_min = self.compute_t_range(l_max)
        self._eps = 1e-5

    def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """map t in [0, 1] to [t_min, t_max]

        Args:
            t (Union[float, torch.Tensor]): time

        Returns:
            torch.Tensor: mapped time
        """
        if not isinstance(t, torch.Tensor):
            t = torch.Tensor([t]).float()

        t_max = self.t_max.to(t.device)
        t_min = self.t_min.to(t.device)
        t_tilde = t_min + (t_max - t_min) * t

        return t_tilde.to(t.dtype)

    def derivative(self, t: torch.Tensor, func: Callable) -> torch.Tensor:
        """compute derivative of a function, it supports bached single variable inputs

        Args:
            t (torch.Tensor): time variable at which derivatives are taken
            func (Callable): function for derivative calculation

        Returns:
            torch.Tensor: derivative that is detached from the computational graph
        """
        with torch.enable_grad():
            t.requires_grad_(True)
            derivative = grad(func(t).sum(), t, create_graph=False)[0].detach()
            t.requires_grad_(False)
        return derivative

    def tensor_check(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """convert input to torch.Tensor if it is a float

        Args:
            t ( Union[float, torch.Tensor]): input

        Returns:
            torch.Tensor: converted torch.Tensor
        """
        if not isinstance(t, torch.Tensor):
            t = torch.Tensor([t]).float()
        return t

    def alpha_func(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """alpha function that scales the mean, usually goes from 1. to 0.

        Args:
            t (Union[float, torch.Tensor]): time in [0, 1]

        Returns:
            torch.Tensor: alpha value
        """

        t = self.tensor_check(t)

        if self.kind == "log_snr":
            l_min, l_max = self.log_snr_range
            t_min, t_max = self.t_min, self.t_max
            log_snr = (1 - t) * l_max + t * l_min

            log_alpha = 0.5 * (log_snr - F.softplus(log_snr))
            alpha = log_alpha.exp()
            return alpha

        elif self.kind == "ve_log_snr":
            return 1 - torch.relu(-t)  # make this differentiable

        elif self.kind == "ot_linear":
            return 1 - t

    def sigma_func(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """sigma function that scales the standard deviation, usually goes from 0. to 1.

        Args:
            t (Union[float, torch.Tensor]): time in [0, 1]

        Returns:
            torch.Tensor: sigma value
        """
        t = self.tensor_check(t)
        l_min, l_max = self.log_snr_range

        if self.kind == "log_snr":
            alpha = self.alpha(t)
            return (1 - alpha.pow(2)).sqrt()

        elif self.kind == "ve_log_snr":
            # compute sigma value given snr range

            l_min, l_max = self.log_snr_range
            t_min, t_max = self.t_min, self.t_max
            log_snr = (1 - t) * l_max + t * l_min
            return torch.exp(-log_snr / 2)

        elif self.kind == "ot_linear":
            return t

    def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """compute alpha value for the mapped time in [t_min, t_max]

        Args:
            t (Union[float, torch.Tensor]): time in [0, 1]

        Returns:
            torch.Tensor: alpha value
        """
        return self.alpha_func(self.t_map(t))

    def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """compute sigma value for mapped time in [t_min, t_max]

        Args:
            t (Union[float, torch.Tensor]): time in [0, 1]

        Returns:
            torch.Tensor: sigma value
        """
        return self.sigma_func(self.t_map(t))

    def alpha_deriv(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """compute alpha derivative for mapped time in [t_min, t_max]

        Args:
            t (Union[float, torch.Tensor]): time in [0, 1]

        Returns:
            torch.Tensor: time derivative of alpha_func
        """
        t_tilde = self.t_map(t)
        alpha_deriv_t = self.derivative(t_tilde, self.alpha_func).detach()
        return alpha_deriv_t

    def sigma_deriv(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """compute sigma derivative for the mapped time in [t_min, t_max]

        Args:
            t (Union[float, torch.Tensor]): time in [0, 1]

        Returns:
            torch.Tensor: sigma derivative
        """
        t_tilde = self.t_map(t)
        sigma_deriv_t = self.derivative(t_tilde, self.sigma_func).detach()
        return sigma_deriv_t

    def beta(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """compute the drift coefficient for the OU process of the form
        $dx = -\frac{1}{2} \beta(t) x dt + g(t) dw_t$

        Args:
            t (Union[float, torch.Tensor]): t in [0, 1]

        Returns:
            torch.Tensor: beta(t)
        """
        # t = self.t_map(t)
        alpha = self.alpha(t).detach()
        t_map = self.t_map(t)
        alpha_deriv_t = self.alpha_deriv(t)
        beta = -2.0 * alpha_deriv_t / alpha

        return beta

    def g(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """compute drift coefficient for the OU process:
        $dx = -\frac{1}{2} \beta(t) x dt + g(t) dw_t$

        Args:
            t (Union[float, torch.Tensor]): t in [0, 1]

        Returns:
            torch.Tensor: g(t)
        """
        if self.kind == "log_snr":
            t = self.t_map(t)
            g = self.beta(t).sqrt()

        else:
            alpha_deriv = self.alpha_deriv(t)
            alpha_prime_div_alpha = alpha_deriv / self.alpha(t)
            sigma_deriv = self.sigma_deriv(t)
            sigma_prime_div_sigma = sigma_deriv / self.sigma(t)

            g_sq = (
                2
                * (sigma_deriv - alpha_prime_div_alpha * self.sigma(t))
                * self.sigma(t)
            )
            g = g_sq.sqrt()

        return g

    def SNR(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """Signal-to-Noise(SNR) ratio  mapped in the allowed log_SNR range

        Args:
            t (Union[float, torch.Tensor]): time in [0, 1]

        Returns:
            torch.Tensor: SNR value
        """
        t = self.tensor_check(t)

        if self.kind == "log_snr":
            SNR = self.log_SNR(t).exp()
        else:
            SNR = self.alpha(t).pow(2) / (self.sigma(t).pow(2))

        return SNR

    def log_SNR(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """log SNR value

        Args:
            t (Union[float, torch.Tensor]): time in [0, 1]

        Returns:
            torch.Tensor: log SNR value
        """
        t = self.tensor_check(t)

        if self.kind == "log_snr":
            l_min, l_max = self.log_snr_range
            log_snr = (1 - t) * l_max + t * l_min

        elif self.kind == "ot_linear":
            log_snr = self.SNR(t).log()

        return log_snr

    def compute_t_range(self, log_snr: Union[float, torch.Tensor]) -> torch.Tensor:
        """Given log(SNR) range : l_max, l_min to compute the time range.
        Hand-derivation is required for specific noise schedules.
        This function is essentially the inverse of logSNR(t)

        Args:
            log_snr (Union[float, torch.Tensor]): logSNR value

        Returns:
            torch.Tensor: the inverse logSNR
        """
        log_snr = self.tensor_check(log_snr)
        l_min, l_max = self.log_snr_range

        if self.kind == "log_snr":
            t = (1 / (l_min - l_max)) * (log_snr - l_max)

        elif self.kind == "ot_linear":
            t = ((0.5 * log_snr).exp() + 1).reciprocal()

        elif self.kind == "ve_log_snr":
            t = (1 / (l_min - l_max)) * (log_snr - l_max)

        return t

    def SNR_derivative(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """the derivative of SNR(t)

        Args:
            t (Union[float, torch.Tensor]): t in [0, 1]

        Returns:
            torch.Tensor: SNR derivative
        """
        t = self.tensor_check(t)

        if self.kind == "log_snr":
            snr_deriv = self.SNR(t) * (self.log_snr_range[0] - self.log_snr_range[1])

        elif self.kind == "ot_linear":
            snr_deriv = self.derivative(t, self.SNR)
        return snr_deriv

    def SSNR(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
        """Signal to Signal+Noise Ratio (SSNR) = alpha^2 / (alpha^2 + sigma^2)
           SSNR monotonically goes from 1 to 0 as t going from 0 to 1.

        Args:
            t (Union[float, torch.Tensor]): time in [0, 1]

        Returns:
            torch.Tensor: SSNR value
        """
        t = self.tensor_check(t)
        # return self.SNR(t) / (self.SNR(t) + 1)
        return 1 / (1 + 1/self.SNR(t))

    def SSNR_inv(self, ssnr: torch.Tensor) -> torch.Tensor:
        """the inverse of SSNR

        Args:
            ssnr (torch.Tensor): ssnr in [0, 1]

        Returns:
            torch.Tensor: time in [0, 1]
        """
        l_min, l_max = self.log_snr_range
        if self.kind == "log_snr":
            return ((ssnr / (1 - ssnr)).log() - l_max) / (l_min - l_max)
        elif self.kind == "ot_linear":
            # the value of SNNR_inv(t=0.5) need to be determined with L'Hôpital rule
            # the inver SNNR_function is solved anyltically:
            # see woflram alpha result: https://tinyurl.com/bdh4es5a
            singularity_check = (ssnr - 0.5).abs() < self._eps
            ssnr_mask = singularity_check.float()
            ssnr = ssnr_mask * (0.5 + self._eps) + (1.0 - ssnr_mask) * ssnr

            return (ssnr + (-ssnr * (ssnr - 1)).sqrt() - 1) / (2 * ssnr - 1)

    def SSNR_inv_deriv(self, ssnr: Union[float, torch.Tensor]) -> torch.Tensor:
        """SSNR_inv derivative. SSNR_inv is a CDF like quantity, so its derivative is a PDF-like quantity

        Args:
            ssnr (Union[float, torch.Tensor]): SSNR in [0, 1]

        Returns:
            torch.Tensor: derivative of SSNR
        """
        ssnr = self.tensor_check(ssnr)
        deriv = self.derivative(ssnr, self.SSNR_inv)
        return deriv

    def prob_SSNR(self, ssnr: Union[float, torch.Tensor]) -> torch.Tensor:
        """compute prob (SSNR(t)), the minus sign is accounted for the inversion of integration range

        Args:
            ssnr (Union[float, torch.Tensor]): SSNR value

        Returns:
            torch.Tensor: Prob(SSNR)
        """
        return -self.SSNR_inv_deriv(ssnr)

    def linear_logsnr_grid(self, N: int, tspan: Tuple[float, float]) -> torch.Tensor:
        """Map uniform time grid to respect logSNR schedule

        Args:
            N (int): number of steps
            tspan (Tuple[float, float]): time span (t_start, t_end)

        Returns:
            torch.Tensor: time grid as torch.Tensor
        """

        logsnr_noise = GaussianNoiseSchedule(
            kind="log_snr", log_snr_range=self.log_snr_range
        )

        ts = torch.linspace(tspan[0], tspan[1], N + 1)
        SSNR_vp = logsnr_noise.SSNR(ts)
        grid = self.SSNR_inv(SSNR_vp)

        # map from t_tilde back to t
        grid = (grid - self.t_min) / (self.t_max - self.t_min)

        return grid


class NoiseTimeEmbedding(nn.Module):
    """
    A class that implements a noise time embedding layer.

    Args:
        dim_embedding (int): The dimension of the output embedding vector.
            noise_schedule (GaussianNoiseSchedule): A GaussianNoiseSchedule object that
            defines the noise schedule function.
        rff_scale (float, optional): The scaling factor for the random Fourier features.
            Default is 0.8.
        feature_type (str, optional): The type of feature to use for the time embedding.
            Either "t" or "log_snr". Default is "log_snr".

    Inputs:
        t (float): time in (1.0, 0.0).
        log_alpha (torch.Tensor, optional): A tensor of log alpha values with
            shape `(batch_size,)`.

    Outputs:
        time_h (torch.Tensor): A tensor of noise time embeddings with shape
         `(batch_size, dim_embedding)`.
    """

    def __init__(
        self,
        dim_embedding: int,
        noise_schedule: GaussianNoiseSchedule,
        rff_scale: float = 0.8,
        feature_type: str = "log_snr",
    ) -> None:
        super(NoiseTimeEmbedding, self).__init__()
        self.noise_schedule = noise_schedule
        self.feature_type = feature_type
        self.fourier_features = basic.FourierFeaturization(
            d_input=1, d_model=dim_embedding, trainable=False, scale=rff_scale
        )

    def forward(
        self, t: torch.Tensor, log_alpha: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        if not isinstance(t, torch.Tensor):
            t = torch.Tensor([t]).float().to(self.fourier_features.B.device)
        if t.dim() == 0:
            t = t[None]

        h = {"t": lambda: t, "log_snr": lambda: self.noise_schedule.log_SNR(t)}[
            self.feature_type
        ]()

        time_h = self.fourier_features(h[:, None, None])
        return time_h


class DiffusionChainCov(nn.Module):
    def __init__(
        self,
        log_snr_range: Tuple[float, float] = (-7.0, 13.5),
        noise_schedule: str = "log_snr",
        sigma_translation: float = 1.0,
        covariance_model: str = "brownian",
        complex_scaling: bool = False,
        **kwargs,
    ) -> None:
        """Diffusion backbone noise, with chain-structured covariance.

        This class implements a diffusion backbone noise model. The model uses a
        chain-structured covariance matrix capturing the spatial correlations between
        residues along the backbone. The model also supports different noise schedules
        and integration schemes for the stochastic differential equation (SDE) that
        defines the diffusion process. This class also implemented various inference
        algorithm by reversing the forward diffusion with user-specified
        conditioner program.

        Args:
            log_snr_range (tuple, optional): log SNR range. Defaults to (-7.0, 13.5).
            noise_schedule (str, optional): noise schedule type. Defaults to "log_snr".
            sigma_translation (float, optional): Scaling factor for the translation
                component of the covariance matrix. Defaults to 1.0.
            covariance_model (str, optional): covariance mode,. Defaults to "brownian".
            complex_scaling (bool, optional): Whether to scale the complex component
                of the covariance matrix by the translation component. Defaults to False.
            **kwargs: Additional arguments for the base Gaussian distribution and
                 the SDE integration.
        """
        super().__init__()

        self.noise_schedule = GaussianNoiseSchedule(
            log_snr_range=log_snr_range, kind=noise_schedule,
        )

        if covariance_model in ["brownian", "globular"]:
            self.base_gaussian = mvn.BackboneMVNGlobular(
                sigma_translation=sigma_translation,
                covariance_model=covariance_model,
                complex_scaling=complex_scaling,
            )
        elif covariance_model == "residue_gas":
            self.base_gaussian = mvn.BackboneMVNResidueGas()

        self.loss_rmsd = rmsd.BackboneRMSD()
        self._eps = 1e-5
        self.sde_funcs = {
            "langevin": self.langevin,
            "reverse_sde": self.reverse_sde,
            "ode": self.ode,
        }
        self.integrate_funcs = {
            "euler_maruyama": sde.sde_integrate,
            "heun": sde.sde_integrate_heun,
        }

    def sample_t(
        self,
        C: torch.LongTensor,
        t: Optional[torch.Tensor] = None,
        inverse_CDF: Optional[Callable] = None,
    ) -> torch.Tensor:
        """Sample a random time index for each batch element

        Inputs:
            C (torch.LongTensor): Chain tensor with shape `(batch_size, num_residues)`.
            t (torch.Tensor, optional): Time index with shape `(batch_size,)`.
                If not given, a random time index will be sampled. Defaults to None.

        Outputs:
            t (float): Time index with shape `(batch_size,)`.
        """
        if t is not None:
            if not isinstance(t, torch.Tensor):
                t = torch.Tensor([t]).float()
            return t

        num_batch = C.size(0)
        if self.training:
            # Sample correlated but marginally uniform t
            # for variance reduction (Kingma et al 2021)
            u = torch.rand([])
            ix = torch.arange(num_batch) / num_batch
            t = torch.remainder(u + ix, 1)
        else:
            t = torch.rand([num_batch])
        if inverse_CDF is not None:
            t = inverse_CDF(t)
        t = t.to(C.device)
        return t

    def sde_forward(self, X, C, t, Z=None):
        """Sample an Euler-Maruyama step on forwards SDE.

        That is to say, Euler-Maruyama integration would
        correspond to the update.
            `X_new = X + dt * f + sqrt(dt) * gZ`

        Args:

        Returns:
            f (Tensor): Drift term with shape `()`.
            gZ (Tensor): Diffusion term  with shape `()`.
        """

        # Sample random perturbation
        if Z is None:
            Z = torch.randn_like(X)
        Z = Z.reshape(X.shape[0], -1, 3)
        R_Z = self.base_gaussian._multiply_R(Z, C).reshape(X.shape)

        X = backbone.center_X(X, C)
        beta = self.noise_schedule.beta(t)
        f = -beta * X / 2.0
        gZ = self.noise_schedule.g(t)[:, None, None] * R_Z

        return f, gZ

    def _schedule_coefficients(
        self,
        t: torch.Tensor,
        inverse_temperature: float = 1.0,
        langevin_isothermal: bool = True,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        """
        A method that computes the schedule coefficients for sampling in the reverse time

        Args:
            t (float): time in (1.0, 0.0).
            inverse_temperature (float, optional): The inverse temperature parameter for
                he Langevin dynamics. Default is 1.0.
            langevin_isothermal (bool, optional): A flag that indicates whether to use
                isothermal or non-isothermal Langevin dynamics. Default is True.

        Returns:
            alpha (torch.Tensor): A tensor of alpha values with shape `(batch_size, 1, 1)`.
            sigma (torch.Tensor): A tensor of sigma values with shape `(batch_size, 1, 1)`.
            beta (torch.Tensor): A tensor of beta values with shape `(batch_size, 1, 1)`.
            g (torch.Tensor): A tensor of g values with shape `(batch_size, 1, 1)`.
            lambda_t (float): A tensor of lambda_t values with shape `(batch_size, 1, 1)`.
            lambda_langevin (torch.Tensor): A tensor of lambda_langevin values with
                shape `(batch_size, 1, 1)`.
        """

        # Schedule coeffiecients
        alpha = self.noise_schedule.alpha(t)[:, None, None].to(t.device)
        sigma = self.noise_schedule.sigma(t)[:, None, None].to(t.device)
        beta = self.noise_schedule.beta(t)[:, None, None].to(t.device)
        g = self.noise_schedule.g(t)[:, None, None].to(t.device)

        # Temperature coefficients
        lambda_t = (
            inverse_temperature
            * (sigma.pow(2) + alpha.pow(2))
            / (inverse_temperature * sigma.pow(2) + alpha.pow(2))
        )
        lambda_langevin = inverse_temperature if langevin_isothermal else lambda_t
        return alpha, sigma, beta, g, lambda_t, lambda_langevin

    @validate_XC()
    def langevin(
        self,
        X: torch.Tensor,
        X0_func: Callable,
        C: torch.LongTensor,
        t: Union[torch.Tensor, float],
        conditioner: Callable = None,
        Z: Union[torch.Tensor, None] = None,
        inverse_temperature: float = 1.0,
        langevin_factor: float = 0.0,
        langevin_isothermal: bool = True,
        align_X0: bool = True,
    ):
        """Return the drift and diffusion components of the Langevin dynamics for the
            reverse process

        Args:
            X (torch.Tensor): A tensor of protein backbone structure with shape
                `(batch_size, num_residues, 4, 3)`.
            X0_func (Callable): A function a denoising function for protein backbon
                e geometry.
            C (torch.LongTensor): A chain map tensor with shape `(batch_size, num_residues)`.
            t (float): time in (1.0, 0.0).
            conditioner (Callable, optional): A conditioner the performs constrained
                transformation (see examples in chroma.layers.structure.conditioners).
            Z (torch.Tensor, optional): A tensor of random noise with
                 shape `(batch_size, num_residues, 4, 3)`. Default is None.
            inverse_temperature (float, optional): The inverse temperature parameter
                 for the Langevin dynamics. Default is 1.0.
            langevin_factor (float, optional): The scaling factor for the Langevin noise.
                 Default is 1.0.
            langevin_isothermal (bool, optional): A flag that indicates whether to use
                 isothermal or non-isothermal Langevin dynamics. Default is True.
            align_X0 (bool, optional): A flag that indicates whether to align the noised
                 X and denoised X for score function calculation.

        Returns:
            f (torch.Tensor): A tensor of drift terms with shape
                `(batch_size, num_residues, 4, 3)`.
            gZ (torch.Tensor): A tensor of diffusion terms with shape
                `(batch_size, num_residues, 4, 3)`.
        """

        alpha, sigma, beta, g, lambda_t, lambda_langevin = self._schedule_coefficients(
            t,
            inverse_temperature=inverse_temperature,
            langevin_isothermal=langevin_isothermal,
        )

        Z = torch.randn_like(X) if Z is None else Z

        score = self.score(X, X0_func, C, t, conditioner, align_X0=align_X0)
        score_transformed = self.base_gaussian.multiply_covariance(score, C)
        f = -g.pow(2) * lambda_langevin * langevin_factor / 2.0 * score_transformed
        gZ = g * np.sqrt(langevin_factor) * self.base_gaussian._multiply_R(Z, C)
        return f, gZ

    @validate_XC()
    def reverse_sde(
        self,
        X: torch.Tensor,
        X0_func: Callable,
        C: torch.LongTensor,
        t: Union[torch.Tensor, float],
        conditioner: Callable = None,
        Z: Union[torch.Tensor, None] = None,
        inverse_temperature: float = 1.0,
        langevin_factor: float = 0.0,
        langevin_isothermal: bool = True,
        align_X0: bool = True,
    ):
        """Return the drift and diffusion components of the reverse SDE.

        Args:
            X (torch.Tensor): A tensor of protein backbone structure with shape
                `(batch_size, num_residues, 4, 3)`.
            X0_func (Callable): A function a denoising function for the protein backbone
                geometry.
            C (torch.LongTensor): A tensor of condition features with shape
                `(batch_size, num_residues)`.
            t (float): time in (1.0, 0.0).
            conditioner (Callable, optional): A conditioner the performs constrained
                 transformation (see examples in chroma.layers.structure.conditioners).
            Z (torch.Tensor, optional): A tensor of random noise with shape
                 `(batch_size, num_residues, 4, 3)`. Default is None.
            inverse_temperature (float, optional): The inverse temperature parameter
                for the Langevin dynamics. Default is 1.0.
            langevin_factor (float, optional): The scaling factor for the Langevin noise.
                 Default is 0.0.
            langevin_isothermal (bool, optional): A flag that indicates whether to use
                isothermal or non-isothermal Langevin dynamics. Default is True.
            align_X0 (bool, optional): A flag that indicates whether to align the noised
                 X and denoised X for score function calculation.

        Returns:
            f (torch.Tensor): A tensor of drift terms with shape
                 `(batch_size, num_residues, 4, 3)`.
            gZ (torch.Tensor): A tensor of diffusion terms with shape
                 `(batch_size, num_residues, 4, 3)`.
        """

        # Schedule management
        alpha, sigma, beta, g, lambda_t, lambda_langevin = self._schedule_coefficients(
            t,
            inverse_temperature=inverse_temperature,
            langevin_isothermal=langevin_isothermal,
        )
        score_scale_t = lambda_t + lambda_langevin * langevin_factor / 2.0

        # Impute missing data
        Z = torch.randn_like(X) if Z is None else Z

        # X = backbone.center_X(X, C)
        score = self.score(X, X0_func, C, t, conditioner, align_X0=align_X0)
        score_transformed = self.base_gaussian.multiply_covariance(score, C)

        f = (
            beta * (-1 / 2) * backbone.center_X(X, C)
            - g.pow(2) * score_scale_t * score_transformed
        )
        gZ = g * np.sqrt(1.0 + langevin_factor) * self.base_gaussian._multiply_R(Z, C)
        return f, gZ

    @validate_XC()
    def ode(
        self,
        X: torch.Tensor,
        X0_func: Callable,
        C: torch.LongTensor,
        t: Union[torch.Tensor, float],
        conditioner: Callable = None,
        Z: Union[torch.Tensor, None] = None,
        inverse_temperature: float = 1.0,
        langevin_factor: float = 0.0,
        langevin_isothermal: bool = True,
        align_X0: bool = True,
        detach_X0: bool = True,
    ):
        """Return the drift and diffusion components of the probability flow ODE.

        Args:
            X (torch.Tensor): A tensor of protein backbone structure with shape
                 `(batch_size, num_residues, 4, 3)`.
            X0_func (Callable): A denoising function that returns a protein backbone
                 geometry `(batch_size, num_residues, 4, 3)`.
            C (torch.LongTensor): A tensor of condition features with shape
                `(batch_size, num_residues)`.
            t (float): time in (1.0, 0.0).
            conditioner (Callable, optional): A conditioner the performs constrained
                transformation (see examples in chroma.layers.structure.conditioners).
            Z (torch.Tensor, optional): A tensor of random noise with shape
                 `(batch_size, num_residues, 4, 3)`. Default is None.
            inverse_temperature (float, optional): The inverse temperature parameter
                 for the Langevin dynamics. Default is 1.0.
            langevin_factor (float, optional): The scaling factor for the Langevin
                 noise. Default is 0.0.
            langevin_isothermal (bool, optional): A flag that indicates whether to use
                isothermal or non-isothermal Langevin dynamics. Default is True.
            align_X0 (bool, optional): A flag that indicates whether to align
                the noised X and denoised X for score function calculation.

        Returns:
            f (torch.Tensor): A tensor of drift terms with shape
                `(batch_size, num_residues, 4, 3)`.
            gZ (torch.Tensor): A tensor of diffusion terms with shape
                 `(batch_size, num_residues, 4, 3)`.
        """

        # Schedule management
        alpha, sigma, beta, g, lambda_t, lambda_langevin = self._schedule_coefficients(
            t,
            inverse_temperature=inverse_temperature,
            langevin_isothermal=langevin_isothermal,
        )

        # Impute missing data
        X = backbone.center_X(X, C)
        score = self.score(
            X, X0_func, C, t, conditioner, align_X0=align_X0, detach_X0=detach_X0
        )
        score_transformed = self.base_gaussian.multiply_covariance(score, C)
        f = (-1 / 2) * beta * X - 0.5 * lambda_langevin * g.pow(2) * score_transformed
        gZ = torch.zeros_like(f)
        return f, gZ

    @validate_XC()
    def energy(
        self,
        X: torch.Tensor,
        X0_func: Callable,
        C: torch.Tensor,
        t: torch.Tensor,
        detach_X0: bool = True,
        align_X0: bool = True,
    ) -> torch.Tensor:
        """Compute the diffusion energy as a function of denoised X

        Args:
            X (torch.Tensor): A tensor of protein backbone coordinates with shape
                 `(batch_size, num_residues, 4, 3)`.
            X0_func (Callable): A function a denoising function for protein backbone
                 geometry.
            C (torch.LongTensor): A tensor of condition features with shape
                `(batch_size, num_residues)`.
            t (float): time in (1.0, 0.0).
            detach_X0 (bool, optional): A flag that indicates whether to detach the
                denoise X for score function evaluation
            align_X0 (bool, optional): A flag that indicates whether to align the
                 noised X and denoised X for score function calculation.

        Returns:
            U_diffusion (torch.Tensor): A tensor of diffusion energy values with
                 shape `(batch_size,)`.
        """

        X = backbone.impute_masked_X(X, C)
        alpha = self.noise_schedule.alpha(t).to(X.device)
        sigma = self.noise_schedule.sigma(t).to(X.device)
        if detach_X0:
            with torch.no_grad():
                X0 = X0_func(X, C, t=t)
        else:
            X0 = X0_func(X, C, t=t)
        if align_X0:
            X0, _ = self.loss_rmsd.align(X0, X, C, align_unmasked=True)
        if detach_X0:
            X0 = X0.detach()
        Z = self._X_to_Z(X, X0, C, alpha, sigma)
        U_diffusion = (0.5 * (Z ** 2)).sum([1, 2, 3])
        return U_diffusion

    @validate_XC()
    def score(
        self,
        X: torch.Tensor,
        X0_func: Callable,
        C: torch.Tensor,
        t: Union[torch.Tensor, float],
        conditioner: Callable = None,
        detach_X0: bool = True,
        align_X0: bool = True,
        U_traj: List = [],
    ) -> torch.Tensor:
        """Compute the score function

        Args:
            X (torch.Tensor): A tensor of protein back geometry with shape
                 `(batch_size, num_residues, 4, 3)`.
            X0_func (Callable): A function a denoising function for protein backbone
                 geometry.
            C (torch.LongTensor): A tensor of chain map with shape
                `(batch_size, num_residues)`.
            t (Union[torch.Tensor, float]): time in (1.0, 0.0).
            conditioner (Callable, optional): A conditioner the performs constrained
                transformation (see examples in chroma.layers.structure.conditioners).
            detach_X0 (bool, optional): A flag that indicates whether to detach the
                 denoised X for score function evaluation
            align_X0 (bool, optional): A flag that indicates whether to align the
                 noised X and denoised X for score function calculation.
            U_traj (List, optional): Record diffusion energy as a list.

        Returns:
            score (torch.Tensor): A tensor of score values with shape
                 `(batch_size, num_residues, 4, 3)`.
        """

        X = backbone.impute_masked_X(X, C)
        with torch.enable_grad():
            X = X.detach().clone()
            X.requires_grad = True

            # Apply optional conditioner transformations to state and energy
            Xt, Ct, U_conditioner = X, C, 0.0
            St = torch.zeros(Ct.shape, device=Xt.device).long()
            Ot = F.one_hot(St, len(AA20)).float()
            if conditioner is not None:
                Xt, Ct, _, U_conditioner, _ = conditioner(X, C, Ot, U_conditioner, t)
            U_conditioner = torch.as_tensor(U_conditioner)

            # Compute system energy
            U_diffusion = self.energy(
                Xt, X0_func, Ct, t, detach_X0=detach_X0, align_X0=align_X0
            )

            U_traj.append(U_diffusion.detach().cpu())

            # Compute score function as negative energy gradient
            U_total = U_diffusion.sum() + U_conditioner.sum()
            U_total.backward()
            score = -X.grad
            score = score.masked_fill((C <= 0)[..., None, None], 0.0)
        return score

    def elbo(self, X0_pred, X0, C, t):
        """ITD ELBO as a weighted average of denoising error,
        inspired by https://arxiv.org/abs/2302.03792"""
        if not isinstance(t, torch.Tensor):
            t = torch.Tensor([t]).float().to(X0.device)

        # Interpolate missing data with Brownian Bridge posterior
        X0 = backbone.impute_masked_X(X0, C)
        X0_pred = backbone.impute_masked_X(X0_pred, C)

        # Compute whitened residual
        dX = (X0 - X0_pred).reshape([X0.shape[0], -1, 3])
        R_inv_dX = self.base_gaussian._multiply_R_inverse(dX, C)

        # Average per atom, including over "missing" positions that we filled in

        weight = 0.5 * self.noise_schedule.SNR_derivative(t)[:, None, None, None]
        snr = self.noise_schedule.SNR(t)[:, None, None, None]
        loss_itd = (
            weight * (R_inv_dX.pow(2) - 1 / (1 + snr))
            - 0.5 * np.log(np.pi * 2.0 * np.e)
        ).reshape(X0.shape)

        # Compute average per-atom loss (including over missing regions)
        mask = (C != 0).float()
        mask_atoms = mask.reshape(mask.shape + (1, 1)).expand([-1, -1, 4, 1])

        # Per-complex
        elbo_gap = (mask_atoms * loss_itd).sum([1, 2, 3])
        logdet = self.base_gaussian.log_determinant(C)
        elbo_unnormalized = elbo_gap - logdet

        # Normalize per atom
        elbo = elbo_unnormalized / (mask_atoms.sum([1, 2, 3]) + self._eps)

        # Compute batch average
        weights = mask_atoms.sum([1, 2, 3])
        elbo_batch = (weights * elbo).sum() / (weights.sum() + self._eps)
        return elbo, elbo_batch

    def pseudoelbo(self, loss_per_residue, C, t):
        """Compute pseudo-ELBOs as weighted averages of other errors."""
        if not isinstance(t, torch.Tensor):
            t = torch.Tensor([t]).float().to(C.device)

        # Average per atom, including over x"missing" positions that we filled in
        weight = 0.5 * self.noise_schedule.SNR_derivative(t)[:, None]
        loss = weight * loss_per_residue

        # Compute average loss
        mask = (C > 0).float()
        pseudoelbo = (mask * loss).sum(-1) / (mask.sum(-1) + self._eps)
        pseudoelbo_batch = (mask * loss).sum() / (mask.sum() + self._eps)
        return pseudoelbo, pseudoelbo_batch

    def _baoab_sample_step(
        self,
        _x,
        p,
        C,
        t,
        dt,
        score_func,
        gamma=2.0,
        kT=1.0,
        n_equil=1,
        ode_boost=True,
        langevin_isothermal=False,
    ):
        gamma = torch.Tensor([gamma]).to(_x.device)
        (
            alpha,
            sigma,
            beta,
            g,
            lambda_t,
            lambda_langevin,
        ) = self._schedule_coefficients(
            t, inverse_temperature=1 / kT, langevin_isothermal=langevin_isothermal,
        )

        def baoab_step(_x, p, t):
            Z = torch.randn_like(_x)
            c1 = torch.exp(-gamma * dt)
            c3 = torch.sqrt((1 / lambda_t) * (1 - c1 ** 2))

            # BAOAB scheme
            p_half = p + score_func(t, C, _x) * dt / 2  # B
            _x_half = (
                _x
                + g.pow(2) * self.base_gaussian.multiply_covariance(p_half, C) * dt / 2
            )  # A
            p_half2 = c1 * p_half + c3 * (
                1 / g
            ) * self.base_gaussian._multiply_R_inverse_transpose(
                Z, C
            )  # O
            _x = (
                _x_half
                + g.pow(2) * self.base_gaussian.multiply_covariance(p_half2, C) * dt / 2
            )  # A
            p = p_half2 + score_func(t, C, _x) * dt / 2  # B

            return _x, p

        def ode_step(t, _x):
            score = score_func(t, C, _x)
            score_transformed = self.base_gaussian.multiply_covariance(score, C)
            _x = _x + 0.5 * (_x + score_transformed) * g.pow(2) * dt
            return _x

        for i in range(n_equil):
            _x, p = baoab_step(_x, p, t)

        if ode_boost:
            _x = ode_step(t, _x)

        return _x, p

    @torch.no_grad()
    def sample_sde(
        self,
        X0_func: Callable,
        h_V: torch.LongTensor,
        C: torch.LongTensor,
        X_init: Optional[torch.Tensor] = None,
        conditioner: Optional[Callable] = None,
        N: int = 100,
        tspan: Tuple[float, float] = (1.0, 0.001),
        inverse_temperature: float = 1.0,
        langevin_factor: float = 0.0,
        langevin_isothermal: bool = True,
        sde_func: str = "reverse_sde",
        integrate_func: str = "euler_maruyama",
        initialize_noise: bool = True,
        remap_time: bool = False,
        remove_drift_translate: bool = False,
        remove_noise_translate: bool = False,
        align_X0: bool = True,
    ) -> Dict[str, torch.Tensor]:
        """Sample from the SDE using a numerical integration scheme.

        This function samples from the stochastic differential equation (SDE) defined
        by the model using a numerical integration scheme such as Euler-Maruyama or
        huen. The SDE can be either in the forward or reverse direction. The function
        also supports optional conditioning on external variables and adding Langevin
        noise to the SDE dynamics.

        Args:
            X0_func (Callable): A denoising function that maps `(X, C, t)` to `X0`.
            C (torch.LongTensor): Conditioner tensor with shape `(num_batch,
                num_residues)`.
            X_init (torch.Tensor, optional): Initial state tensor with shape `(num_batch
                , num_residues, 4 ,3)` or None.
                If None, a zero tensor will be used as the initial state.
            conditioner (Callable, optional): A function that transforms X, C, U, t.
                If None, no conditioning will be applied.
            N (int): Number of integration steps.
            tspan (Tuple[float,float]): Time span for integration.
            inverse_temperature (float): Inverse temperature parameter for SDE.
            langevin_factor (float): Langevin factor for adding noise to SDE.
            langevin_isothermal (bool): Whether to use isothermal or adiabatic Langevin
                 dynamics.
            sde_func (str): Which SDE function to use ('reverse_sde', 'langevin' or 'ode').
            integrate_func (str): Which integration function to use ('euler_maruyama'
                 or 'heun').
            initialize_noise (bool): Whether to initialize the state with noise.
            remap_time (bool): Whether to remap the time grid according to the noise
                 schedule.
            remove_drift_translate (bool): Whether to remove the net translational
                 component from the drift term.
            remove_noise_translate (bool): Whether to remove the net translational
                 component from the noise term.
            align_X0 (bool): Whether to Kabsch align X0 with X before computing SDE terms.

        Returns:
            outputs (Dict[str, torch.Tensor]): A dictionary of output tensors with the
            following keys:
                - 'C': The conditioned tensor with shape `(num_batch,num_residues)`.
                - 'X_sample': The final sampled state tensor with shape `(num_batch,
                    num_residues ,4 ,3)`.
                - 'X_trajectory': A list of state tensors along the trajectory with
                    shape `(num_batch,num_residues ,4 ,3)` each.
                - 'Xhat_trajectory': A list of transformed state tensors along the
                    trajectory with shape `(num_batch,num_residues ,4 ,3)` each.
                - 'Xunc_trajectory': A list of unconstrained state tensors along the
                    trajectory with shape `(num_batch,num_residues ,4 ,3)` each.
        """

        # Setup SDE integration
        integrate_func = self.integrate_funcs[integrate_func]
        sde_func = self.sde_funcs[sde_func]
        T_grid = (
            self.noise_schedule.linear_logsnr_grid(N=N, tspan=tspan).to(C.device)
            if remap_time
            else torch.linspace(tspan[0], tspan[1], N + 1).to(C.device)
        )

        # Intercept the X0 function for tracking Xt and Xhat
        Xhat_trajectory = []
        Xt_trajectory = []
        U_trajectory = []

        def _X0_func(_X, _C, t):
            _X0 = X0_func(_X, h_V, _C, t)
            Xt_trajectory.append(_X.detach())
            Xhat_trajectory.append(_X0.detach())
            return _X0

        def sdefun(_t, _X):
            f, gZ = sde_func(
                _X,
                _X0_func,
                C,
                _t,
                conditioner=conditioner,
                inverse_temperature=inverse_temperature,
                langevin_factor=langevin_factor,
                langevin_isothermal=langevin_isothermal,
                align_X0=align_X0,
            )
            # Remove net translational component
            if remove_drift_translate:
                f = backbone.center_X(f, C)
            if remove_noise_translate:
                gZ = backbone.center_X(gZ, C)
            return f, gZ

        # Initialization
        if initialize_noise and X_init is not None:
            X_init = self.forward(X_init, C, t=tspan[0]).detach()
        elif X_init is None:
            X_init = torch.zeros(list(C.shape) + [4, 3], device=C.device)
            X_init = self.forward(X_init, C, t=tspan[0]).detach()

        # Determine output shape via a test forward pass
        if conditioner:
            with torch.enable_grad():
                X_init_test = X_init.clone()
                X_init_test.requires_grad = True
                S_test = torch.zeros(C.shape, device=X_init.device).long()
                O_test = F.one_hot(S_test, len(AA20)).float()
                U_test = 0.0
                t_test = torch.tensor([0.0], device=X_init.device)
                _, Ct, _, _, _ = conditioner(X_init_test, C, O_test, U_test, t_test)
        else:
            Ct = C

        # Integrate
        X_trajectory = integrate_func(sdefun, X_init, tspan, N=N, T_grid=T_grid)

        # Return constrained coordinates
        outputs = {
            "C": Ct,
            "X_sample": Xt_trajectory[-1],
            "X_trajectory": [Xt_trajectory[-1]] + Xt_trajectory,
            "Xhat_trajectory": Xhat_trajectory,
            "Xunc_trajectory": X_trajectory,
        }
        return outputs

    @torch.no_grad()
    def estimate_pseudoelbo_X(
        self,
        X0_func,
        X,
        C,
        num_samples=50,
        deterministic_seed=0,
        return_elbo_t=False,
        noise=True,
    ):
        with torch.random.fork_rng():
            torch.random.manual_seed(deterministic_seed)

            mask = (C > 0).float()
            mask_atoms = mask.reshape(list(mask.shape) + [1, 1]).expand([-1, -1, 4, 1])

            elbo = []
            T = np.linspace(1e-4, 1.0, num_samples)
            for t in tqdm(T.tolist()):
                X_noise = self.forward(X, C, t=t) if noise else X
                X_denoise = X0_func(X_noise, C, t)

                elbo_t = -self.noise_schedule.SNR_derivative(t).to(X.device) * (
                    ((mask_atoms * (X_denoise - X) / 10.0) ** 2).sum([1, 2, 3])
                    / mask_atoms.sum([1, 2, 3])
                )
                elbo.append(elbo_t)
            elbo = torch.stack(elbo, 0)
            if not return_elbo_t:
                elbo = elbo.mean(0)
        return elbo

    def _score_direct(
        self, Xt, X0_func, C, t, align_X0=True,
    ):
        X0 = X0_func(Xt, C, t)

        """Compute the score function directly. (Sometimes numerically unstable)"""

        alpha = self.noise_schedule.alpha(t).to(Xt.device)
        sigma = self.noise_schedule.sigma(t).to(Xt.device)

        # Impute sensibly behaved values in masked regions for numerical stability
        # X0 = backbone.impute_masked_X(X0, C)
        Xt = backbone.impute_masked_X(Xt, C)

        if align_X0:
            X0, _ = self.loss_rmsd.align(X0, Xt, C, align_unmasked=True)

        # Compute mean
        X_mu = self._mean(X0, C, alpha)
        X_mu = backbone.impute_masked_X(X_mu, C)
        dX = Xt - X_mu

        Ci_dX = self.base_gaussian.multiply_inverse_covariance(dX, C)
        score = -Ci_dX / sigma.pow(2)[:, None, None, None]

        # Mask
        score = score.masked_fill((C <= 0)[..., None, None], 0.0)

        return score

    def estimate_logp(
        self,
        X0_func: Callable,
        X_sample: torch.Tensor,
        C: torch.LongTensor,
        N: int,
        return_trace_t: bool = False,
    ):
        """Estimate the model logP for given protein backboones
            (num_batch, num_residues, 4, 3) by the Continuous Normalizing Flow formalism

            Reference:
                https://arxiv.org/abs/1810.01367
                https://arxiv.org/abs/1806.07366

        Args:
            X0_func (Callable): A function that returns the initial protein backboone
                 (num) features given a condition.
            X_sample (torch.Tensor): A tensor of protein backboone (num) features with
            shape
                `(batch_size, num_residues, 4, 3)`.
            C (torch.Tensor): A tensor of condition features with shape `(batch_size,
                 num_residues)`.
            N (int, optional): number of ode integration steps
            return_trace_t (bool, optional): A flag that indicates whether to return the
            log |df / dx| for each time step for the integrated log Jacobian trance.
              Default is False.

        Returns:
            elbo (torch.Tensor): A tensor of logP value
            if return_elbo_t is False, or `(N)` if return_elbo_t
            is True.
        """

        def divergence(fn, x, t):
            """Calculate Divergance with Stochastic Trace Estimator"""
            vec_eps = torch.randn_like(x)
            fn_out, eps_J_prod = torch.autograd.functional.vjp(
                fn, (t, x), vec_eps, create_graph=False
            )
            eps_J_eps = (
                (eps_J_prod[1] * vec_eps).reshape(x.shape[0], -1).sum(-1).unsqueeze(-1)
            )
            return fn_out, eps_J_eps

        def flow_gradient(
            X, X0_func, C, t,
        ):
            """Compute the time gradient from the probability flow ODE."""

            _, _, beta, g, _, _ = self._schedule_coefficients(t)
            score = self._score_direct(X, X0_func, C, t)
            dXdt = (-1 / 2) * beta * X - 0.5 * g.pow(2) * score

            return dXdt

        def odefun(_t, _X):
            _t = _t.detach()
            f = flow_gradient(_X, X0_func, C, _t,)
            return f

        # foward integration to noise
        X_sample = backbone.center_X(X_sample, C)
        X_sample = backbone.impute_masked_X(X_sample, C)
        C = C.abs()

        out = self.sample_sde(
            X0_func=X0_func,
            C=C,
            X_init=X_sample,
            N=N,
            sde_func="ode",
            tspan=(0, 1.0),
            inverse_temperature=1.0,
            langevin_factor=0.0,
            initialize_noise=False,
            align_X0=False,
        )

        X_flow = out["X_trajectory"][1:]

        # get ode function
        ddlogp = []

        for i, t in enumerate(tqdm(torch.linspace(1e-2, 1.0, len(X_flow)))):
            with torch.enable_grad():
                dlogP = divergence(odefun, X_flow[i], t[None].to(C.device))[1]
            ddlogp.append(dlogP.item())

        logp_x1 = self.base_gaussian.log_prob(X_flow[-1], C).item()

        if return_trace_t:
            return np.array(ddlogp) / ((C > 0).float().sum().item() * 4)
        else:
            return (logp_x1 + np.array(ddlogp).mean()) / (
                (C > 0).float().sum().item() * 4
            )

    @torch.no_grad()
    @validate_XC(all_atom=False)
    def estimate_elbo(
        self,
        X0_func: Callable,
        X: torch.Tensor,
        C: torch.LongTensor,
        num_samples: int = 50,
        deterministic_seed: int = 0,
        return_elbo_t: bool = False,
        grad_logprob_Y_func: Optional[Callable] = None,
    ) -> torch.Tensor:
        """Estimate the evidence lower bound (ELBO) for given protein backboones
            (num_batch, num_residues, 4, 3) and condition.

        Args:
            X0_func (Callable): A function that returns the initial protein backboone
                 (num) features given a condition.
            X (torch.Tensor): A tensor of protein backboone (num) features with shape
                `(batch_size, num_residues, 4, 3)`.
            C (torch.Tensor): A tensor of condition features with shape `(batch_size,
                 num_residues)`.
            num_samples (int, optional): The number of time steps to sample for
                estimating the ELBO. Default is 50.
            deterministic_seed (int, optional): The seed for generating random noise.
                 Default is 0.
            return_elbo_t (bool, optional): A flag that indicates whether to return the
            ELBO for each time step or the average ELBO. Default is False.
            grad_logprob_Y_func (Optional[Callable], optional): A function that returns
            the gradient of the log probability of the observed protein backboone (num)
            given a time step and a noisy image. Default is None.

        Returns:
            elbo (torch.Tensor): A tensor of ELBO values with shape `(batch_size,)`
            if return_elbo_t is False, or `(num_samples, batch_size)` if return_elbo_t
            is True.
        """
        X = backbone.impute_masked_X(X, C)

        with torch.random.fork_rng():
            torch.random.manual_seed(deterministic_seed)
            mask = (C > 0).float()
            mask_atoms = mask.reshape(list(mask.shape) + [1, 1]).expand([-1, -1, 4, 1])

            elbo = []
            T = np.linspace(1e-4, 1.0, num_samples)
            for t in tqdm(T.tolist()):
                X_noise = self.forward(X, C, t=t)
                X_denoise = X0_func(X_noise, C, t)

                # Adjust X-hat estimate with aux-grad
                if grad_logprob_Y_func is not None:
                    with torch.random.fork_rng():
                        grad = grad_logprob_Y_func(t, X_noise)
                        sigma_square = (
                            self.noise_schedule.sigma(t).square().to(X.device)
                        )
                        dXhat = sigma_square * self.base_gaussian.multiply_covariance(
                            grad, C
                        )
                        dXhat = backbone.center_X(dXhat, C)
                        X_denoise = X_denoise + dXhat

                elbo_t, _ = self.elbo(X_denoise, X, C, t)

                elbo.append(elbo_t)

            elbo_t = torch.stack(elbo, 0)

        if return_elbo_t:
            return elbo_t
        else:
            return elbo_t.mean(0)

    def conditional_X0(
        self, X0: torch.Tensor, score: torch.Tensor, C: torch.tensor, t: torch.Tensor
    ) -> torch.Tensor:
        """Use Bayes theorem and Tweedie formula to obtain a conditional X0 given
        prior X0 and a conditional score \nabla_x p( y | x)
        X0 <- X0 + \frac{sigma_t**2}{alpha_t} \Sigma score
        Args:
            X0 (torch.Tensor): backbone coordinates of size (batch, num_residues, 4, 3)
            score (torch.Tensor): of size (batch, num_residues, 4, 3)
            C (torch.Tensor): of size (batch, num_residues)
            t (torch.Tensor): of size (batch,)

        Returns:
            X0 (torch.Tensor): updated conditional X0 of size (batch, num_residues, 4, 3)
        """
        alpha, sigma, _, _, _, _ = self._schedule_coefficients(t)
        X_update = sigma.pow(2).div(alpha)[
            ..., None
        ] * self.base_gaussian.multiply_covariance(score, C)
        return X0 + X_update

    def _mean(self, X, C, alpha):
        """Build the diffusion kernel mean given alpha"""
        # Compute the MVN mean
        X_mu = backbone.scale_around_mean(X, C, alpha)
        return X_mu

    def _X_to_Z(self, X_sample, X, C, alpha, sigma):
        """Convert from output space to standardized space"""

        # Impute missing data with conditional means
        X = backbone.impute_masked_X(X, C)
        X_sample = backbone.impute_masked_X(X_sample, C)

        # sigma = self.noise_schedule.sigma(t).to(X.device)

        # Step 4. [Inverse] Add mean
        X_mu = self._mean(X, C, alpha)
        X_mu = backbone.impute_masked_X(X_mu, C)
        X_noise = (X_sample - X_mu).reshape(X.shape[0], -1, 3)

        # Step 3. [Inverse] Scale noise by sigma
        X_noise = X_noise / sigma[:, None, None]

        # Step 1 & 2. Multiply Z by inverse square root of covariance
        Z = self.base_gaussian._multiply_R_inverse(X_noise, C)

        return Z

    def _Z_to_X(self, Z, X, C, alpha, sigma):
        """Convert from standardized space to output space"""

        # Step 1 & 2. Multiply Z by square root of covariance
        dX = self.base_gaussian._multiply_R(Z, C)

        # Step 3. Scale noise by alpha
        dX = sigma[:, None, None, None] * dX.reshape(X.shape)

        # Step 4. Add mean
        X_mu = self._mean(X, C, alpha)
        X_sample = X_mu + dX

        return X_sample

    def sample_conditional(
        self, X: torch.Tensor, C: torch.LongTensor, t: torch.Tensor, s: torch.Tensor
    ) -> torch.Tensor:
        """
        Samples from the forward process q(x_{t} | x_{s}) for t > s.
        See appendix A.1 in [https://arxiv.org/pdf/2107.00630.pdf]. `forward` does this for s = 0.
        Args:
            X (torch.Tensor): Input coordinates with shape `(batch_size, num_residues,
                4, 3)` at time `t0`.
            C (torch.Tensor): Chain tensor with shape `(batch_size, num_residues)`.
            t (torch.Tensor): Time index with shape `(batch_size,)`.
            s (torch.Tensor): Time index with shape `(batch_size,)`.

        Returns:
            X_sample (torch.Tensor): Sampled coordinates from the forward diffusion
                marginals with shape `(batch_size, num_residues, 4, 3)`.
        """
        assert (t > s).all()
        X = backbone.impute_masked_X(X, C)
        # Do we need this?
        X = backbone.center_X(X, C)
        alpha_ts = self.noise_schedule.alpha(t) / self.noise_schedule.alpha(s)
        sigma_ts = (
            self.noise_schedule.sigma(t).pow(2)
            - alpha_ts.pow(2) * self.noise_schedule.sigma(s).pow(2)
        ).sqrt()

        X_sample = alpha_ts * X + sigma_ts * self.base_gaussian.sample(C)
        # Do we need this?
        X_sample = backbone.center_X(X_sample - X, C) + X
        return X_sample

    @validate_XC(all_atom=False)
    def forward(
        self, X: torch.Tensor, C: torch.LongTensor, t: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample from the forwards diffusion marginals at time t

        Inputs:
            X (torch.Tensor): Input coordinates with shape `(batch_size, num_residues,
                4, 3)`.
            C (torch.LongTensor): Chain tensor with shape `(batch_size, num_residues)`.
            t (torch.Tensor, optional): Time index with shape `(batch_size,)`. If not
                given, a random time index will be sampled. Defaults to None.

        Outputs:
            X_sample (torch.Tensor): Sampled coordinates from the forward diffusion
                marginals with shape `(batch_size, num_residues, 4, 3)`.
            t (torch.Tensor, optional): Time index with shape `(batch_size,)`. Only
                returned if t is not given as input.
        """

        # Draw a sample from the prior
        X_prior = self.base_gaussian.sample(C)

        # Sample time if not given
        t_input = t
        t = self.sample_t(C, t)

        alpha = self.noise_schedule.alpha(t)[:, None, None, None].to(X.device)
        sigma = self.noise_schedule.sigma(t)[:, None, None, None].to(X.device)

        X_sample = alpha * X + sigma * X_prior
        X_sample = backbone.center_X(X_sample - X, C) + X

        if t_input is None:
            return X_sample, t
        else:
            return X_sample


class ReconstructionLosses(nn.Module):
    """Compute diffusion reconstruction losses for protein backbones.

    Args:
        diffusion (DiffusionChainCov): Diffusion object parameterizing a
            forwards diffusion over protein backbones.
        loss_scale (float): Length scale parameter used for setting loss error
            scaling in units of Angstroms. Default is 10, which corresponds to
            using units of nanometers.
        rmsd_method (str): Method used for computing RMSD superpositions. Can
            be "symeig" (default) or "power" for power iteration.

    Inputs:
        X0_pred (torch.Tensor): Denoised coordinates with shape
            `(num_batch, num_residues, 4, 3)`.
        X (torch.Tensor): Unperturbed coordinates with shape
            `(num_batch, num_residues, 4, 3)`.
        C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`.
        t (torch.Tensor): Diffusion time with shape `(batch_size,)`.
            Should be on [0,1].

    Outputs:
        losses (dict): Dictionary of reconstructions computed across different
            metrics. Metrics prefixed with `batch_` will be batch-averaged scalars
            while other metrics should be per batch member with shape
            `(num_batch, ...)`.
    """

    def __init__(
        self,
        diffusion: DiffusionChainCov,
        loss_scale: float = 10.0,
        rmsd_method: str = "symeig",
    ):
        super().__init__()
        self.noise_perturb = diffusion
        self.loss_scale = loss_scale
        self._loss_eps = 1e-5

        # Auxiliary losses
        self.loss_rmsd = rmsd.BackboneRMSD(method=rmsd_method)
        self.loss_fragment = rmsd.LossFragmentRMSD(method=rmsd_method)
        self.loss_fragment_pair = rmsd.LossFragmentPairRMSD(method=rmsd_method)
        self.loss_neighborhood = rmsd.LossNeighborhoodRMSD(method=rmsd_method)
        self.loss_hbond = hbonds.LossBackboneHBonds()
        self.loss_distance = backbone.LossBackboneResidueDistance()

        self.loss_functions = {
            "elbo": self._loss_elbo,
            "rmsd": self._loss_rmsd,
            "pseudoelbo": self._loss_pseudoelbo,
            "fragment": self._loss_fragment,
            "pair": self._loss_pair,
            "neighborhood": self._loss_neighborhood,
            "distance": self._loss_distance,
            "hbonds": self._loss_hbonds,
        }

    def _batch_average(self, loss, C):
        weights = (C > 0).float().sum(-1)
        return (weights * loss).sum() / (weights.sum() + self._loss_eps)

    def _loss_elbo(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
        losses["elbo"], losses["batch_elbo"] = self.noise_perturb.elbo(X0_pred, X, C, t)

    def _loss_rmsd(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
        _, rmsd_denoise = self.loss_rmsd.align(X, X0_pred, C)
        _, rmsd_noise = self.loss_rmsd.align(X, X_t_2, C)
        rmsd_ratio_per_item = w * rmsd_denoise / (rmsd_noise + self._loss_eps)
        global_mse_normalized = (
            w
            * self.loss_scale
            * rmsd_denoise.square()
            / (rmsd_noise.square() + self._loss_eps)
        )
        losses["rmsd_ratio"] = self._batch_average(rmsd_ratio_per_item, C)
        losses["global_mse"] = global_mse_normalized
        losses["batch_global_mse"] = self._batch_average(global_mse_normalized, C)

    def _loss_pseudoelbo(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
        # Unaligned residual pseudoELBO
        unaligned_mse = ((X - X0_pred) / self.loss_scale).square().sum(-1).mean(-1)
        losses["elbo_X"], losses["batch_pseudoelbo_X"] = self.noise_perturb.pseudoelbo(
            unaligned_mse, C, t
        )

    def _loss_fragment(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
        # Aligned Fragment MSE loss
        mask = (C > 0).float()
        rmsd_fragment = self.loss_fragment(X0_pred, X, C)
        rmsd_fragment_noise = self.loss_fragment(X_t_2, X, C)
        fragment_mse_normalized = (
            self.loss_scale
            * w
            * (
                (mask * rmsd_fragment.square()).sum(1)
                / ((mask * rmsd_fragment_noise.square()).sum(1) + self._loss_eps)
            )
        )
        losses["fragment_mse"] = fragment_mse_normalized
        losses["batch_fragment_mse"] = self._batch_average(fragment_mse_normalized, C)

    def _loss_pair(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
        # Aligned Pair MSE loss
        rmsd_pair, mask_ij_pair = self.loss_fragment_pair(X0_pred, X, C)
        rmsd_pair_noise, mask_ij_pair = self.loss_fragment_pair(X_t_2, X, C)
        pair_mse_normalized = (
            self.loss_scale
            * w
            * (
                (mask_ij_pair * rmsd_pair.square()).sum([1, 2])
                / (
                    (mask_ij_pair * rmsd_pair_noise.square()).sum([1, 2])
                    + self._loss_eps
                )
            )
        )
        losses["pair_mse"] = pair_mse_normalized
        losses["batch_pair_mse"] = self._batch_average(pair_mse_normalized, C)

    def _loss_neighborhood(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
        # Neighborhood MSE
        rmsd_neighborhood, mask = self.loss_neighborhood(X0_pred, X, C)
        rmsd_neighborhood_noise, mask = self.loss_neighborhood(X_t_2, X, C)
        neighborhood_mse_normalized = (
            self.loss_scale
            * w
            * (
                (mask * rmsd_neighborhood.square()).sum(1)
                / ((mask * rmsd_neighborhood_noise.square()).sum(1) + self._loss_eps)
            )
        )
        losses["neighborhood_mse"] = neighborhood_mse_normalized
        losses["batch_neighborhood_mse"] = self._batch_average(
            neighborhood_mse_normalized, C
        )

    def _loss_distance(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
        # Distance MSE
        mask = (C > 0).float()
        distance_mse = self.loss_distance(X0_pred, X, C)
        distance_mse_noise = self.loss_distance(X_t_2, X, C)
        distance_mse_normalized = self.loss_scale * (
            w
            * (mask * distance_mse).sum(1)
            / ((mask * distance_mse_noise).sum(1) + self._loss_eps)
        )
        losses["distance_mse"] = distance_mse_normalized
        losses["batch_distance_mse"] = self._batch_average(distance_mse_normalized, C)

    def _loss_hbonds(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
        # HBond recovery
        outs = self.loss_hbond(X0_pred, X, C)
        hb_local, hb_nonlocal, error_co = [w * o for o in outs]

        losses["batch_hb_local"] = self._batch_average(hb_local, C)
        losses["hb_local"] = hb_local
        losses["batch_hb_nonlocal"] = self._batch_average(hb_nonlocal, C)
        losses["hb_nonlocal"] = hb_nonlocal
        losses["batch_hb_contact_order"] = self._batch_average(error_co, C)

    @torch.no_grad()
    @validate_XC(all_atom=False)
    def estimate_metrics(
        self,
        X0_func: Callable,
        X: torch.Tensor,
        C: torch.LongTensor,
        num_samples: int = 50,
        deterministic_seed: int = 0,
        use_noise: bool = True,
        return_samples: bool = False,
        tspan: Tuple[float] = (1e-4, 1.0),
    ):
        """Estimate time-averaged reconstruction losses of protein backbones.

        Args:
            X0_func (Callable): A denoising function that maps `(X, C, t)` to `X0`.
            X (torch.Tensor): A tensor of protein backboone (num) features with shape
                `(batch_size, num_residues, 4, 3)`.
            C (torch.Tensor): A tensor of condition features with shape `(batch_size,
                num_residues)`.
            num_samples (int, optional): The number of time steps to sample for
            estimating the ELBO. Default is 50.
            use_noise (bool): If True, add noise to each structure before denoising.
                Default is True. When False this can be used for estimating if
                if structures are fixed points of the denoiser across time.
            deterministic_seed (int, optional): The seed for generating random noise.
                Default is 0.
            return_samples (bool): If True, include intermediate sampled
                values for each metric. Default is false.
            tspan (tuple[float]): Tuple of floats indicating the diffusion
                times between which to integrate.

        Returns:
            metrics (dict): A dictionary of reconstruction metrics averaged over
                time.
            metrics_samples (dict, optional): A dictionary of in metrics
                averaged over time.
        """
        #
        X = backbone.impute_masked_X(X, C)
        with torch.random.fork_rng():
            torch.random.manual_seed(deterministic_seed)
            T = np.linspace(1e-4, 1.0, num_samples)
            losses = []
            for t in tqdm(T.tolist(), desc="Integrating diffusion metrics"):
                X_noise = self.noise_perturb(X, C, t=t) if use_noise else X
                X_denoise = X0_func(X_noise, C, t)
                losses_t = self.forward(X_denoise, X, C, t)

                # Discard batch estimated objects
                losses_t = {
                    k: v
                    for k, v in losses_t.items()
                    if not k.startswith("batch_") and k != "rmsd_ratio"
                }
                losses.append(losses_t)

            # Transpose list of dicts to a dict of lists
            metrics_samples = {k: [d[k] for d in losses] for k in losses[0].keys()}

            # Average final metrics across time
            metrics = {
                k: torch.stack(v, 0).mean(0)
                for k, v in metrics_samples.items()
                if isinstance(v[0], torch.Tensor)
            }
        if return_samples:
            return metrics, metrics_samples
        else:
            return metrics

    @validate_XC()
    def forward(
        self,
        X0_pred: torch.Tensor,
        X: torch.Tensor,
        C: torch.LongTensor,
        t: torch.Tensor,
    ):
        # Collect all losses and tensors for metric tracking
        losses = {"t": t, "X": X, "X0_pred": X0_pred}
        X_t_2 = self.noise_perturb(X, C, t=t)

        # Per complex weights
        ssnr = self.noise_perturb.noise_schedule.SSNR(t).to(X.device)
        prob_ssnr = self.noise_perturb.noise_schedule.prob_SSNR(ssnr)
        importance_weights = 1 / prob_ssnr

        for _loss in self.loss_functions.values():
            _loss(losses, X0_pred, X, C, t, w=importance_weights, X_t_2=X_t_2)
        return losses


def _debug_viz_gradients(
    pml_file, X_list, dX_list, C, S, arrow_length=2.0, name="gradient", color="red"
):
    """ """
    lines = [
        "from pymol.cgo import *",
        "from pymol import cmd",
        f'color_1 = list(pymol.cmd.get_color_tuple("{color}"))',
        'color_2 = list(pymol.cmd.get_color_tuple("blue"))',
    ]

    with open(pml_file, "w") as f:
        for model_ix, X in enumerate(X_list):
            print(model_ix)
            lines = lines + ["obj_1 = []"]

            dX = dX_list[model_ix]
            scale = dX.norm(dim=-1).mean().item()
            X_i = X
            X_j = X + arrow_length * dX / scale

            for a_ix in range(4):
                for i in range(X.size(1)):
                    x_i = X_i[0, i, a_ix, :].tolist()
                    x_j = X_j[0, i, a_ix, :].tolist()
                    lines = lines + [
                        f"obj_1 = obj_1 + [CYLINDER] + {x_i} + {x_j} + [0.15]"
                        " + color_1 + color_1"
                    ]
            lines = lines + [f'cmd.load_cgo(obj_1, "{name}", {model_ix+1})']
            f.write("\n" + "\n".join(lines))
            lines = []

def _debug_viz_XZC(X, Z, C, rgb=True):
    from matplotlib import pyplot as plt

    if len(X.shape) > 3:
        X = X.reshape(X.shape[0], -1, 3)
    if len(Z.shape) > 3:
        Z = Z.reshape(Z.shape[0], -1, 3)
    if C.shape[1] != X.shape[1]:
        C_expand = C.unsqueeze(-1).expand(-1, -1, 4)
        C = C_expand.reshape(C.shape[0], -1)

    # C_mask = expand_chain_map(torch.abs(C))
    # X_expand = torch.einsum('nix,nic->nicx', X, C_mask)
    # plt.plot(X_expand[0,:,:,0].data.numpy())
    N = X.shape[1]
    Ymax = torch.max(X[0, :, 0]).item()
    plt.figure(figsize=[12, 4])
    plt.subplot(2, 1, 1)

    plt.bar(
        np.arange(0, N),
        (C[0, :].data.numpy() < 0) * Ymax,
        width=1.0,
        edgecolor=None,
        color="lightgrey",
    )
    if rgb:
        plt.plot(X[0, :, 0].data.numpy(), "r", linewidth=0.5)
        plt.plot(X[0, :, 1].data.numpy(), "g", linewidth=0.5)
        plt.plot(X[0, :, 2].data.numpy(), "b", linewidth=0.5)
        plt.xlim([0, N])
        plt.grid()
        plt.title("X")
        plt.xticks([])
        plt.subplot(2, 1, 2)
        plt.plot(Z[0, :, 0].data.numpy(), "r", linewidth=0.5)
        plt.plot(Z[0, :, 1].data.numpy(), "g", linewidth=0.5)
        plt.plot(Z[0, :, 2].data.numpy(), "b", linewidth=0.5)
        plt.plot(C[0, :].data.numpy(), "orange")
        plt.xlim([0, N])
        plt.grid()
        plt.title("RInverse @ [X]")
        plt.xticks([])
        plt.savefig("xzc.pdf")
    else:
        plt.plot(X[0, :, 0].data.numpy(), "k", linewidth=0.5)
        plt.xlim([0, N])
        plt.grid()
        plt.title("X")
        plt.xticks([])
        plt.subplot(2, 1, 2)
        plt.plot(Z[0, :, 0].data.numpy(), "k", linewidth=0.5)
        plt.plot(C[0, :].data.numpy(), "orange")
        plt.xlim([0, N])
        plt.grid()
        plt.title("Inverse[X]")
        plt.xticks([])
        plt.savefig("xzc.pdf")
    exit()
