"""
This module provides tools to simulate Stochastic Differential Equations (SDEs) using explicit, implicit and JKO schemes.
It includes functionality to generate trajectories based on specified models, potentials, internal energies, and interactions.
The simulations are performed using ``JAX``, which enables efficient computations and automatic differentiation.

Classes
-----------
- ``SDESimulator``
    Simulates SDEs using an explicit scheme, allowing for the application of potential, internal energy, and interaction
    components. The class supports forward sampling of trajectories based on the initial condition and a JAX random key.

- ``SDESimulator_implicit_time``
    Simulates SDEs using implicit time-stepping methods. This class is designed to handle time-varying potentials
    and performs fixed-point iterations to account for implicit dynamics. It supports forward sampling similar to ``SDESimulator``.

- ``JKOSimulator``
    Simulates SDEs using Jordan-Kinderlehrer-Otto (JKO) scheme with learnable mappings via neural networks.
    It supports forward sampling similar to ``SDESimulator``.

Functions
-------------
- ``get_SDE_predictions``
    A helper function to choose between explicit and implicit SDE simulators based on the model type and return
    the simulated trajectories.

Usage example
-------------
To simulate trajectories using an explicit SDE simulator:

    >>> import jax.numpy as jnp
    >>> import jax.random as jrandom
    >>> key = jrandom.PRNGKey(0)
    >>> init_pp = jnp.array([[0.0, 0.0]])
    >>> sde_simulator = SDESimulator(dt=0.01, n_timesteps=100, start_timestep=0, potential=False, internal=0.1, interaction=False)
    >>> trajectories = sde_simulator.forward_sampling(key, init_pp)
    >>> print(trajectories.shape)  # Output: (101, 1, 2)

For implicit time-stepping with a time-dependent potential:

    >>> potential_func = lambda x: 0.5 * jnp.sum(jnp.square(x))
    >>> implicit_simulator = SDESimulator_implicit_time(dt=0.01, n_timesteps=100, start_timestep=0, potential=potential_func, internal=False, interaction=False)
    >>> trajectories = implicit_simulator.forward_sampling(key, init_pp)
    >>> print(trajectories.shape)  # Output: (101, 1, 2)

# TODO: add notes

References
----------
- `jax`: https://github.com/google/jax
- Stochastic Differential Equations (SDE): https://en.wikipedia.org/wiki/Stochastic_differential_equation
"""

from typing import Callable

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
from flax.training.train_state import TrainState
from jax import lax
from tqdm import trange

from networks.maps import TransportMap


def get_SDE_predictions(
    model: str,
    dt: float,
    n_timesteps: int,
    start_timestep: int,
    potential: bool | Callable[[jnp.ndarray], jnp.ndarray],
    internal: bool | Callable[[jnp.ndarray], jnp.ndarray] | float,
    interaction: bool | Callable[[jnp.ndarray], jnp.ndarray],
    key: jax.random.PRNGKey,
    init_pp: jnp.ndarray,
    simulator: int = "forward",
) -> jnp.ndarray:
    """
    Get predictions from a Stochastic Differential Equation (SDE) simulator based on the specified model type.

    Depending on the model type, it selects the appropriate SDE simulator (`SDESimulator` or `SDESimulator_implicit_time`)
    and performs forward sampling to generate predictions.

    Parameters
    ----------
    model : str
        The name of the model to use for simulation. If 'jkonet-star-time-potential' is specified,
        `SDESimulator_implicit_time` is used; otherwise, `SDESimulator` is used.
    dt : float
        The timestep size for the SDE simulation.
    n_timesteps : int
        The total number of timesteps to simulate.
    start_timestep : int
        The initial timestep index for the simulation.
    potential : bool | Callable[[jnp.ndarray], jnp.ndarray]
        If `True`, a potential function is used in the simulation. If a callable is provided, it should accept
        a JAX array as input and return the potential. If `False`, no potential is applied.
    internal :  bool | Callable[[jnp.ndarray], jnp.ndarray] | float
        If a float, represents the internal energy scale used in the simulation. If a callable is provided, it should
        accept a JAX array, returning the internal component. If `False`, no internal component is used.
    interaction : bool | Callable[[jnp.ndarray], jnp.ndarray]
        If `True`, an interaction function is used in the simulation. If a callable is provided, it should accept
        a JAX array as input and return the interaction component. If `False`, no interaction is applied.
    key : jax.random.PRNGKey
        A JAX random key used for stochastic processes in the SDE simulation.
    init_pp : jnp.ndarray
        The initial state for the simulation, typically a JAX array representing the starting point of the system.
    simulator : str
        Simulator type to use for simulation. If 'forward' is specified,
        `SDESimulator_implicit_time` is used; otherwise, `SDESimulator` is used.

    Returns
    -------
    jnp.ndarray
        An array representing the simulated trajectories of the system, with shape (n_timesteps + 1, ...),
        where the first dimension corresponds to the timesteps and the remaining dimensions correspond
        to the state variables of the system.
    """
    simulator_map = {"forward": SDESimulator, "backward": SDEBackwardEulerSampler}

    if model in ["jkonet-star-time-potential", "inverse-jkonet-time-potential","inverse-jkonet-multimap-time-potential"]:
        sde = SDESimulator_implicit_time
    else:
        if simulator not in simulator_map:
            raise ValueError(f"Unknown simulator type: {simulator}!")
        sde = simulator_map[simulator]

    return sde(dt, n_timesteps, start_timestep, potential, internal, interaction).forward_sampling(key, init_pp)


class SDESimulator:
    """
    Simulator for Stochastic Differential Equations (SDEs) with an explicit scheme.

    Parameters
    ----------
    dt : float
        The timestep size for the simulation.

    n_timesteps : int
        The number of timesteps to simulate.

    start_timestep : int
        The initial timestep index for the simulation.

    potential : bool | Callable[[jnp.ndarray], jnp.ndarray]
        If a callable, it should take a JAX array as input and return the potential. If `False`,
        no potential is applied.

    internal : bool | Callable[[jnp.ndarray], jnp.ndarray] | float
        If a callable, it should take a JAX array as input and return the internal component. If `False`,
        no internal component is used.

    interaction : bool | Callable[[jnp.ndarray], jnp.ndarray]
        If a callable, it should take a JAX array as input and return the interaction component. If `False`,
        no interaction is applied.
    Methods
    -------
    forward_sampling(key: jax.random.PRNGKey, init: jnp.ndarray) -> jnp.ndarray
        Performs forward sampling of the SDE from the initial condition `init` using the provided random key.

    """

    def __init__(
        self,
        dt: float,
        n_timesteps: int,
        start_timestep: int,
        potential: bool | Callable[[jnp.ndarray], jnp.ndarray],
        internal: bool | Callable[[jnp.ndarray], jnp.ndarray] | float,
        interaction: bool | Callable[[jnp.ndarray], jnp.ndarray],
    ):

        sqrtdt = jnp.sqrt(2 * dt)
        potential_component = lambda pp, key: jnp.zeros(pp.shape)
        internal_component = lambda pp, key: jnp.zeros(pp.shape)
        interaction_component = lambda pp, key: jnp.zeros(pp.shape)

        if potential:
            potential_grad = jax.grad(potential)
            flow = jax.vmap(lambda v: -potential_grad(v))
            potential_component = lambda pp, key: flow(pp) * dt

        if internal:
            # At the moment we use wiener process
            if not isinstance(internal, float):
                raise NotImplementedError("Generic internal energies not implemented yet.")

            internal_component = (
                lambda pp, key: -jnp.sqrt(jnp.abs(internal)) * jrandom.normal(key, shape=pp.shape) * sqrtdt
            )

        if isinstance(interaction, Callable):

            interaction_grad = lambda v: jax.grad(interaction)(v)
            interaction_grad_vmap = jax.vmap(interaction_grad)

            def get_interaction_component(pp):
                def W_fn(p):
                    forw = -interaction_grad_vmap(p - pp)
                    back = interaction_grad_vmap(pp - p)
                    W_biased_sum = jnp.sum(forw + back, axis=0)
                    assert W_biased_sum.shape == p.shape
                    bs = pp.shape[0]
                    return W_biased_sum / (bs - 1.0)

                return W_fn

            interaction_component = lambda pp, _: jax.vmap(get_interaction_component(pp))(pp) * dt

        def forward_sampling(key: jax.random.PRNGKey, init: jnp.ndarray) -> jnp.ndarray:
            """
            Performs forward sampling of the SDE from the initial condition.

            Parameters
            ----------
            key : jax.random.PRNGKey
                Random key used for sampling.

            init : jnp.ndarray
                Initial condition for the simulation.

            Returns
            -------
            jnp.ndarray
                The array of simulated trajectories with shape (n_timesteps + 1, ...) where
                the first dimension represents the timestep and the remaining dimensions
                represent the state variables.
            """
            pp = jnp.copy(init)
            trajectories = [pp]
            for i in range(1, n_timesteps + 1):
                key, subkey = jrandom.split(key, 2)
                _pot = potential_component(pp, subkey)
                _int = internal_component(pp, subkey)
                _inter = interaction_component(pp, subkey)
                assert _pot.shape == pp.shape
                assert _int.shape == pp.shape
                assert _inter.shape == pp.shape
                # pp = pp + potential_component(pp, subkey) + internal_component(pp, subkey) + interaction_component(pp, subkey)
                pp = pp + _pot + _int + _inter
                trajectories.append(pp)
            return jnp.asarray(trajectories)

        self.forward_sampling = jax.jit(forward_sampling)


class SDESimulator_implicit_time:
    """
    Simulator for Stochastic Differential Equations (SDEs) using implicit methods.

    Parameters
    ----------
    dt : float
        The timestep size for the simulation.

    n_timesteps : int
        The number of timesteps to simulate.

    start_timestep : int
        The initial timestep index for the simulation. In the case that we are working with time-varying potentials
        the start time of the simulation is necessary.

    potential : bool | Callable[[jnp.ndarray], jnp.ndarray]
        If a callable, it should take a JAX array and a time array
        as input and return the potential. If `False`, no potential is applied.

    internal : bool | Callable[[jnp.ndarray], jnp.ndarray] | float
        If a callable, it should take a JAX array as input and return the internal component. If `False`,
        no internal component is used.

    interaction : bool | Callable[[jnp.ndarray], jnp.ndarray]
        If a callable, it should take a JAX array as input and return the interaction component. If `False`,
        no interaction is applied.

    Methods
    -------
    forward_sampling(key: jax.random.PRNGKey, init: jnp.ndarray) -> jnp.ndarray
        Performs forward sampling of the SDE from the initial condition `init` using the provided random key.
    """

    def __init__(
        self,
        dt: float,
        n_timesteps: int,
        start_timestep: int,
        potential: bool | Callable[[jnp.ndarray], jnp.ndarray],
        internal: bool | Callable[[jnp.ndarray], jnp.ndarray] | float,
        interaction: bool | Callable[[jnp.ndarray], jnp.ndarray],
    ):
        self.dt = dt
        self.n_timesteps = n_timesteps
        self.potential = potential
        self.sqrtdt = jnp.sqrt(2 * dt)

        def potential_component_implicit(pp: jnp.ndarray, t_array: jnp.ndarray, key: jnp.ndarray) -> jnp.ndarray:
            """
            Computes the implicit potential component using fixed-point iterations.

            Parameters
            ----------
            pp : jnp.ndarray
                The current state of the simulation.

            t_array : jnp.ndarray
                The time array for the current step.

            key : jax.random.PRNGKey
                Random key used for sampling.

            Returns
            -------
            jnp.ndarray
                The implicit potential component to be added to the state.
            """
            if self.potential:

                def fixed_point_iteration(x, pp, t_array):
                    concat_pos_time = jnp.concatenate([x, t_array], axis=-1)
                    gradient = jax.vmap(jax.grad(potential))(concat_pos_time)
                    assert pp.shape[-1] == gradient.shape[-1] - 1
                    return pp - gradient[..., :-1] * dt

                # Initial guess for implicit method
                x = pp
                for _ in range(50):  # Perform fixed-point iterations
                    x = fixed_point_iteration(x, pp, t_array)

                return x - pp
            else:
                return jnp.zeros(pp.shape)

        def forward_sampling(key, init):
            """
            Performs forward sampling of the SDE from the initial condition.

            Parameters
            ----------
            key : jax.random.PRNGKey
                Random key used for sampling.

            init : jnp.ndarray
                Initial condition for the simulation.

            Returns
            -------
            jnp.ndarray
                The array of simulated trajectories with shape (n_timesteps + 1, ...) where
                the first dimension represents the timestep and the remaining dimensions
                represent the state variables.
            """
            pp = jnp.copy(init)
            trajectories = [pp]
            for i in range(start_timestep, start_timestep + n_timesteps):
                # for i in range(start_timestep, start_timestep + n_timesteps * timestep, timestep):
                key, subkey = jrandom.split(key, 2)
                t_array = (i) * jnp.ones((pp.shape[0], 1), dtype=jnp.float32)  # Create time array for current step
                pp = pp + potential_component_implicit(pp, t_array, subkey)
                trajectories.append(pp)
            return jnp.asarray(trajectories)

        self.forward_sampling = jax.jit(forward_sampling)


class SDEBackwardEulerSampler:

    def __init__(
        self,
        dt: float,
        n_timesteps: int,
        start_timestep: int,
        potential: bool | Callable[[jnp.ndarray], jnp.ndarray],
        internal: bool | Callable[[jnp.ndarray], jnp.ndarray] | float,
        interaction: bool | Callable[[jnp.ndarray], jnp.ndarray],
    ):
        self.potential = potential

        def potential_component_implicit(pp: jnp.ndarray, key: jnp.ndarray) -> jnp.ndarray:

            if self.potential:

                def fixed_point_iteration(x, pp):
                    gradient = jax.vmap(jax.grad(potential))(x)
                    return pp - gradient * dt

                def run_fixed_point_iterations(pp, max_iter=1000, epsilon=1e-6):
                    def cond_fn(state):
                        x_prev, x_new, iter_count = state
                        # Continue if:
                        # 1. Norm is >= epsilon AND
                        # 2. Iteration count < max_iter
                        return (jnp.max(jnp.linalg.norm(x_new - x_prev, axis=1)) >= epsilon) & (iter_count < max_iter)

                    def body_fn(state):
                        x_prev, x_new, iter_count = state
                        x_next = fixed_point_iteration(x_new, pp)  # Compute next value
                        return x_new, x_next, iter_count + 1  # Update state

                    # Initialize: x_prev, x_new, iter_count
                    x_init = pp  # Starting point
                    initial_state = (x_init, fixed_point_iteration(x_init, pp), 0)
                    # Run loop
                    final_state = lax.while_loop(cond_fn, body_fn, initial_state)
                    _, x_final, _ = final_state
                    return x_final

                x = pp
                return run_fixed_point_iterations(x) - pp
            else:
                return jnp.zeros(pp.shape)

        def forward_sampling(key: jax.random.PRNGKey, init: jnp.ndarray):

            pp = jnp.copy(init)
            trajectories = [pp]
            for _ in range(start_timestep, start_timestep + n_timesteps):
                key, subkey = jrandom.split(key, 2)
                pp = pp + potential_component_implicit(pp, subkey)
                trajectories.append(pp)
            return jnp.asarray(trajectories)

        self.forward_sampling = jax.jit(forward_sampling)


class SDESimulatorWithJKO:
    def __init__(
        self,
        dt: float,
        n_timesteps: int,
        start_timestep: int,
        potential: bool | Callable[[jnp.ndarray], jnp.ndarray],
        internal: bool | Callable[[jnp.ndarray], jnp.ndarray] | float,
        interaction: bool | Callable[[jnp.ndarray], jnp.ndarray],
        transport_maps: list[TransportMap],
        transport_states: list[TrainState | None],
        optimizer: optax.GradientTransformation = optax.adam(1e-3),
        n_inner_steps: int = 1000,
    ):
        self.dt = dt
        self.n_timesteps = n_timesteps
        self.start_timestep = start_timestep
        self.potential = jax.vmap(potential) if potential else potential
        self.internal = internal if internal else internal
        self.interaction = jax.vmap(interaction) if interaction else interaction

        self.transport_maps = transport_maps
        self.transport_states = transport_states

        assert len(transport_maps) == len(transport_states)
        assert len(transport_maps) == n_timesteps

        self.optimizer = optimizer
        self.n_inner_steps = n_inner_steps

        def _compute_interaction(rho: jnp.ndarray) -> jnp.ndarray:
            """
            Computes MC estimate of interaction energy based on batch `x`
            """
            # interaction
            diffs = rho[:, None, :] - rho[None, :, :]  # shape [N, N, D]
            # TODO: not tested: taking off-diagonal elements
            lower_diffs = diffs[jnp.tril_indices(rho.shape[0], k=-1)]  # shape (N * (N-1) / 2, D)
            interaction = 0.5 * (self.interaction(lower_diffs) + self.interaction(-lower_diffs))
            assert interaction.shape[0] == rho.shape[0] * (rho.shape[0] - 1) // 2
            assert len(interaction.shape) == 1

            return jnp.mean(interaction)

        def _compute_energy(rho: jnp.ndarray) -> jnp.ndarray:
            e = 0.0
            if self.potential:
                e += jnp.mean(self.potential(rho))
            if self.interaction:
                e += _compute_interaction(rho)
            if self.internal:
                raise NotImplementedError("Sampling with internal energy is not implemented yet!")
            return e

        def _loss(params: dict, model: nn.Module, rho_prev: jnp.ndarray) -> jnp.ndarray:
            phi_rho = model.apply({"params": params}, rho_prev)
            rho_new = rho_prev + phi_rho
            transport_cost = jnp.mean(jnp.sum(phi_rho**2, axis=-1)) / (2.0 * self.dt)
            energy = _compute_energy(rho_new)
            return transport_cost + energy

        def forward_sampling(key: jax.random.PRNGKey, init: jnp.ndarray):
            """
            Performs forward sampling using JKO scheme with separate transport maps per timestep.

            Returns
            -------
            jnp.ndarray : (n_timesteps + 1, N, D)
            """
            particles = init
            trajectories = [init]

            for t in range(self.n_timesteps):
                key, subkey = jrandom.split(key)

                model = self.transport_maps[t]
                state = self.transport_states[t]

                if state is None:
                    params = model.init(subkey, particles)["params"]
                    state = TrainState.create(apply_fn=model.apply, params=params, tx=self.optimizer)

                    @jax.jit
                    def step_fn(state: TrainState, rho_prev: jnp.ndarray) -> tuple[TrainState, float]:
                        grads = jax.grad(_loss)(state.params, model, rho_prev)
                        state = state.apply_gradients(grads=grads)
                        loss = _loss(state.params, model, rho_prev)
                        return state, loss

                    pbar = trange(self.n_inner_steps, desc=f"Map {t}")
                    for i in pbar:
                        state, loss = step_fn(state, particles)
                        if i % 100 == 0:  # Update display every 10 steps
                            pbar.set_postfix(loss=float(loss))

                    self.transport_states[t] = state

                phi_x = model.apply({"params": state.params}, particles)
                particles = particles + phi_x
                trajectories.append(particles)

            return jnp.asarray(trajectories, dtype=init.dtype)

        self.forward_sampling = forward_sampling
