import math
from typing import Callable, Optional, Union, Tuple, Sequence
from torchdiffeq import odeint
import torch
from torch import Tensor

from flow_matching.solver.solver import Solver
from flow_matching.utils import ModelWrapper
from flow_matching.utils.manifolds import geodesic, Manifold
from flow_matching.utils import gradient

try:
    from tqdm import tqdm

    TQDM_AVAILABLE = True
except ImportError:
    TQDM_AVAILABLE = False


class RiemannianODESolver(Solver):
    r"""Riemannian ODE solver
    Initialize the ``RiemannianODESolver``.

    Args:
        manifold (Manifold): the manifold to solve on.
        velocity_model (ModelWrapper): a velocity field model receiving :math:`(x,t)`
            and returning :math:`u_t(x)` which is assumed to lie on the tangent plane at `x`.
    """

    def __init__(self, manifold: Manifold, velocity_model: ModelWrapper):
        super().__init__()
        self.manifold = manifold
        self.velocity_model = velocity_model

    def sample(
        self,
        x_init: Tensor,
        step_size: float,
        projx: bool = True,
        proju: bool = True,
        method: str = "euler",
        time_grid: Tensor = torch.tensor([0.0, 1.0]),
        return_intermediates: bool = False,
        verbose: bool = False,
        enable_grad: bool = False,
        **model_extras,
    ) -> Tensor:
        r"""Solve the ODE with the `velocity_field` on the manifold.

        Args:
            x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`).
            step_size (float): The step size.
            projx (bool): Whether to project the point onto the manifold at each step. Defaults to True.
            proju (bool): Whether to project the vector field onto the tangent plane at each step. Defaults to True.
            method (str): One of ["euler", "midpoint", "rk4"]. Defaults to "euler".
            time_grid (Tensor, optional): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
            return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
            verbose (bool, optional): Whether to print progress bars. Defaults to False.
            enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
            **model_extras: Additional input for the model.

        Returns:
            Tensor: The sampled sequence. Defaults to returning samples at :math:`t=1`.

        Raises:
            ImportError: To run in verbose mode, tqdm must be installed.
        """
        step_fns = {
            "euler": _euler_step,
            "midpoint": _midpoint_step,
            "rk4": _rk4_step,
        }
        assert method in step_fns.keys(), f"Unknown method {method}"
        step_fn = step_fns[method]

        def velocity_func(x, t):
            return self.velocity_model(x=x, t=t, **model_extras)

        # --- Factor this out.
        time_grid = torch.sort(time_grid.to(device=x_init.device)).values

        if step_size is None:
            # If step_size is None then set the t discretization to time_grid.
            t_discretization = time_grid
            n_steps = len(time_grid) - 1
        else:
            # If step_size is float then t discretization is uniform with step size set by step_size.
            t_init = time_grid[0].item()
            t_final = time_grid[-1].item()
            assert (
                t_final - t_init
            ) > step_size, f"Time interval [min(time_grid), max(time_grid)] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."

            n_steps = math.ceil((t_final - t_init) / step_size)
            t_discretization = torch.tensor(
                [step_size * i for i in range(n_steps)] + [t_final],
                device=x_init.device,
            )
        # ---
        t0s = t_discretization[:-1]

        if verbose:
            if not TQDM_AVAILABLE:
                raise ImportError(
                    "tqdm is required for verbose mode. Please install it."
                )
            t0s = tqdm(t0s)

        if return_intermediates:
            xts = []
            i_ret = 0

        with torch.set_grad_enabled(enable_grad):
            xt = x_init
            for t0, t1 in zip(t0s, t_discretization[1:]):
                dt = t1 - t0
                xt_next = step_fn(
                    velocity_func,
                    xt,
                    t0,
                    dt,
                    manifold=self.manifold,
                    projx=projx,
                    proju=proju,
                )
                if return_intermediates:
                    while (
                        i_ret < len(time_grid)
                        and t0 <= time_grid[i_ret]
                        and time_grid[i_ret] <= t1
                    ):
                        xts.append(
                            interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret])
                        )
                        i_ret += 1
                xt = xt_next

        if return_intermediates:
            return torch.stack(xts, dim=0)
        else:
            return xt

    def compute_likelihood(
        self,
        x_1: Tensor,
        log_p0: Callable[[Tensor], Tensor],
        step_size: Optional[float],
        method: str = "euler",
        atol: float = 1e-5,
        rtol: float = 1e-5,
        time_grid: Tensor = torch.tensor([1.0, 0.0]),
        return_intermediates: bool = False,
        exact_divergence: bool = False,
        enable_grad: bool = False,
        **model_extras,
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]:
        r"""Solve for log likelihood given a target sample at :math:`t=0`.

        Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x.
        The function assumes log_p0 is the log probability of the source distribution at :math:`t=0`.

        Args:
            x_1 (Tensor): target sample (e.g., samples :math:`X_1 \sim p_1`).
            log_p0 (Callable[[Tensor], Tensor]): Log probability function of the source distribution.
            step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
            method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
            atol (float): Absolute tolerance, used for adaptive step solvers.
            rtol (float): Relative tolerance, used for adaptive step solvers.
            time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).
            return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.
            exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
            enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
            **model_extras: Additional input for the model.

        Returns:
            Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: Samples at time_grid and log likelihood values of given x_1.
        """
        assert (
            time_grid[0] == 1.0 and time_grid[-1] == 0.0
        ), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}"

        time_grid = time_grid.to(device=x_1.device)

        # Fix the random projection for the Hutchinson divergence estimator
        if not exact_divergence:
            z = (torch.randn_like(x_1).to(x_1.device) < 0) * 2.0 - 1.0

        def ode_func(x, t):
            return self.velocity_model(x=x, t=t, **model_extras)

        def dynamics_func(t, states):
            xt = self.manifold.projx(states[0])
            with torch.set_grad_enabled(True):
                xt.requires_grad_()
                ut = ode_func(xt, t)

                if exact_divergence:
                    # Compute exact divergence
                    div = 0
                    for i in range(ut.flatten(1).shape[1]):
                        div += gradient(ut[:, i], xt, create_graph=True)[:, i]
                else:
                    # Compute Hutchinson divergence estimator E[z^T D_x(ut) z]
                    z_tan = self.manifold.proju(xt, z)
                    ut_dot_z = torch.einsum(
                        "bi,bi->b", ut.flatten(start_dim=1), z_tan.flatten(start_dim=1)
                    )
                    grad_ut_dot_z = gradient(ut_dot_z, xt)
                    div = torch.einsum(
                        "bi,bi->b",
                        grad_ut_dot_z.flatten(start_dim=1),
                        z_tan.flatten(start_dim=1),
                    )

            return ut.detach(), div.detach()

        y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
        ode_opts = {"step_size": step_size} if step_size is not None else {}

        with torch.set_grad_enabled(enable_grad):
            sol, log_det = odeint(
                dynamics_func,
                y_init,
                time_grid,
                method=method,
                options=ode_opts,
                atol=atol,
                rtol=rtol,
            )

        x_source = sol[-1]
        source_log_p = log_p0(x_source)

        if return_intermediates:
            return sol, source_log_p + log_det[-1]
        else:
            return sol[-1], source_log_p + log_det[-1]

def interp(manifold, xt, xt_next, t, t_next, t_ret):
    return geodesic(manifold, xt, xt_next)(
        (t_ret - t) / (t_next - t).reshape(1)
    ).reshape_as(xt)


def _euler_step(
    velocity_model: Callable,
    xt: Tensor,
    t0: Tensor,
    dt: Tensor,
    manifold: Manifold,
    projx: bool = True,
    proju: bool = True,
) -> Tensor:
    r"""Perform an Euler step on a manifold.

    Args:
        velocity_model (Callable): the velocity model
        xt (Tensor): tensor containing the state at time t0
        t0 (Tensor): the time at which this step is taken
        dt (Tensor): the step size
        manifold (Manifold): a manifold object
        projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
        proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.

    Returns:
        Tensor: tensor containing the state after the step
    """
    velocity_fn = lambda x, t: (
        manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
    )
    projx_fn = lambda x: manifold.projx(x) if projx else x

    vt = velocity_fn(xt, t0)

    xt = xt + dt * vt

    return projx_fn(xt)


def _midpoint_step(
    velocity_model: Callable,
    xt: Tensor,
    t0: Tensor,
    dt: Tensor,
    manifold: Manifold,
    projx: bool = True,
    proju: bool = True,
) -> Tensor:
    r"""Perform a midpoint step on a manifold.

    Args:
        velocity_model (Callable): the velocity model
        xt (Tensor): tensor containing the state at time t0
        t0 (Tensor): the time at which this step is taken
        dt (Tensor): the step size
        manifold (Manifold): a manifold object
        projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
        proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.

    Returns:
        Tensor: tensor containing the state after the step
    """
    velocity_fn = lambda x, t: (
        manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
    )
    projx_fn = lambda x: manifold.projx(x) if projx else x

    half_dt = 0.5 * dt
    vt = velocity_fn(xt, t0)
    x_mid = xt + half_dt * vt
    x_mid = projx_fn(x_mid)

    xt = xt + dt * velocity_fn(x_mid, t0 + half_dt)

    return projx_fn(xt)


def _rk4_step(
    velocity_model: Callable,
    xt: Tensor,
    t0: Tensor,
    dt: Tensor,
    manifold: Manifold,
    projx: bool = True,
    proju: bool = True,
) -> Tensor:
    r"""Perform an RK4 step on a manifold.

    Args:
        velocity_model (Callable): the velocity model
        xt (Tensor): tensor containing the state at time t0
        t0 (Tensor): the time at which this step is taken
        dt (Tensor): the step size
        manifold (Manifold): a manifold object
        projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
        proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.

    Returns:
        Tensor: tensor containing the state after the step
    """
    velocity_fn = lambda x, t: (
        manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
    )
    projx_fn = lambda x: manifold.projx(x) if projx else x

    k1 = velocity_fn(xt, t0)
    k2 = velocity_fn(projx_fn(xt + dt * k1 / 3), t0 + dt / 3)
    k3 = velocity_fn(projx_fn(xt + dt * (k2 - k1 / 3)), t0 + dt * 2 / 3)
    k4 = velocity_fn(projx_fn(xt + dt * (k1 - k2 + k3)), t0 + dt)

    return projx_fn(xt + (k1 + 3 * (k2 + k3) + k4) * dt * 0.125)