import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable
import einops
import equinox as eqx
from abc import ABC, abstractmethod
import diffrax
from jaxtyping import Array, PRNGKeyArray, Float, Scalar
from jax._src.util import curry
import abc
import jax.tree_util as jtu
from diffusion_crf.gaussian.dist import StandardGaussian
from diffusion_crf.gaussian.transition import *
from diffusion_crf.matrix import *
from diffusion_crf.gaussian.dist import *
from plum import dispatch
from diffusion_crf.sde.sde_base import AbstractLinearTimeInvariantSDE
import diffusion_crf.util as util

__all__ = ['TrackingModel',
           'HigherOrderTrackingModel']

################################################################################################################

class TrackingModel(AbstractLinearTimeInvariantSDE):
  """
  F = [[0, 0, 1, 0],
       [0, 0, 0, 1],
       [0, 0, 0, 0],
       [0, 0, 0, 0]]

  L = [[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, s, 0],
       [0, 0, 0, s]]
  """

  sigma: Scalar
  F: Diagonal2x2BlockMatrix
  L: DiagonalMatrix

  def __init__(self, sigma: Scalar):
    self.sigma = sigma

    elements = jnp.array([[[0.0,0.0],
                           [1.0,1.0]],
                          [[0.0,0.0],
                           [0.0,0.0]]])
    self.F = Diagonal2x2BlockMatrix(elements, tags=TAGS.no_tags)

    elements = jnp.array([0.0, 0.0, self.sigma, self.sigma])
    self.L = DiagonalMatrix(elements, tags=TAGS.symmetric_tags)

  def get_transition_distribution(self,
                           s: Scalar,
                           t: Scalar) -> GaussianTransition:
    s = jnp.array(s)
    t = jnp.array(t)
    assert t.shape == s.shape == ()
    dt = t - s

    A_elements = jnp.array([[[1, 1],
                             [dt, dt]],
                            [[0, 0],
                             [1, 1]]])
    A = Diagonal2x2BlockMatrix(A_elements, tags=TAGS.no_tags)

    AinvT_elements = jnp.array([[[1, 1],
                                 [0, 0]],
                                [[-dt, -dt],
                                 [1, 1]]])
    AinvT = Diagonal2x2BlockMatrix(AinvT_elements, tags=TAGS.no_tags)
    A = MatrixWithInverse(A, AinvT.T)

    a = dt**3/3
    b = dt**2/2
    c = dt
    Sigma_elements = jnp.array([[[a, a],
                                 [b, b]],
                                [[b, b],
                                 [c, c]]])*self.sigma**2
    Sigma = Diagonal2x2BlockMatrix(Sigma_elements, tags=TAGS.symmetric_tags)
    u = jnp.zeros((A.shape[-2],))

    # If all of Sigma has elements close to 0, symbolically set it to 0.
    Sigma = util.where(jnp.abs(Sigma.elements).mean() < 1e-8, Sigma.set_zero(), Sigma)

    out = GaussianTransition(A, u, Sigma)

    return out

################################################################################################################

class HigherOrderTrackingModel(AbstractLinearTimeInvariantSDE):
  """A higher order tracking model.  This has latent variables corresponding
  to different derivatives of position.  For example, if order=3, then the
  latent variables correspond to position, velocity and acceleration.  There
  will be no noise added to position and less noise added to the lower derivatives
  than the higher derivatives.
  """

  F: Diagonal3x3BlockMatrix
  L: DiagonalMatrix
  position_dim: int = eqx.field(static=True)
  order: int = eqx.field(static=True)

  def __init__(self,
               sigma: Scalar,
               position_dim: int,
               order: int):
    assert order > 1
    self.position_dim = position_dim
    self.order = order

    # Build a block matrix with order x order blocks and to start
    # make each block the identitiy matrix
    I = jnp.eye(self.position_dim)
    F = jnp.block([[I]*self.order]*self.order)

    # Zero out everything that isn't the upper elements
    F = jnp.triu(F, k=self.position_dim)
    if order == 2:
      elements = jnp.zeros((2, 2, self.position_dim))
      elements = elements.at[0,1,:].set(1.0)
      self.F = Diagonal2x2BlockMatrix(elements, tags=TAGS.no_tags)
    elif order == 3:
      elements = jnp.zeros((3, 3, self.position_dim))
      elements = elements.at[0,1,:].set(1.0)
      elements = elements.at[1,2,:].set(1.0)
      self.F = Diagonal3x3BlockMatrix(elements, tags=TAGS.no_tags)
    else:
      self.F = DenseMatrix(jnp.tril(F, k=self.position_dim), tags=TAGS.no_tags)

    # Construct the diffusion matrix.  Place less noise on lower order terms
    sigma = jnp.array(sigma)
    if sigma.ndim == 0:
      factor = 3
      order_noise = (1/sigma**(factor - 1))*jnp.linspace(0.0, sigma, self.order)**factor
      order_noise = jnp.repeat(order_noise, self.position_dim)
      self.L = DiagonalMatrix(order_noise, tags=TAGS.no_tags)
    else:
      assert sigma.shape == ((self.order - 1),)
      sigma = jnp.pad(sigma, (1, 0))
      sigma = jnp.repeat(sigma, self.position_dim)
      self.L = DiagonalMatrix(sigma, tags=TAGS.no_tags)

  def get_transition_distribution(self,
                                  s: Scalar,
                                  t: Scalar) -> GaussianTransition:
    """We don't have an efficient matrix exponential for DiagonalBlockMatrix yet so for
    the most numerical stability, we explicitly compute the matrix exponential"""
    s = jnp.array(s)
    t = jnp.array(t)
    assert t.shape == s.shape == ()
    dt = t - s

    if self.order == 2:
      one = jnp.ones_like(self.F.elements[0,0])
      zero = jnp.zeros_like(self.F.elements[0,0])
      A_elements = jnp.zeros((2, 2, self.position_dim))
      A_elements = A_elements.at[0,0,:].set(one)
      A_elements = A_elements.at[0,1,:].set(one*dt)
      A_elements = A_elements.at[1,1,:].set(one)
      A = Diagonal2x2BlockMatrix(A_elements, tags=TAGS.no_tags)

      AinvT_elements = jnp.zeros((2, 2, self.position_dim))
      AinvT_elements = AinvT_elements.at[0,0,:].set(one)
      AinvT_elements = AinvT_elements.at[1,0,:].set(-one*dt)
      AinvT_elements = AinvT_elements.at[1,1,:].set(one)
      AinvT = Diagonal2x2BlockMatrix(AinvT_elements, tags=TAGS.no_tags)
      A = MatrixWithInverse(A, AinvT.T)

      sigma2 = self.L.elements[self.position_dim:]**2
      assert sigma2.shape == one.shape
      Sigma_elements = jnp.zeros((2, 2, self.position_dim))
      Sigma_elements = Sigma_elements.at[0,0,:].set(sigma2*dt**3/3)
      Sigma_elements = Sigma_elements.at[0,1,:].set(sigma2*dt**2/2)
      Sigma_elements = Sigma_elements.at[1,0,:].set(sigma2*dt**2/2)
      Sigma_elements = Sigma_elements.at[1,1,:].set(sigma2*dt)
      Sigma = Diagonal2x2BlockMatrix(Sigma_elements, tags=TAGS.symmetric_tags)

      # If all of Sigma has elements close to 0, symbolically set it to 0.
      Sigma = util.where(jnp.abs(Sigma.elements).mean() < 1e-8, Sigma.set_zero(), Sigma)

      u = jnp.zeros((A.shape[-2],))
      return GaussianTransition(A, u, Sigma)
    elif self.order == 3:

      # Create the matrix exponential
      A = jnp.zeros((3, 3, self.position_dim)) # Top left terms
      AinvT = jnp.zeros((3, 3, self.position_dim)) # Bottom right terms
      Sigma_AinvT = jnp.zeros((3, 3, self.position_dim)) # Top right left terms

      Sigma1 = self.L.elements[self.position_dim:2*self.position_dim]**2
      Sigma2 = self.L.elements[2*self.position_dim:]**2

      # X0 is the identity
      A = A.at[0,0,:].add(1.0)
      A = A.at[1,1,:].add(1.0)
      A = A.at[2,2,:].add(1.0)
      AinvT = AinvT.at[0,0,:].add(1.0)
      AinvT = AinvT.at[1,1,:].add(1.0)
      AinvT = AinvT.at[2,2,:].add(1.0)

      # X*dt
      A = A.at[0,1,:].add(dt)
      A = A.at[1,2,:].add(dt)
      AinvT = AinvT.at[1,0,:].add(-dt)
      AinvT = AinvT.at[2,1,:].add(-dt)
      Sigma_AinvT = Sigma_AinvT.at[1,1,:].add(Sigma1*dt)
      Sigma_AinvT = Sigma_AinvT.at[2,2,:].add(Sigma2*dt)

      # 1/2*X2(dt**2)
      A = A.at[0,2,:].add(dt**2/2)
      AinvT = AinvT.at[2,0,:].add(dt**2/2)
      Sigma_AinvT = Sigma_AinvT.at[0,1,:].add(Sigma1*dt**2/2)
      Sigma_AinvT = Sigma_AinvT.at[1,0,:].add(-Sigma1*dt**2/2)
      Sigma_AinvT = Sigma_AinvT.at[1,2,:].add(Sigma2*dt**2/2)
      Sigma_AinvT = Sigma_AinvT.at[2,1,:].add(-Sigma2*dt**2/2)

      # 1/6*X3(dt**3)
      Sigma_AinvT = Sigma_AinvT.at[0,0,:].add(-Sigma1*dt**3/6)
      Sigma_AinvT = Sigma_AinvT.at[0,2,:].add(Sigma2*dt**3/6)
      Sigma_AinvT = Sigma_AinvT.at[1,1,:].add(-Sigma2*dt**3/6)
      Sigma_AinvT = Sigma_AinvT.at[2,0,:].add(Sigma2*dt**3/6)

      # 1/24*X4(dt**4)
      Sigma_AinvT = Sigma_AinvT.at[0,1,:].add(-Sigma2*dt**4/24)
      Sigma_AinvT = Sigma_AinvT.at[1,0,:].add(Sigma2*dt**4/24)

      # 1/120*X5(dt**5)
      Sigma_AinvT = Sigma_AinvT.at[0,0,:].add(Sigma2*dt**5/120)

      A = Diagonal3x3BlockMatrix(A, tags=TAGS.no_tags)
      AinvT = Diagonal3x3BlockMatrix(AinvT, tags=TAGS.no_tags)
      A = MatrixWithInverse(A, AinvT.T)
      Sigma_AinvT = Diagonal3x3BlockMatrix(Sigma_AinvT, tags=TAGS.symmetric_tags)
      Sigma = Sigma_AinvT@A.T
      Sigma = Sigma.set_symmetric()
      u = jnp.zeros((A.shape[-2],))

      # If all of Sigma has elements close to 0, symbolically set it to 0.
      Sigma = util.where(jnp.abs(Sigma.elements).mean() < 1e-8, Sigma.set_zero(), Sigma)
      return GaussianTransition(A, u, Sigma)
    else:
      return super().get_transition_distribution(s, t)

################################################################################################################

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  from diffusion_crf.sde.sde_base import linear_sde_test
  import matplotlib.pyplot as plt
  import diffusion_crf.util as util
  jax.config.update('jax_enable_x64', True)

  key = random.PRNGKey(0)
  sigma = 0.1
  dim = 2

  # Create a prior
  t0 = jnp.array(0.0)
  def make_prior(key, order):
    x0 = jnp.ones((order*dim,))
    I = DiagonalMatrix.eye(order*dim)
    Sigma_mat = random.normal(key, (order*dim,))
    Sigma0 = Sigma_mat*Sigma_mat
    Sigma0 = DiagonalMatrix(Sigma0, tags=TAGS.symmetric_tags)
    # fwd_prior = StandardGaussian(x0, Sigma0)#.make_deterministic()
    fwd_prior = MixedGaussian(x0, Sigma0.get_inverse())#.make_deterministic()
    return fwd_prior

  # Create the SDE
  sde1 = TrackingModel(sigma)
  sde2 = HigherOrderTrackingModel(sigma, dim, order=2)
  sde3 = HigherOrderTrackingModel(sigma, dim, order=3)


  # grow_crf_test(sde1)
  linear_sde_test(sde1)
  linear_sde_test(sde2)
  linear_sde_test(sde3)
