import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from diffusion_crf.base import AbstractBatchableObject, AbstractTransition
from diffusion_crf.matrix import *
from diffusion_crf.gaussian import *
from plum import dispatch
import diffusion_crf.util as util
from diffusion_crf.sde.ode_sde_solve import ode_solve

__all__ = ['AbstractSDE',
           'AbstractLinearSDE',
           'AbstractLinearTimeInvariantSDE',
           'TimeScaledLinearTimeInvariantSDE']

################################################################################################################

class AbstractSDE(AbstractBatchableObject, abc.ABC):
  """An abstract SDE does NOT support sampling.  We need to incorporate a potential (or initial point)."""

  @abc.abstractmethod
  def get_drift(self, t: Scalar,  xt: Float[Array, 'D']) -> Float[Array, 'D']:
    pass

  @abc.abstractmethod
  def get_diffusion_coefficient(self, t: Scalar, xt: Float[Array, 'D']) -> AbstractMatrix:
    pass

  def get_transition_distribution(self, s: Scalar, t: Scalar) -> AbstractTransition:
    raise NotImplementedError

################################################################################################################

class AbstractLinearSDE(AbstractSDE, abc.ABC):

  @abc.abstractmethod
  def get_params(self, t: Scalar, xt: Float[Array, 'D']) -> Tuple[AbstractMatrix,
                                           Float[Array, 'D'],
                                           AbstractMatrix]:
    """Get F, u, and L at time t
    """
    pass

  def get_diffusion_coefficient(self, t: Scalar, xt: Float[Array, 'D']) -> AbstractMatrix:
    _, _, L = self.get_params(t)
    return L

  def get_drift(self, t: Scalar, xt: Float[Array, 'D']) -> Float[Array, 'D']:
    F, u, _ = self.get_params(t)
    return F@xt + u

  def get_transition_distribution(self,
                           s: Scalar,
                           t: Scalar) -> GaussianTransition:
    """Get the transition parameters from time s to time t. See section
    6.1 of Särkkä's book (https://users.aalto.fi/~asolin/sde-book/sde-book.pdf)
    for the math details.  This class solves everything in reverse in order to
    only need to solve one ODE.

    This solves for the transition parameters, A_{t,s}, u_{t,s}, and Sigma_{t,s} so that
    for a starting point x_s, the transition distribution p(x_t | x_s) is Gaussian
    N(x_t | A_{t,s} x_s + u_{t,s}, Sigma_{t,s})

    **Arguments**

    - `s`: The start time
    - `t`: The end time

    **Returns**

    - `A`: The transition matrix
    - `u`: The input
    - `Sigma`: The transition covariance
    """
    # Call get params to get the data types
    F, u, L = self.get_params(s)
    I = (F.T@L).set_eye() # Initialize with identity matrix.  Do it this way to get the right data type

    D = self.dim
    psi_TT = I # Initialize with identity matrix
    uT = jnp.zeros(D)
    SigmaT = I.set_zero() # Initialize with 0 matrix

    # Remove the tags from the matrices so that we avoid symbolic computation
    psi_TT = eqx.tree_at(lambda x: x.tags, psi_TT, TAGS.no_tags)
    SigmaT = eqx.tree_at(lambda x: x.tags, SigmaT, TAGS.no_tags)

    # Initialize the ODE state
    yT = (psi_TT, uT, SigmaT)

    # The ODE solver should not try to update the tags, so need to partition
    yT_params, yT_static = eqx.partition(yT, eqx.is_inexact_array)

    def reverse_dynamics(tau, ytau_params):
      ytau = eqx.combine(ytau_params, yT_static)
      psi_Ttau, _, _ = ytau

      Ftau, utau, L = self.get_params(tau)
      LLT = L@L.T

      dpsi_Ttau = -psi_Ttau@Ftau
      du = -psi_Ttau@utau # The negative sign comes from reversing the ODE
      dSigma = -psi_Ttau@LLT@psi_Ttau.T # The negative sign comes from reversing the ODE

      dytau = (dpsi_Ttau, du, dSigma)
      dytau_params, _ = eqx.partition(dytau, eqx.is_inexact_array)
      return dytau_params

    # Solve the ODE backwards in time
    save_times = jnp.array([t, s])
    y0_params = ode_solve(reverse_dynamics, yT_params, save_times)

    # Extract the final time point and combine with the static data
    y0_params = jax.tree_map(lambda x: x[-1], y0_params)
    y0 = eqx.combine(y0_params, yT_static)
    A, u, Sigma = y0

    # If all of Sigma has elements close to 0, symbolically set it to 0.
    Sigma = util.where(jnp.abs(Sigma.elements).mean() < 1e-8, Sigma.set_zero(), Sigma)

    return GaussianTransition(A, u, Sigma)

################################################################################################################

class AbstractLinearTimeInvariantSDE(AbstractLinearSDE, abc.ABC):

  F: eqx.AbstractVar[AbstractMatrix]
  L: eqx.AbstractVar[AbstractMatrix]

  @property
  def u(self) -> Float[Array, 'D']:
    return jnp.zeros((self.dim,))

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.F.batch_size

  @property
  def dim(self) -> int:
    return self.F.shape[0]

  def get_params(self, t: Scalar) -> Tuple[AbstractMatrix,
                                           Float[Array, 'D'],
                                           AbstractMatrix]:
    return self.F, self.u, self.L

  def get_transition_distribution(self,
                           s: Scalar,
                           t: Scalar) -> GaussianTransition:
    """Compute the covariance of the transition distribution from
       time s to time t.  We'll use the matrix fraction decomposition
       approach from section 6.3 of Särkkä's book (also detailed here
       https://arxiv.org/pdf/2302.07261)
    """
    D = self.F.shape[0]
    dt = t - s

    if isinstance(self.L, DiagonalMatrix):

      if isinstance(self.F, DiagonalMatrix):
        # This is simple
        A = (self.F*dt).get_exp()
        AinvT = (-self.F.T*dt).get_exp()
        A = MatrixWithInverse(A, AinvT.T)
        Sigma_AinvT = self.L@self.L.T*dt
        Sigma = self.L@self.L.T@A.T*dt

      elif isinstance(self.F, Diagonal2x2BlockMatrix) or \
           isinstance(self.F, Diagonal3x3BlockMatrix) or \
           isinstance(self.F, DenseMatrix) or \
           isinstance(self.F, Block2x2Matrix):

        zero = jnp.zeros((D, D))
        F, L = self.F.as_matrix(), self.L.as_matrix()
        X = jnp.block([[F, L@L.T], [zero, -F.T]])*dt
        Phi = jax.scipy.linalg.expm(X)

        A = Phi[:D,:D] # Top left
        Sigma_AinvT = Phi[:D,D:] # Top right
        Sigma = Sigma_AinvT@A.T # Bottom left
        AinvT = Phi[D:,D:] # Bottom right

        A = DenseMatrix(A, tags=TAGS.no_tags)
        AinvT = DenseMatrix(AinvT, tags=TAGS.no_tags)
        Sigma = DenseMatrix(Sigma, tags=TAGS.symmetric_tags)

        if isinstance(self.F, Diagonal2x2BlockMatrix) or isinstance(self.F, Diagonal3x3BlockMatrix) or isinstance(self.F, Block2x2Matrix):
          # Don't have a proof for this, but seems like if L is diagonal and F has diagonal blocks,
          # then A and Sigma also have diagonal blocks.
          # TODO: Find proof of correctness and figure out more efficient implementation that doesn't require
          # taking matrix exponential of full X matrix.
          A = self.F.project_dense(A)
          AinvT = self.F.project_dense(AinvT)
          Sigma = self.F.project_dense(Sigma)

        A = MatrixWithInverse(A, AinvT.T)

      else:
        raise ValueError('Invalid F type')

    else:
      zero = jnp.zeros((D, D))
      F, L = self.F.as_matrix(), self.L.as_matrix()
      X = jnp.block([[F, L@L.T], [zero, -F.T]])*dt
      Phi = jax.scipy.linalg.expm(X)

      A = Phi[:D,:D] # Top left
      Sigma_AinvT = Phi[:D,D:] # Top right
      Sigma = Sigma_AinvT@A.T # Bottom left
      AinvT = Phi[D:,D:] # Bottom right

      A = DenseMatrix(A, tags=TAGS.no_tags)
      AinvT = DenseMatrix(AinvT, tags=TAGS.no_tags)
      A = MatrixWithInverse(A, AinvT.T)
      Sigma = DenseMatrix(Sigma, tags=TAGS.symmetric_tags)

    u = jnp.zeros((D,))
    Sigma = Sigma.set_symmetric()

    # If all of Sigma has elements close to 0, symbolically set it to 0.
    Sigma = util.where(jnp.abs(Sigma.elements).mean() < 1e-8, Sigma.set_zero(), Sigma)

    return GaussianTransition(A, u, Sigma)

class LinearTimeInvariantSDE(AbstractLinearTimeInvariantSDE):
  F: AbstractMatrix
  L: AbstractMatrix

################################################################################################################

class TimeScaledLinearTimeInvariantSDE(AbstractLinearTimeInvariantSDE):
  """If sde represents dx_s = Fx_s ds + LdW_s, then this represents reparametrizing time
  as t = \gamma*s and also x_s = \gamma*tilde{x}_s
  """

  sde: AbstractLinearTimeInvariantSDE
  time_scale: Scalar

  def __init__(self,
               sde: AbstractLinearTimeInvariantSDE,
               time_scale: Scalar):
    self.sde = sde
    self.time_scale = time_scale

  @property
  def F(self) -> Float[Array, 'D D']:
    return self.sde.F*self.time_scale

  @property
  def L(self) -> Float[Array, 'D D']:
    return self.sde.L*jnp.sqrt(self.time_scale)

  @property
  def order(self) -> int:
    """To be compatible with HigherOrderTracking.  There is definitely a better
    way to access member variables of self.sde"""
    try:
      return self.sde.order
    except AttributeError:
      raise AttributeError(f'SDE of type {type(self.sde)} does not have an order')

  def get_transition_distribution(self,
                                  s: Scalar,
                                  t: Scalar) -> GaussianTransition:
    return self.sde.get_transition_distribution(s*self.time_scale, t*self.time_scale)

################################################################################################################

def max_likelihood_ltisde(xts: Float[Array, 'B N D'], dt: Scalar) -> LinearTimeInvariantSDE:
  """Compute the maximum likelihood parameters for the SDE given a batch of data that is sampled uniformly
  in time with step size dt."""
  if xts.ndim == 2:
    xts = xts[None]
  assert xts.ndim == 3

  # Compute the covariance of the data
  xt_xtT = jnp.einsum('bni,bnj->ij', xts[:,1:], xts[:,1:])
  xt_xtm1T = jnp.einsum('bni,bnj->ij', xts[:,1:], xts[:,:-1])
  A = xt_xtm1T@jnp.linalg.inv(xt_xtT)

  xts_diff = xts[:,1:] - jnp.einsum('ij,bnj->bni', A, xts[:,:-1])
  Sigma = jnp.einsum('bni,bnj->ij', xts_diff, xts_diff) / xts.shape[1]

  # Compute the block matrix
  AinvT = jnp.linalg.inv(A).T
  upper_left = A
  upper_right = Sigma@AinvT
  lower_left = jnp.zeros_like(upper_right)
  lower_right = AinvT
  block_matrix = jnp.block([[upper_left, upper_right], [lower_left, lower_right]])

  # Compute the matrix logarithm of the block matrix
  import scipy.linalg
  Psi = scipy.linalg.logm(block_matrix)/dt
  Psi = Psi.real

  D = xts.shape[2]
  F = Psi[:D,:D] # Top left
  LLT = Psi[:D,D:] # Top right
  negF = Psi[D:,D:] # Bottom right

  L = jnp.linalg.cholesky(LLT)

  return LinearTimeInvariantSDE(DenseMatrix(F, tags=TAGS.no_tags), DenseMatrix(L, tags=TAGS.no_tags))


################################################################################################################
################################################################################################################

def _compare_crf_samples_to_ode(sde: AbstractSDE, parallel: bool = False):
  from diffusion_crf.sde.ode_sde_solve import ode_solve
  key = random.PRNGKey(0)

  # Test that samples from the SDE, ODE and CRF have the same distribution
  keys = random.split(key, 10000)
  N = 3
  save_times = random.uniform(key, (N,), minval=0.0, maxval=2/N).cumsum()

  crf_result = sde.discretize(save_times)
  crf, info = crf_result.crf, crf_result.info
  marginals = crf.get_marginals()
  samples_crf = crf.sample(keys[0])
  bwd = crf.get_backward_messages()
  smoothed_transitions = crf.get_transitions()
  flow = sde.get_flow(save_times[0], samples_crf[0])

  def get_samples(key):
    samples_crf = sde.sample(key, save_times)
    x0 = samples_crf.yts[0]
    samples_ode = ode_solve(sde.get_flow, x0, info.ts)
    return samples_crf, samples_ode

  samples_crf, samples_ode = jax.vmap(get_samples)(keys)
  samples_crf = samples_crf.yts
  samples_ode = samples_ode.yts

  # Look at the W2 distance between the marginals
  marginals_crf = jax.vmap(util.empirical_dist, in_axes=1, out_axes=0)(samples_crf)
  marginals_ode = jax.vmap(util.empirical_dist, in_axes=1, out_axes=0)(samples_ode)
  marginals_ode = info.filter_new_times(marginals_ode)

  w2_dist = jax.vmap(util.w2_distance)(marginals_crf, marginals_ode)
  print(f'W2 distance between CRF and ODE: {w2_dist}')
  assert jnp.all(w2_dist < 1e-2)

def linear_sde_test(sde: AbstractSDE):
  from diffusion_crf.timeseries import ProbabilisticTimeSeries, TimeSeries
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  key = random.PRNGKey(0)

  # Create some evidence to condition on
  x0 = jnp.ones((sde.dim,))*5.0
  Sigma0 = DiagonalMatrix.eye(sde.dim)*0.001
  potential = MixedGaussian(x0, Sigma0.get_inverse())#.make_deterministic()
  ts = jnp.array([1.1])
  node_potentials = potential[None]

  pts = ProbabilisticTimeSeries.from_potentials(ts, node_potentials)
  cond_sde = ConditionedLinearSDE(sde, pts)

  save_times = jnp.linspace(0, 1, 100)
  result = cond_sde.discretize(save_times)
  crf1 = result.crf
  xts = crf1.sample(key)

  _compare_crf_samples_to_ode(cond_sde)

  print(f'\nFinished {sde} test\n\n')
