import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, List
import einops
import equinox as eqx
from abc import ABC, abstractmethod
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
import pandas as pd
import numpy as np
from diffusion_crf import TimeSeries
from diffusion_crf.sde.sde_base import AbstractSDE
from diffusion_crf.matrix import DiagonalMatrix
from Models.models.ode_sde_simulation import AbstractSolverParams, ODESolverParams, ode_solve, SDESolverParams, sde_sample, DiffraxSolverState

class AbstractDynamicalSystemSDE(AbstractSDE):

  def get_drift(self, t: Scalar, x: Float[Array, 'D']) -> Float[Array, 'D']:
    pass

  def get_diffusion_coefficient(self, t: Scalar, x: Float[Array, 'D']) -> Float[Array, 'D']:
    pass

  @property
  def batch_size(self) -> int:
    return None # Not going to care about this

  def simulate(
    self,
    x0: Float[Array, 'D'],
    save_times: Float[Array, 'T'],
    key: PRNGKeyArray,
  ) -> TimeSeries:
    solver_params = ODESolverParams(rtol=1e-3,
                             atol=1e-3,
                             solver='dopri5',
                             adjoint='recursive_checkpoint',
                             stepsize_controller='pid',
                             max_steps=20_000,
                             throw=True,
                             progress_meter=None)
    simulated_trajectory = ode_solve(self,
                                      x0=x0,
                                      save_times=save_times,
                                      params=solver_params)
    return simulated_trajectory

  @abstractmethod
  def sample_x0(self, key: PRNGKeyArray) -> Float[Array, 'D']:
    pass


class LorenzSDE(AbstractDynamicalSystemSDE):
  sigma: Scalar
  rho: Scalar
  beta: Scalar

  def __init__(self, sigma, rho, beta):
    self.sigma = sigma
    self.rho = rho
    self.beta = beta


  def get_flow(self, t: Scalar, x: Float[Array, 'D']) -> Float[Array, 'D']:
    x1, x2, x3 = x
    dx1 = self.sigma * (x2 - x1)
    dx2 = x1 * (self.rho - x3) - x2
    dx3 = x1 * x2 - self.beta * x3
    return jnp.array([dx1, dx2, dx3])

  def sample_x0(self, key: PRNGKeyArray) -> Float[Array, 'D']:
    return random.uniform(key, (3,), minval=0.0, maxval=10.0)


class FitzHughNagumoSDE(AbstractDynamicalSystemSDE):
  a: Scalar
  b: Scalar
  tau: Scalar
  I: Scalar

  def __init__(self, a, b, tau, I):
    self.a = a
    self.b = b
    self.tau = tau
    self.I = I

  def get_flow(self, t: Scalar, x: Float[Array, 'D']) -> Float[Array, 'D']:
    x1, x2 = x
    dx1 = x1 - (x1**3) / 3 - x2 + self.I
    dx2 = (x1 + self.a - self.b * x2) / self.tau
    return jnp.array([dx1, dx2])

  def sample_x0(self, key: PRNGKeyArray) -> Float[Array, 'D']:
    return random.uniform(key, (2,), minval=-2.0, maxval=2.0)

class LotkaVolterraSDE(AbstractDynamicalSystemSDE):
  alpha: Scalar
  beta: Scalar
  gamma: Scalar
  delta: Scalar

  def __init__(self, alpha, beta, gamma, delta):
    self.alpha = alpha
    self.beta = beta
    self.gamma = gamma
    self.delta = delta

  def get_flow(self, t: Scalar, x: Float[Array, 'D']) -> Float[Array, 'D']:
    x1, x2 = x
    dx1 = self.alpha * x1 - self.beta * x1 * x2
    dx2 = -self.delta * x2 + self.gamma * x1 * x2
    return jnp.array([dx1, dx2])

  def sample_x0(self, key: PRNGKeyArray) -> Float[Array, 'D']:
    return random.uniform(key, (2,), minval=0.0, maxval=5.0)

class BrusselatorSDE(AbstractDynamicalSystemSDE):
  A: Scalar
  B: Scalar

  def __init__(self, A, B):
    self.A = A
    self.B = B

  def get_flow(self, t: Scalar, x: Float[Array, 'D']) -> Float[Array, 'D']:
    x1, x2 = x
    dx1 = self.A + x1**2 * x2 - (self.B + 1) * x1
    dx2 = self.B * x1 - x1**2 * x2
    return jnp.array([dx1, dx2])

  def sample_x0(self, key: PRNGKeyArray) -> Float[Array, 'D']:
    return random.uniform(key, (2,), minval=0.0, maxval=2.0)

class VanDerPolSDE(AbstractDynamicalSystemSDE):
  mu: Scalar

  def __init__(self, mu):
    self.mu = mu

  def get_flow(self, t: Scalar, x: Float[Array, 'D']) -> Float[Array, 'D']:
    x1, x2 = x
    dx1 = x2
    dx2 = self.mu * (1 - x1**2) * x2 - x1
    return jnp.array([dx1, dx2])

  def sample_x0(self, key: PRNGKeyArray) -> Float[Array, 'D']:
    return random.uniform(key, (2,), minval=-2.0, maxval=2.0)


def create_raw_physics_data(name: str,
                            key: PRNGKeyArray,
                            n_points: int = 225,
                            n_trajectories: int = 2400,
                            save_folder: str = 'Data/dynamical_system/physics_data'):

  if name == 'lorenz':
    sde = LorenzSDE(sigma=10.0, rho=28.0, beta=8/3)
    t0, t1 = 0.0, 2.0
  elif name == 'fitzhugh':
    sde = FitzHughNagumoSDE(a=0.7, b=0.8, tau=12.5, I=0.5)
    t0, t1 = 0.0, 10.0
  elif name == 'lotka':
    sde = LotkaVolterraSDE(alpha=1.3, beta=0.9, gamma=0.8, delta=1.8)
    # sde = LotkaVolterraSDE(alpha=1.3, beta=0.9, gamma=0.8, delta=1.8)
    t0, t1 = 0.0, 20.0
  elif name == 'brusselator':
    sde = BrusselatorSDE(A=1.0, B=3.0)
    t0, t1 = 0.0, 20.0
  elif name == 'van_der_pol':
    sde = VanDerPolSDE(mu=0.1)
    t0, t1 = 0.0, 20.0
  else:
    raise ValueError(f'Unknown dynamical system: {name}')

  keys = random.split(key, n_trajectories)
  x0s = jax.vmap(sde.sample_x0)(keys)

  # Simulate the trajectories
  ts = jnp.linspace(t0, t1, n_points)
  trajectories: TimeSeries = jax.vmap(sde.simulate, in_axes=(0, None, 0))(x0s, ts, keys)

  # Save off the trajectories
  ts = np.array(trajectories.ts)
  yts = np.array(trajectories.yts)

  # Ensure that the save_folder exists
  os.makedirs(save_folder, exist_ok=True)
  save_path = os.path.join(save_folder, f'{name}_data.npz')
  print(f"Saving data to: {save_path}")
  np.savez(save_path, ts=ts, yts=yts)

  # Plot the trajectories
  trajectories[:100].plot_series(title=f'{name} trajectories', show_plot=False)
  plt.savefig(os.path.join(save_folder, f'{name}_trajectories.png'))

  return save_path

def get_raw_physics_data(name: str,
                         noise_std: Optional[float] = None,
                         key: Optional[PRNGKeyArray] = None,
                         save_folder: str = 'Data/dynamical_system/physics_data'):
  save_path = os.path.join(save_folder, f'{name}_data.npz')
  data = np.load(save_path)
  ts = jnp.array(data['ts'])
  yts = jnp.array(data['yts'])
  full_original_series = TimeSeries(ts=ts, yts=yts)

  # Add noise to the data if noise_std is not None
  ts = full_original_series.ts
  samples = full_original_series.yts
  if noise_std is not None:
    assert key is not None, 'key must be provided if noise_std is not None'
    noisy_samples = samples + noise_std*random.normal(key, samples.shape)
  else:
    noisy_samples = samples
  full_data_series = TimeSeries(ts=ts, yts=noisy_samples)

  return full_data_series, full_original_series

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  import tqdm
  # turn on x64
  jax.config.update('jax_enable_x64', True)

  key = random.PRNGKey(0)
  _ = create_raw_physics_data(name='lorenz',
                              key=key,
                              n_points=225,
                              n_trajectories=2400,
                              save_folder='Data/dynamical_system/physics_data')
  _ = create_raw_physics_data(name='fitzhugh',
                              key=key,
                              n_points=225,
                              n_trajectories=2400,
                              save_folder='Data/dynamical_system/physics_data')
  _ = create_raw_physics_data(name='lotka',
                              key=key,
                              n_points=225,
                              n_trajectories=2400,
                              save_folder='Data/dynamical_system/physics_data')
  _ = create_raw_physics_data(name='brusselator',
                              key=key,
                              n_points=225,
                              n_trajectories=2400,
                              save_folder='Data/dynamical_system/physics_data')
  _ = create_raw_physics_data(name='van_der_pol',
                              key=key,
                              n_points=225,
                              n_trajectories=2400,
                              save_folder='Data/dynamical_system/physics_data')

  # Load the data
  lorenz_data = get_raw_physics_data(name='lorenz',
                                     save_folder='Data/dynamical_system/physics_data')
