import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable
import einops
import equinox as eqx
from abc import ABC, abstractmethod
import diffrax
from jaxtyping import Array, PRNGKeyArray, Float, Scalar
from jax._src.util import curry
import abc
import jax.tree_util as jtu
from diffusion_crf.base import *
import diffusion_crf.util as util
import warnings
import lineax as lx
from diffusion_crf.gaussian.transition import *
from diffusion_crf.gaussian.dist import StandardGaussian, NaturalGaussian
from diffusion_crf.matrix import *
from diffusion_crf.matrix.matrix_with_inverse import MatrixWithInverse
from diffusion_crf.sde.sde_base import AbstractLinearSDE

__all__ = ['VariancePreserving']

################################################################################################################

class VariancePreserving(AbstractLinearSDE, abc.ABC):

  beta_min: Scalar
  beta_max: Scalar
  dim: int

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self.beta_min.ndim == 0:
      return None
    elif self.beta_min.ndim == 1:
      return self.beta_min.shape[0]
    else:
      return self.beta_min.shape[:-1]

  def beta(self, t: Scalar) -> Scalar:
    return self.beta_min + t*(self.beta_max - self.beta_min)

  def T(self, t: Scalar) -> Scalar:
    return t*self.beta_min + 0.5*t**2*(self.beta_max - self.beta_min)

  def get_params(self, t: Scalar) -> Tuple[AbstractMatrix,
                                           Float[Array, 'D'],
                                           AbstractMatrix]:
    beta = self.beta(t)
    I = DiagonalMatrix.eye(self.dim)
    F = -0.5*beta*I
    u = jnp.zeros((self.dim,))
    L = jnp.sqrt(beta)*I
    return F, u, L

  def get_transition_distribution(self,
                                  s: Scalar,
                                  t: Scalar) -> GaussianTransition:
    Tt = self.T(t)
    Ts = self.T(s)
    dT = Tt - Ts
    alpha = jnp.exp(-0.5*dT)
    I = DiagonalMatrix.eye(self.dim)

    A = alpha*I
    u = jnp.zeros((self.dim,))
    Sigma = (1 - jnp.exp(-dT))*I

    return GaussianTransition(A, u, Sigma)

################################################################################################################

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  from diffusion_crf.sde.sde_base import linear_sde_test
  import matplotlib.pyplot as plt
  import diffusion_crf.util as util
  from diffusion_crf.gaussian.dist import MixedGaussian
  from diffusion_crf.sde.ode_sde_solve import SDESolverParams, sde_sample
  jax.config.update('jax_enable_x64', True)

  key = random.PRNGKey(0)
  sigma = 0.1
  dim = 2

  vp = VariancePreserving(beta_min=jnp.array(0.1), beta_max=jnp.array(1.0), dim=dim)
  t0 = 0.5
  t1 = 1.0
  transition = vp.get_transition_distribution(t0, t1)

  x0 = random.normal(key, (dim,))
  save_times = jnp.linspace(t0, t1, 10)

  sde_solver_params = SDESolverParams(solver='euler_heun',
                                      adjoint='recursive_checkpoint',
                                      n_steps=3000,
                                      max_steps=40000,
                                      stepsize_controller='none',
                                      progress_meter='tqdm')

  def get_samples(key):
    out = sde_sample(vp,
                    x0,
                    key,
                    save_times,
                    params=sde_solver_params)
    return out[-1]

  keys = random.split(key, 4000)
  samples1 = jax.vmap(get_samples)(keys)
  dist1 = util.empirical_dist(samples1)

  blah = transition.condition_on_x(x0)
  samples2 = jax.vmap(blah.sample)(keys)
  dist2 = util.empirical_dist(samples2)

  w2_dist = util.w2_distance(dist1, dist2)
  print(f'W2 distance between samples and transition: {w2_dist}')

  import pdb; pdb.set_trace()
