import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable
import einops
import equinox as eqx
from abc import ABC, abstractmethod
import diffrax
from jaxtyping import Array, PRNGKeyArray, Float, Scalar
from jax._src.util import curry
import abc
import jax.tree_util as jtu
from diffusion_crf.base import *
import diffusion_crf.util as util
import warnings
import lineax as lx
from diffusion_crf.gaussian.transition import *
from diffusion_crf.gaussian.dist import StandardGaussian, NaturalGaussian
from diffusion_crf.matrix import *
from diffusion_crf.matrix.matrix_with_inverse import MatrixWithInverse
from diffusion_crf.sde.sde_base import AbstractLinearTimeInvariantSDE, AbstractLinearSDE

__all__ = ['BrownianMotion',
           'OrnsteinUhlenbeck',
           'DiagonalDiffusionBrownianMotion',
           'DenseLTISDE']

################################################################################################################

class DenseLTISDE(AbstractLinearTimeInvariantSDE):

  F: DenseMatrix
  L: DenseMatrix

  def __init__(
    self,
    key: PRNGKeyArray,
    dim: int
  ):
    F, L = random.normal(key, (2, dim, dim))*0.001
    L = L + jnp.eye(dim)
    self.F = DenseMatrix(F, tags=TAGS.no_tags)
    self.L = DenseMatrix(L, tags=TAGS.no_tags)

################################################################################################################

class BrownianMotion(AbstractLinearTimeInvariantSDE):

  sigma: Scalar
  F: DiagonalMatrix
  L: DiagonalMatrix

  def __init__(
    self,
    sigma: Scalar,
    dim: int
  ):
    self.sigma = sigma
    self.F = DiagonalMatrix.zeros(dim)
    self.L = DiagonalMatrix.eye(dim)*sigma

  def get_transition_distribution(
      self,
      s: Scalar,
      t: Scalar
  ) -> GaussianTransition:
    """Default implementation has same performance, this is just here for clarity"""
    s = jnp.array(s)
    t = jnp.array(t)
    assert t.shape == s.shape == ()

    dt = t - s

    A = DiagonalMatrix.eye(self.dim)
    AinvT = DiagonalMatrix.eye(self.dim)
    A = MatrixWithInverse(A, AinvT.T)

    Sigma = jnp.ones(self.dim)*self.sigma**2*dt
    Sigma = DiagonalMatrix(Sigma, tags=TAGS.symmetric_tags)

    u = jnp.zeros(self.dim)
    out = GaussianTransition(A, u, Sigma)

    # If all of Sigma has elements close to 0, symbolically set it to 0.
    Sigma = util.where(jnp.abs(Sigma.elements).mean() < 1e-8, Sigma.set_zero(), Sigma)

    return out

################################################################################################################

class OrnsteinUhlenbeck(AbstractLinearTimeInvariantSDE):

  sigma: Scalar
  lambda_: Scalar

  F: DiagonalMatrix
  L: DiagonalMatrix

  def __init__(
    self,
    sigma: Scalar,
    lambda_: Scalar,
    dim: int
  ):
    self.sigma = sigma
    self.lambda_ = lambda_
    self.F = DiagonalMatrix.eye(dim)*-self.lambda_
    self.L = DiagonalMatrix.eye(dim)*sigma

  def get_transition_distribution(
    self,
    s: Scalar,
    t: Scalar
  ) -> GaussianTransition:
    """Default implementation has same performance, this is just here for clarity"""
    s = jnp.array(s)
    t = jnp.array(t)
    assert t.shape == s.shape == ()

    dt = t - s

    A = DiagonalMatrix.eye(self.dim)*jnp.exp(-self.lambda_*dt)
    AinvT = DiagonalMatrix.eye(self.dim)*jnp.exp(self.lambda_*dt)
    A = MatrixWithInverse(A, AinvT.T)
    Sigma = 0.5*(self.sigma**2/self.lambda_)*DiagonalMatrix.eye(self.dim)*(1 - jnp.exp(-2*self.lambda_*dt))

    u = jnp.zeros(self.dim)

    # If all of Sigma has elements close to 0, symbolically set it to 0.
    Sigma = util.where(jnp.abs(Sigma.elements).mean() < 1e-8, Sigma.set_zero(), Sigma)

    return GaussianTransition(A, u, Sigma)

################################################################################################################

class DiagonalDiffusionBrownianMotion(AbstractLinearTimeInvariantSDE):

  F: DiagonalMatrix
  L: DiagonalMatrix

  def __init__(self, sigma_diag: Float[Array, 'D']):
    dim = sigma_diag.shape[0]
    self.F = DiagonalMatrix.zeros(dim)
    self.L = DiagonalMatrix(sigma_diag, tags=TAGS.symmetric_tags)

################################################################################################################

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  from diffusion_crf.sde.sde_base import linear_sde_test
  import matplotlib.pyplot as plt
  import diffusion_crf.util as util
  from diffusion_crf.gaussian.dist import MixedGaussian
  jax.config.update('jax_enable_x64', True)

  key = random.PRNGKey(0)
  sigma = 0.1
  dim = 2

  # Create a prior
  t0 = jnp.array(0.0)
  x0 = jnp.ones((dim,))
  I = DiagonalMatrix.eye(dim)

  if False:
    Sigma_mat = random.normal(key, (dim, dim))
    Sigma0 = Sigma_mat@Sigma_mat.T
    Sigma0 = DenseMatrix(Sigma0, tags=TAGS.symmetric_tags)
  else:
    Sigma_mat = random.normal(key, (dim,))
    Sigma0 = Sigma_mat*Sigma_mat
    Sigma0 = DiagonalMatrix(Sigma0, tags=TAGS.symmetric_tags)
  # fwd_prior = StandardGaussian(x0, Sigma0).make_deterministic()

  # Create the SDE
  fwd_prior = MixedGaussian(x0, Sigma0.get_inverse())#.make_deterministic()
  sde1 = DenseLTISDE(key, dim)

  s, t = jnp.array(0.0), jnp.array(1.0)

  def get_log_prob(sde, s, t, x, y):
    transition = sde.get_transition_distribution(s, t)
    return transition.condition_on_x(x).log_prob(y)

  x, y = random.normal(key, (2, dim))
  grads = eqx.filter_grad(get_log_prob)(sde1, s, t, x, y)
  import pdb; pdb.set_trace()



  sde2 = BrownianMotion(sigma, dim)
  sde3 = OrnsteinUhlenbeck(sigma, 0.2, dim)
  sde4 = DiagonalDiffusionBrownianMotion(jnp.ones(dim)*sigma)
  print('Starting soft prior tests')
  linear_sde_test(sde1)
  linear_sde_test(sde2)
  linear_sde_test(sde3)
  linear_sde_test(sde4)
