import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable
import einops
import equinox as eqx
from abc import ABC, abstractmethod
import diffrax
from jaxtyping import Array, PRNGKeyArray, Float, Scalar
from jax._src.util import curry
import abc
import jax.tree_util as jtu
from diffusion_crf.base import *
import diffusion_crf.util as util
import warnings
import lineax as lx
from diffusion_crf.gaussian.transition import *
from diffusion_crf.gaussian.dist import StandardGaussian, NaturalGaussian
from diffusion_crf.matrix import *
from diffusion_crf.matrix.matrix_with_inverse import MatrixWithInverse
from diffusion_crf.sde.sde_base import AbstractLinearTimeInvariantSDE, AbstractLinearSDE
from diffusion_crf.matrix.block_2x2 import Block2x2Matrix

__all__ = ['CriticallyDampedLangevinDynamics', 'TOLD', 'HarmonicOscillator']

################################################################################################################

class CriticallyDampedLangevinDynamics(AbstractLinearTimeInvariantSDE):
  """https://arxiv.org/pdf/2112.07068"""
  F: Diagonal2x2BlockMatrix
  L: DiagonalMatrix

  def __init__(self,
               mass: Union[Float[Array, 'dim'], Scalar],
               beta: Union[Float[Array, 'dim'], Scalar],
               dim: Optional[int] = None):
    mass = jnp.array(mass)
    beta = jnp.array(beta)
    if mass.ndim == 0:
      assert dim is not None
      mass = jnp.ones(dim)*mass
    else:
      dim = mass.shape[-1]

    if beta.ndim == 0:
      beta = jnp.ones(dim)*beta
    else:
      assert beta.shape[-1] == dim

    assert mass.shape ==  beta.shape

    gamma = jnp.sqrt(4*mass) # critical damping
    zero = jnp.zeros_like(gamma)

    F_elements = jnp.array([[zero, beta/mass],
                            [-beta, -gamma*beta/mass]])
    self.F = Diagonal2x2BlockMatrix(F_elements, tags=TAGS.no_tags)

    elements = jnp.sqrt(2*gamma*beta)
    elements = jnp.pad(elements, (elements.shape[-1], 0))
    self.L = DiagonalMatrix(elements, tags=TAGS.symmetric_tags)


class CriticallyDampedLangevinDynamics2(AbstractLinearTimeInvariantSDE):
  """https://arxiv.org/pdf/2112.07068"""
  F: Block2x2Matrix
  L: DiagonalMatrix

  def __init__(self,
               mass: Union[Float[Array, 'dim'], Scalar],
               beta: Union[Float[Array, 'dim'], Scalar],
               dim: Optional[int] = None):
    mass = jnp.array(mass)
    beta = jnp.array(beta)
    if mass.ndim == 0:
      assert dim is not None
      mass = jnp.ones(dim)*mass
    else:
      dim = mass.shape[-1]

    if beta.ndim == 0:
      beta = jnp.ones(dim)*beta
    else:
      assert beta.shape[-1] == dim

    assert mass.shape ==  beta.shape

    gamma = jnp.sqrt(4*mass) # critical damping
    zero = jnp.zeros_like(gamma)
    top_left = DiagonalMatrix(zero, tags=TAGS.no_tags)
    top_right = DiagonalMatrix(beta/mass, tags=TAGS.no_tags)
    bottom_left = DiagonalMatrix(-beta, tags=TAGS.no_tags)
    bottom_right = DiagonalMatrix(-gamma*beta/mass, tags=TAGS.no_tags)

    self.F = Block2x2Matrix.from_blocks(top_left, top_right, bottom_left, bottom_right)

    elements = jnp.sqrt(2*gamma*beta)
    elements = jnp.pad(elements, (elements.shape[-1], 0))
    self.L = DiagonalMatrix(elements, tags=TAGS.symmetric_tags)


################################################################################################################

class TOLD(AbstractLinearTimeInvariantSDE):
  """https://arxiv.org/pdf/2409.07697"""

  F: Diagonal3x3BlockMatrix
  L: DiagonalMatrix

  def __init__(self,
               L: Scalar = 1,
               *,
               dim: int):
    one = jnp.ones(dim)
    zero = jnp.zeros(dim)

    F_elements = jnp.array([[zero, one, zero],
                             [-one, zero, 2*jnp.sqrt(2)*one],
                             [zero, -2*jnp.sqrt(2)*one, -3*jnp.sqrt(3)*one]])
    self.F = Diagonal3x3BlockMatrix(F_elements, tags=TAGS.no_tags)

    L_elements = 3**(0.25)*jnp.sqrt(6/L)*one
    L_elements = jnp.pad(L_elements, (2*L_elements.shape[-1], 0))
    self.L = DiagonalMatrix(L_elements, tags=TAGS.symmetric_tags)

################################################################################################################

class HarmonicOscillator(AbstractLinearTimeInvariantSDE):
  F: Diagonal2x2BlockMatrix
  L: DiagonalMatrix

  def __init__(self,
               freq: Union[Float[Array, 'dim'], Scalar], # Period of the oscillator
               coeff: Union[Float[Array, 'dim'], Scalar], # Damping coefficient
               sigma: Union[Float[Array, 'dim'], Scalar], # Noise level
               observation_dim: Optional[int] = None):
    freq = jnp.array(freq)
    coeff = jnp.array(coeff)
    sigma = jnp.array(sigma)
    if freq.ndim == 0:
      assert observation_dim is not None
      freq = jnp.ones(observation_dim)*freq
    else:
      observation_dim = freq.shape[-1]

    if coeff.ndim == 0:
      coeff = jnp.ones(observation_dim)*coeff
    else:
      assert coeff.shape[-1] == observation_dim

    if sigma.ndim == 0:
      sigma = jnp.ones(observation_dim)*sigma
    else:
      assert sigma.shape[-1] == dim

    assert freq.shape ==  coeff.shape == sigma.shape

    zero = jnp.zeros_like(freq)
    one = jnp.ones_like(freq)

    F_elements = jnp.array([[zero, one],
                            [-freq**2, -coeff]])
    self.F = Diagonal2x2BlockMatrix(F_elements, tags=TAGS.no_tags)

    elements = sigma
    elements = jnp.pad(elements, (elements.shape[-1], 0))
    self.L = DiagonalMatrix(elements, tags=TAGS.symmetric_tags)

################################################################################################################

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  from diffusion_crf.sde.sde_base import linear_sde_test
  import matplotlib.pyplot as plt
  import diffusion_crf.util as util
  from diffusion_crf.gaussian.dist import MixedGaussian
  jax.config.update('jax_enable_x64', True)

  key = random.PRNGKey(0)
  sigma = 0.1
  dim = 2

  # Create a prior
  t0 = jnp.array(0.0)
  x0 = jnp.ones((dim,))
  I = DiagonalMatrix.eye(dim)

  if False:
    Sigma_mat = random.normal(key, (dim, dim))
    Sigma0 = Sigma_mat@Sigma_mat.T
    Sigma0 = DenseMatrix(Sigma0, tags=TAGS.symmetric_tags)
  else:
    Sigma_mat = random.normal(key, (dim,))
    Sigma0 = Sigma_mat*Sigma_mat
    Sigma0 = DiagonalMatrix(Sigma0, tags=TAGS.symmetric_tags)
  # fwd_prior = StandardGaussian(x0, Sigma0).make_deterministic()

  # Create the SDE
  fwd_prior = MixedGaussian(x0, Sigma0.get_inverse())#.make_deterministic()
  sde1 = HarmonicOscillator(jnp.ones(dim)*1, jnp.ones(dim)*0.1, jnp.ones(dim)*0.1, dim)
  # sde1 = CriticallyDampedLangevinDynamics(jnp.ones(dim)*1, jnp.ones(dim)*0.1, dim)

  sde1.get_transition_distribution(0.1, 0.2)

  sde2 = TOLD(dim=dim)
  print('Starting soft prior tests')
  linear_sde_test(sde1)
  linear_sde_test(sde2)
