# inspired by https://github.com/NVIDIA-Digital-Bio/proteina/blob/main/proteinfoundation/flow_matching/r3n_fm.py
import torch
from torch import Tensor
from typing import Optional, Callable, Literal, Tuple
import math
from tqdm import tqdm
import einops


class R3NFlowMatcher:
    def __init__(
        self,
        sigma: float = 0.1,
        scale_ref: float = 1.0,
    ):
        self.dim = 3
        self.scale_ref = scale_ref
        self.sigma = sigma

    def _mean_w_mask(self, a, mask, keepdim=True):
        """
        Computes the mean of point cloud a accounting for the mask.

        Args:
            a: Input point cloud of shape [*, n, d]
            mask: Input mask of shape [*, n] of boolean values
            keepdim: whether to keep the dimension across which we're computing the mean
                like normal pytorch mean

        Returns:
            Masked mean of a across dimension -2 (or n)
        """
        mask = mask[..., None]  # [*, n, 1]
        num_elements = torch.sum(mask, dim=-2, keepdim=True)  # [*, 1, 1]
        num_elements = torch.where(
            num_elements == 0, torch.tensor(1.0), num_elements
        )  # [*, 1, 1]
        a_masked = torch.masked_fill(a, ~mask, 0.0)  # [*, n, d]
        mean = torch.sum(a_masked, dim=-2, keepdim=True) / num_elements  # [*, 1, d]
        mean = torch.masked_fill(mean, num_elements == 0, 0.0)  # [*, 1, d]
        if not keepdim:
            mean = einops.rearrange(mean, "... () d -> ... d")
        return mean

    def _force_zero_com(
        self,
        x: Tensor,
        mask: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Centers tensor over n dimension.

        Args:
            x: Tensor of shape [*, n, 3]
            mask (optional): Binary mask of shape [*, n]

        Returns:
            Centered x = x - mean(x, dim=-2), shape [*, n, 3].
        """
        if mask is None:
            x = x - torch.mean(x, dim=-2, keepdim=True)
        else:
            x = (x - self._mean_w_mask(x, mask, keepdim=True)) * mask[..., None]
        return x

    def _apply_mask(
        self,
        x: Tensor,
        mask: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Applies mask to x. Sets masked elements to zero.

        Args:
            x: Tensor of shape [*, n, 3]
            mask (optional): Binary mask of shape [*, n]

        Returns:
            Masked x of shape [*, n, 3]
        """
        if mask is None:
            return x
        return x * mask[..., None]  # [*, n, 3]

    def _mask_and_zero_com(self, x, mask: Optional[Tensor] = None) -> Tensor:
        """
        Applies mask to and centers x if needed (if zero_com=True).

        Args:
            x: Batch of samples, batch shape *
            mask (optional): Binary mask of shape [*, n]

        Returns:
            Masked (and possibly center) samples.
        """
        x = self._apply_mask(x, mask)
        x = self._force_zero_com(x, mask)
        return x

    def _extend_t(self, n: int, t: Tensor) -> Tensor:
        """
        Extends t shape with n. Needed to use flow matching utils.

        Args:
            n (int): Number of elements per sample (e.g. number of residues)
            t: Float vector, shape [*]

        Returns:
            Extended t vector of shape [*, n] compatible with flow matching utils.
        """
        return t[..., None].expand(t.shape + (n,))

    def sample_noise(
        self,
        n: int,
        b: int,
        device: Optional[torch.device] = None,
        mask: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Samples noise distribution std Gaussian (possibly centered).

        Args:
            n: number of frames in a single sample
            b: number of samples
            device (optional): torch device used
            mask (optional): Binary mask of shape [*, n]

        Returns:
            Samples from refenrece [N(0, I_3)]^n shape [*shape, n, 3]
        """
        x = (
            torch.randn(
                (b, n, self.dim),
                device=device,
            )
            * self.scale_ref
        )
        return self._mask_and_zero_com(x, mask)

    def interpolate(
        self,
        x0: Tensor,
        x1: Tensor,
        t: Tensor,
        mask: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Interpolates between rigids x_0 (base) and x_1 (data) using t.

        Args:
            x0: Tensor sampled from reference, shape [*, n, 3]
            x1: Tensor sampled from target, shape [*, n, 3]
            t: Interpolation times, shape [*]
            mask (optional): Binary mask, shape [*, n]

        Returns:
            x_t: Interpolated tensor, shape [*, n, 3]
        """
        x0, x1 = map(
            lambda args: self._mask_and_zero_com(*args),
            ((x0, mask), (x1, mask)),
        )

        n = x0.shape[-2]
        t = self._extend_t(n, t)  # [*, n]
        t = t[..., None]  # [*, n, 1]
        mu_t = (1.0 - t) * x0 + t * x1
        eps_t = torch.randn_like(mu_t)
        eps_t = self._mask_and_zero_com(eps_t, mask)
        x = mu_t + self.sigma * eps_t
        return x, mu_t, eps_t

    def xt_dot(
        self,
        x_1: Tensor,
        x_t: Tensor,
        t: Tensor,
        mask: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Computes dx_t / dt for the interpolation scheme defined
        above. This is the target used in flow matching loss.

        Args:
            x_1: Sample tensor from target, shape [*, n, 3]
            x_t: Interpolated tensor, shape [*, n, 3]
            t: Interpolation times, shape [*]
            mask (optional): Binary mask of shape [*, n]

        Returns:
            dx_t / dt, with shapes [*, n, 3].
        """
        x_1, x_t = map(
            lambda args: self._mask_and_zero_com(*args),
            ((x_1, mask), (x_t, mask)),
        )

        n = x_1.shape[-2]
        t = self._extend_t(n, t)  # [*, n]
        t = t[..., None]  # [*, n, 1]
        x_t_dot = (x_1 - x_t) / (1.0 - t)
        return x_t_dot

    def vf_to_score(
        self,
        x_t: Tensor,
        v: Tensor,
        t: Tensor,
    ):
        """
        Compute score of noisy density given the vector field learned by flow matching. With
        our interpolation scheme these are related by

        v(x_t, t) = (1 / t) (x_t + scale_ref ** 2 * (1 - t) * s(x_t, t)),

        or equivalently,

        s(x_t, t) = (t * v(x_t, t) - x_t) / (scale_ref ** 2 * (1 - t)).

        Args:
            x_t: Noisy sample, shape [*, dim]
            v: Vector field, shape [*, dim]
            t: Interpolation time, shape [*]

        Returns:
            Score of intermediate density, shape [*, dim].
        """
        assert torch.all(t < 1.0), "vf_to_score requires t < 1 (strict)"
        num = t[..., None] * v - x_t  # [*, n, 3]
        den = (1.0 - t)[..., None] * self.scale_ref**2  # [*, n, 1]
        score = num / den
        return score  # [*, dim]

    def get_gt(
        self,
        t: Tensor,
        eps: float = 1e-2,
    ) -> Tensor:
        """
        Computes gt = 1 / (t + eps).

        Args:
            t: times where we'll evaluate, covers [0, 1), shape [nsteps]
            eps: small value leave as it is

        Returns
        """
        # Numerical reasons for some schedule
        t = torch.clamp(t, 0, 1 - 1e-5)
        num = 1.0
        den = t
        gt = num / (den + eps)
        return gt  # [s]

    def get_schedule(self, mode: str, nsteps: int, p1: float = None):
        if mode == "uniform":
            t = torch.linspace(0, 1, nsteps + 1)
            return t
        elif mode == "log":
            assert p1 is not None, "p1 cannot be none for the log schedule"
            assert p1 > 0, f"p1 must be > 0 for the log schedule, got {p1}"
            t = 1.0 - torch.logspace(-p1, 0, nsteps + 1).flip(0)
            t = t - torch.min(t)
            t = t / torch.max(t)
            return t
        else:
            # Should not get here
            raise NotImplementedError(f"Schedule mode not recognized {mode}")

    def full_simulation(
        self,
        predict_clean_n_v: Callable,
        dt: float,
        nsamples: int,
        n: int,
        device: torch.device,
        mask: Tensor,
        schedule_mode: Literal["uniform", "log"],
        schedule_p: float,
        sampling_mode: str,
        sc_scale_noise: float,
    ) -> Tensor:
        """
        Generates samples by simulating the full process starting from
        t=0 up to t=1.

        Args:
            predict_clean_n_v: A function that predicts clean sample and vector field
                takes as input a dictionary with keys:
                    - "x_t"
                    - "t"
                    - "mask"
                with values the corresponding tensors
            dt: step-size, float
            nsamples: number of samples to generate, int
            n: protein length
            mask: Binary mask of shape [*, n]
            schedule_mode: "uniform", "log"
            schedule_p: parameter of the schedule for the times covering [0, 1]
                uniform: this is ignored, points uniformly spaced
                log: ts = (1.0 - np.logspace(schedule_p, 0, num=nsteps)[::-1])
                     ts = ts / ts[-1]  # to make sure it goes exactly from 0 to 1

        Returns:
            Batch of generated samples [nsamples, n, ...]
        """
        assert mask.shape == (nsamples, n)

        nsteps = math.ceil(1.0 / dt)
        ts = self.get_schedule(
            mode=schedule_mode,
            nsteps=nsteps,
            p1=schedule_p,
        )
        # [nsteps + 1], first element is 0, last element is 1
        # we eval the learned vector field on the values ts[:-1], ie we
        # do not evaluate at 1. So this defines nsteps steps.

        t_eval = ts[:-1]  # [nsteps], last one is 1 not used to eval but to define dt
        gt = self.get_gt(t=t_eval)

        with torch.no_grad():
            x = self.sample_noise(
                n=n, b=nsamples, device=device, mask=mask
            )  # [nsamples, n, 3]

            for step in tqdm(range(nsteps)):
                t = ts[step] * torch.ones(nsamples, device=device)  # [nsamples]
                dt = ts[step + 1] - ts[step]  # float
                gt_step = gt[step]  # float

                nn_in = {
                    "x_t": x,
                    "t": t,
                    "mask": mask,
                }

                x_1_pred, v = predict_clean_n_v(nn_in)

                # Accomodate last few steps
                if ts[step] > 0.99:
                    sampling_mode = "vf"

                x, _ = self.simulation_step(
                    x_t=x,
                    v=v,
                    t=t,
                    dt=dt,
                    gt=gt_step,
                    sampling_mode=sampling_mode,
                    sc_scale_noise=sc_scale_noise,
                    mask=mask,
                )
            return x

    def simulation_step(
        self,
        x_t: Tensor,
        v: Tensor,
        t: Tensor,
        dt: float,
        gt: float,
        sampling_mode: Literal["vf", "sc"],
        sc_scale_noise: float,
        mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        Single integration step of ODE dx_t / dt = v(x_t, t) using Euler integration scheme.

        Args:
            x_t: Current values, shape [*, n, 3]
            v: Vector field of shape [*, n, 3]
            t: Current time, shape [*]
            dt: Step-size, float
            sampling_mode: "vf" of "sc", standing for vector field (normal flow matching, eq. (1)) and
                score (introduces score, eq. (2)).
            sc_scale_noise: scale applied to the noise when simulating eq. (2).
            mask (optional): Binary mask of shape [*, n]

        Returns:
            Updated x_t after Euler integration step, shape [*, n, 3]
            Updated time [*]
        """
        v = self._apply_mask(v, mask)
        n = x_t.shape[-2]

        # Euler step
        t_ext = self._extend_t(n, t)  # [*, n]

        x_t_updated, _ = self.step_euler(
            x_t=x_t,
            v=v,
            t=t_ext,
            dt=dt,
            gt=gt,
            sampling_mode=sampling_mode,
            sc_scale_noise=sc_scale_noise,
        )

        return (
            self._mask_and_zero_com(
                x_t_updated, mask
            ),  # Equivalent to centering the update vector since x_t is centered
            t + dt,
        )

    def step_euler(
        self,
        x_t: Tensor,
        v: Tensor,
        t: Tensor,
        dt: float,
        gt: float,
        sampling_mode: Literal["vf", "sc"],
        sc_scale_noise: float,
    ) -> tuple[Tensor, Tensor]:
        """
        Single integration step of ODE

        eq. (1): d x_t = v(x_t, t) dt

        or SDE

        eq. (2): d x_t = [v(x_t, t) + g(t) s(x_t, t)] dt + \sqrt{2g(t)} dw_t

        using Euler integration scheme.

        For our interpolation scheme (i.e. stochastic interpolant) we can obtain
        the score as a function of the vector field from

        v(x_t, t) = (1 / t) (x_t + scale_ref ** 2 * (1 - t) * s(x_t, t)),

        or equivalently,

        s(x_t, t) = (t * v(x_t, t) - x_t) / (scale_ref ** 2 * (1 - t)).

        We add a few additional parameters to the SDE to control noise/score scale and
        perform stochastic and low temperature sampling:

        eq. (3): d x_t = [v(x_t, t) + g(t) * sc_score_scale * s(x_t, t) * sc_score_scale] dt + \sqrt{2 * g(t) * sc_noise_scale} dw_t,

        where g(t) = sc_g * min(5, (1-t)/t).

        At the moment we do not scale the vector field v.

        Args:
            x_t: Current value, shape [*, n, 3]
            v: Vector field, shape [*, n, 3]
            t: Current time, shape [*, n]
            dt: Step-size, float
            sampling_mode: "vf" of "sc", standing for vector field (normal flow matching, eq. (1)) and
                score (introduces score, eq. (2)).
            sc_scale_noise: scale applied to the noise when simulating eq. (2),

        Returns:
            Updated values for x_t after an Euler integration step, shape [*, n, 3].
            Updated time [*, n]
        """
        assert sampling_mode in [
            "vf",
            "sc",
        ], f"Invalid sampling mode {sampling_mode}, should be `vf` or `sc`"
        assert (
            sc_scale_noise >= 0
        ), f"Scale noise for sampling should be >= 0, got {sc_scale_noise}"

        assert gt >= 0, f"gt for sampling should be >= 0, got {gt}"
        t_element = t.flatten()[0]
        assert torch.all(
            t_element == t
        ), "Sampling only implemented for same time for all samples"
        # The last few steps are always taken with eq. (1).

        if sampling_mode == "vf" or t_element > 1.0:
            return x_t + v * dt, t + dt

        if sampling_mode == "sc":
            score = self.vf_to_score(x_t, v, t)  # get score from v, [*, dim]
            eps = torch.randn(x_t.shape, dtype=x_t.dtype, device=x_t.device)  # [*, dim]
            std_eps = torch.sqrt(2 * gt * sc_scale_noise * dt)
            delta_x = (v + gt * score) * dt + std_eps * eps
            return x_t + delta_x, t + dt
