# from https://github.com/facebookresearch/flow_matching/blob/main/flow_matching/solver/ode_solver.py
import torch
import torch.nn as nn

from torch import Tensor
from torchdiffeq import odeint_adjoint
from typing import Optional, Union, Sequence, Callable, Tuple
from src.models.fm.r3n_fm import R3NFlowMatcher


from src.utils.data_utils import get_trace_jacobian


class DensityDynamics(nn.Module):
    def __init__(
        self,
        velocity_model: nn.Module,
        fm: R3NFlowMatcher,
        compute_divergence=False,
        exact_divergence=False,
        z=None,
        model_extras=None,
    ):
        super().__init__()
        self.velocity_model = velocity_model
        self.fm = fm
        self._compute_divergence = compute_divergence
        self._exact_divergence = exact_divergence
        self._z = z
        self.model_extras = model_extras or {}

    def forward(self, t, state):
        if self._compute_divergence:
            x, _, _ = state
            t_repeat = t.repeat_interleave(x.shape[0])
            with torch.set_grad_enabled(True):
                x.requires_grad_(True)
                vt_input = {"x_t": x, "t": t_repeat, **self.model_extras}
                ut = self.velocity_model(vt_input)["coors_pred"]
                ut = self.fm._mask_and_zero_com(ut, self.model_extras["mask"])
                # TODO: could not reproduce the same results for ESS
                # as in the paper EFM
                divergence = get_trace_jacobian(ut, x, self._exact_divergence, self._z)

            return ut, divergence, torch.norm(ut, dim=-1)
        else:
            x = state
            t_repeat = t.repeat_interleave(x.shape[0])
            vt_input = {"x_t": x, "t": t_repeat, **self.model_extras}
            ut = self.velocity_model(vt_input)["coors_pred"]
            ut = self.fm._mask_and_zero_com(ut, self.model_extras["mask"])
            return ut.view(x.shape)


class ODESolver(nn.Module):
    def __init__(self, velocity_model: nn.Module, fm: nn.Module):
        super().__init__()
        self._velocity_model = velocity_model
        self._fm = fm

    def sample(
        self,
        x_init: Tensor,
        compute_divergence: bool = False,
        exact_divergence: bool = False,
        step_size: Optional[float] = None,
        method: str = "euler",
        atol: float = 1e-5,
        rtol: float = 1e-5,
        time_grid: Tensor = torch.tensor([0.0, 1.0]),
        return_intermediates: bool = False,
        enable_grad: bool = False,
        **model_extras,
    ) -> Union[Tensor, Sequence[Tensor]]:
        r"""Solve the ODE with the velocity field.

        Args:
            x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). Shape: [batch_size, ...].
            step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
            method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
            atol (float): Absolute tolerance, used for adaptive step solvers.
            rtol (float): Relative tolerance, used for adaptive step solvers.
            time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
            return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
            enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
            **model_extras: Additional input for the model.

        Returns:
            Union[Tensor, Sequence[Tensor]]: The last timestep when return_intermediates=False, otherwise all values specified in time_grid.
        """

        time_grid = time_grid.to(x_init.device)
        ode_opts = {"step_size": step_size} if step_size is not None else {}
        z = None

        if not exact_divergence:
            z = (torch.randn_like(x_init).to(x_init.device) < 0) * 2.0 - 1.0

        dynamics_func = DensityDynamics(
            self._velocity_model,
            self._fm,
            compute_divergence=compute_divergence,
            exact_divergence=exact_divergence,
            z=z,
            model_extras=model_extras,
        )

        if compute_divergence:
            y_init = (
                x_init,
                torch.zeros(x_init.shape[0], device=x_init.device),
                torch.zeros(x_init.shape[0], device=x_init.device),
            )
        else:
            y_init = x_init

        with torch.set_grad_enabled(enable_grad):
            result = odeint_adjoint(
                dynamics_func,
                y_init,
                time_grid,
                method=method,
                options=ode_opts,
                atol=atol,
                rtol=rtol,
            )

        if compute_divergence:
            x, logd, path_length = result
            return (
                (x, logd, path_length)
                if return_intermediates
                else (x[-1], logd[-1], path_length[-1])
            )
        else:
            return result if return_intermediates else result[-1]

    def compute_likelihood(
        self,
        x_1: Tensor,
        log_p0: Callable[[Tensor], Tensor],
        step_size: Optional[float] = None,
        method: str = "euler",
        atol: float = 1e-5,
        rtol: float = 1e-5,
        time_grid: Tensor = torch.tensor([1.0, 0.0]),
        return_intermediates: bool = False,
        exact_divergence: bool = False,
        enable_grad: bool = False,
        **model_extras,
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]:
        r"""Solve for log likelihood given a target sample at :math:`t=0`.

        Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x.
        The function assumes log_p0 is the log probability of the source distribution at :math:`t=0`.

        Args:
            x_1 (Tensor): target sample (e.g., samples :math:`X_1 \sim p_1`).
            log_p0 (Callable[[Tensor], Tensor]): Log probability function of the source distribution.
            step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
            method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
            atol (float): Absolute tolerance, used for adaptive step solvers.
            rtol (float): Relative tolerance, used for adaptive step solvers.
            time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).
            return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.
            exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
            enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
            **model_extras: Additional input for the model.

        Returns:
            Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: Samples at time_grid and log likelihood values of given x_1.
        """
        assert (
            time_grid[0] == 1.0 and time_grid[-1] == 0.0
        ), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}"

        time_grid = time_grid.to(x_1.device)
        x_1 = x_1.view(x_1.size(0), -1)

        if exact_divergence:
            z = None
        else:
            z = (torch.randn_like(x_1).to(x_1.device) < 0) * 2.0 - 1.0

        dynamics_func = DensityDynamics(
            self._velocity_model,
            compute_divergence=True,
            exact_divergence=exact_divergence,
            z=z,
            model_extras=model_extras,
        )

        y_init = (
            x_1,
            torch.zeros(x_1.shape[0], device=x_1.device),
            torch.zeros(x_1.shape[0], device=x_1.device),
        )

        ode_opts = {"step_size": step_size} if step_size is not None else {}

        with torch.set_grad_enabled(enable_grad):
            sol, negative_logd, path_length = odeint_adjoint(
                dynamics_func,
                y_init,
                time_grid,
                method=method,
                options=ode_opts,
                atol=atol,
                rtol=rtol,
            )

        x_source = sol[-1]
        source_log_p = log_p0(x_source)

        if return_intermediates:
            return sol, source_log_p + negative_logd, path_length
        else:
            return sol[-1], source_log_p + negative_logd[-1], path_length[-1]
