import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable
import einops
import equinox as eqx
from abc import ABC, abstractmethod
import diffrax
import plum
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, PyTree

__all__ = ['ode_solve',
           'ODESolverParams',
           'SDESolverParams',
           'sde_sample']

class ODESolverParams(eqx.Module):
  rtol: float = 1e-8
  atol: float = 1e-8
  n_steps: int = 1000
  solver: str = 'dopri5'
  adjoint: str = 'recursive_checkpoint'
  stepsize_controller: str = 'pid'
  max_steps: int = 8192
  throw: bool = True
  progress_meter: str = 'tqdm'

TimeSeries = plum.ModuleType('diffusion_crf.timeseries', 'TimeSeries')
def ode_solve(f: Callable,
              x0: PyTree,
              save_times: Array,
              params: ODESolverParams = ODESolverParams(),
              _only_return_yts: bool = False) -> TimeSeries:
  """Solve the ODE dx/dt = f(t, x) with initial condition x0 at time t0.

  **Arguments**:

  - f: The function defining the ODE.  Should take arguments (t, x) and return dx/dt.
  - x0: The initial condition.
  - save_times: The times at which to save the solution.
  - params: The parameters for the ODE solver.

  **Returns**:

  - PyTree: The solution at the save times.
  """

  @diffrax.ODETerm
  def wrapped_dynamics(t, xts, args):
    return f(t, xts)

  if params.solver == 'dopri5':
    solver = diffrax.Dopri5()
  elif params.solver == 'euler':
    solver = diffrax.Euler()
  elif params.solver == 'kvaerno5':
    solver = diffrax.Kvaerno5()
  else:
    raise ValueError(f"Unknown solver: {params.solver}")

  if params.adjoint == 'recursive_checkpoint':
    adjoint = diffrax.RecursiveCheckpointAdjoint()
  elif params.adjoint == 'direct':
    adjoint = diffrax.DirectAdjoint()
  else:
    raise ValueError(f"Unknown adjoint: {params.adjoint}")

  if params.stepsize_controller == 'pid':
    stepsize_controller = diffrax.PIDController(rtol=params.rtol, atol=params.atol)
  else:
    stepsize_controller = diffrax.ConstantStepSize()

  saveat = diffrax.SaveAt(ts=save_times)
  t0 = save_times[0]
  t1 = save_times[-1]
  dt0 = (t1 - t0)/params.n_steps

  params = eqx.tree_at(lambda x: x.max_steps, params, max(params.max_steps, params.n_steps))

  if params.progress_meter == 'tqdm':
    progress_meter = diffrax.TqdmProgressMeter()
  elif params.progress_meter == 'text':
    progress_meter = diffrax.TextProgressMeter()
  else:
    progress_meter = diffrax.NoProgressMeter()

  sol = diffrax.diffeqsolve(wrapped_dynamics,
                            solver,
                            saveat=saveat,
                            t0=t0,
                            t1=t1,
                            dt0=dt0,
                            y0=x0,
                            args=(),
                            adjoint=adjoint,
                            stepsize_controller=stepsize_controller,
                            max_steps=params.max_steps,
                            throw=params.throw,
                            progress_meter=progress_meter)
  yts = sol.ys
  if _only_return_yts:
    return yts
  from diffusion_crf.timeseries import TimeSeries
  return TimeSeries(save_times, yts)

################################################################################################################

class SDESolverParams(eqx.Module):
  rtol: float = 1e-8
  atol: float = 1e-8
  n_steps: int = 1000
  solver: str = 'shark'
  adjoint: str = 'recursive_checkpoint'
  stepsize_controller: str = 'none'
  max_steps: int = 8192
  throw: bool = True
  progress_meter: bool = False

AbstractSDE = plum.ModuleType('diffusion_crf.sde.sde_base', 'AbstractSDE')
def sde_sample(sde: AbstractSDE,
               x0: Array,
               key: PRNGKeyArray,
               save_times: Array,
               params: SDESolverParams = SDESolverParams()) -> TimeSeries:
  """Sample from an SDE dx/dt = f(t, x) + g(t, x) dW_t.

  **Arguments**:

  - sde: The SDE to sample from.
  - x0: The initial condition.
  - key: The JAX random key.
  - save_times: The times at which to save the solution.
  - params: SDESolverParams: The parameters for the SDE solver.

  **Returns**:

  - Array: The solution at the save times.
  """
  if key.ndim > 1:
    keys = key
    return jax.vmap(sde_sample, in_axes=(None, None, 0, None, None))(sde,
                                                               x0,
                                                               keys,
                                                               save_times,
                                                               params)
  if x0.ndim == 2:
    keys = random.split(key, x0.shape[0])
    return jax.vmap(sde_sample, in_axes=(None, 0, 0, None, None))(sde,
                                                            x0,
                                                            keys,
                                                            save_times,
                                                            params)

  if x0.ndim != 1:
    raise ValueError("Can only call this with unbatched data!  We need a unique key for every data point.")

  sde_params, sde_static = eqx.partition(sde, eqx.is_inexact_array)

  @diffrax.ODETerm
  def wrapped_drift(t, xt, sde_params):
    sde = eqx.combine(sde_params, sde_static)
    return sde.get_drift(t, xt)

  def diffusion_fn(t, xt, sde_params):
    sde = eqx.combine(sde_params, sde_static)
    return sde.get_diffusion_coefficient(t, xt).as_matrix()

  if params.solver == 'euler_heun':
    solver = diffrax.EulerHeun()
  elif params.solver == 'shark':
    solver = diffrax.ShARK()
  elif params.solver == 'reversible_heun':
    solver = diffrax.ReversibleHeun()
  elif params.solver == 'ito_milstein':
    solver = diffrax.ItoMilstein()
  elif params.solver == 'stratonovich_milstein':
    solver = diffrax.StratonovichMilstein()
  elif params.solver == 'spark':
    solver = diffrax.SPaRK()
  else:
    raise ValueError(f"Unknown solver: {params.solver}")

  if params.adjoint == 'recursive_checkpoint':
    adjoint = diffrax.RecursiveCheckpointAdjoint()
  elif params.adjoint == 'direct':
    adjoint = diffrax.DirectAdjoint()
  else:
    raise ValueError(f"Unknown adjoint: {params.adjoint}")

  if params.stepsize_controller == 'pid':
    stepsize_controller = diffrax.PIDController(rtol=params.rtol, atol=params.atol)
  else:
    stepsize_controller = diffrax.ConstantStepSize()

  saveat = diffrax.SaveAt(ts=save_times)
  t0 = save_times[0]
  t1 = save_times[-1]
  dt0 = (t1 - t0)/params.n_steps


  # bm = diffrax.VirtualBrownianTree(t0,
  #                                  t1,
  #                                  tol=dt0/2,
  #                                  shape=x0.shape,
  #                                  key=key,
  #                                  levy_area=diffrax.SpaceTimeLevyArea)

  bm = diffrax.UnsafeBrownianPath(shape=x0.shape,
                                   key=key,
                                   levy_area=diffrax.SpaceTimeLevyArea)

  diff = diffrax.ControlTerm(diffusion_fn, bm)
  terms = diffrax.MultiTerm(wrapped_drift, diff)

  if params.progress_meter == 'tqdm':
    progress_meter = diffrax.TqdmProgressMeter()
  elif params.progress_meter == 'text':
    progress_meter = diffrax.TextProgressMeter()
  else:
    progress_meter = diffrax.NoProgressMeter()

  sol = diffrax.diffeqsolve(terms,
                            solver,
                            t0,
                            t1,
                            dt0=dt0,
                            y0=x0,
                            args=sde_params,
                            saveat=saveat,
                            adjoint=adjoint,
                            stepsize_controller=stepsize_controller,
                            max_steps=params.max_steps,
                            throw=params.throw,
                            progress_meter=progress_meter)
  from diffusion_crf.timeseries import TimeSeries
  return TimeSeries(save_times, sol.ys)
