import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Iterable
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import jax.tree_util as jtu
from diffusion_crf.base import *
from diffusion_crf.util.parallel_scan import parallel_segmented_scan, parallel_scan
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractLinearSDE, AbstractLinearTimeInvariantSDE
from diffusion_crf.matrix import *
from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
from diffusion_crf.gaussian.dist import MixedGaussian, NaturalGaussian, StandardGaussian
import diffusion_crf.util as util

__all__ = ['GeneralLTISDE',
           'make_lti_sde',
           'build_conditional_sde']

class GeneralLTISDE(AbstractLinearTimeInvariantSDE):
  """A linear time-invariant SDE.  This is an stochastic
  differential equation of the form $dx_t = Fx_tdt + L dW_t.

  This class on its own only supports the following methods:

  - get_drift(t: Scalar, xt: Float[Array, 'D']) -> Float[Array, 'D']
    Returns the drift of the SDE

  - get_diffusion_coefficient(t: Scalar, xt: Float[Array, 'D']) -> Float[Array, 'D']
    Returns the diffusion coefficient

  - get_transition_distribution(self, s: Scalar, t: Scalar) -> AbstractTransition
    Returns the (Gaussian) transition distribution of the form
    $N(x_{t+s} | A_s x_t, \Sigma_s)$

  In order to sample, evaluate likelihoods, compute probability flows, etc,
  you need to introduce evidence (by introducing an initial point of prior distribution)
  using `build_conditional_sde` or `build_stochastic_bridge`.
  """
  F: AbstractMatrix
  L: AbstractMatrix

@dispatch
def _to_matrix(A: Float[Array, 'D']):
  return DiagonalMatrix(A, tags=TAGS.no_tags)

@dispatch
def _to_matrix(A: Float[Array, 'D D']):
  return DenseMatrix(A, tags=TAGS.no_tags)

def make_lti_sde(
  F: Union[Float[Array, 'D'], Float[Array, 'D D']],
  L: Union[Float[Array, 'D'], Float[Array, 'D D']]
) -> GeneralLTISDE:
  """Construct an LTI-SDE with the given parameters.

  **Arguments**

  - `F`: A 1 or 2 dimensional jax numpy array representing the
       matrix that appears in the drift term.
       If 1d, will be interpreted as a diagonal matrix.
  - `L`: Represents the diffusion coefficient.  If is 1d, will
       be interpreted as a diagonal matrix.

  **Returns**
  Returns a GeneralLTISDE object.  Pass this to `build_conditional_sde` or
  `build_stochastic_bridge` along with evidence in order to perform inference.
  """
  # Wrap in a Matrix object in order to do symbolic computation if
  # needed during inference.
  F, L = _to_matrix(F), _to_matrix(L)
  return GeneralLTISDE(F, L)


def build_conditional_sde(
  sde: AbstractLinearSDE,
  ts: Float[Array, 'T'],
  means: Float[Array, 'T D'],
  covs: Optional[Union[Float[Array, 'T D D'],
              Float[Array, 'T D']]] = None,
  inverse_covs: Optional[Union[Float[Array, 'T D D'],
              Float[Array, 'T D']]] = None,
  parallel: bool = False,
  parameterization: Optional[str] = 'natural'
) -> ConditionedLinearSDE:
  """Condition the input SDE on observations with uncertainty.
  The returned object can perform inference in the condition SDE.

  **Arguments**
  - `sde`: An SDE that we will condition.
  - `ts`: The times that we want to condition at.
  - `means`: The means of the potentials that we'll use to condition.
  - `covs`: The covariances of the potentials for conditioning.  If
            covs is 1d, then will use a diagonal matrix.
              * For total uncertainty, use jnp.inf
              * For total certainty, use 0.0
  - `inverse_covs`: The inverse covariances of the potentials for conditioning.
           If both covs and inverse_covs are provided, then we will set the inverse of
           inverse_covs to be covs when needed.
            * For total uncertainty, use 0.0
            * For total certainty, use jnp.inf
  - `parallel`: Whether to use parallel scan for the conditioning.
  - `parameterization`: Whether to use the 'natural', 'mixed', or 'standard' parameterization.
    Natural is the most numerically stable parameterization but does not handle exact conditioning.

  **Returns**
  A ConditionedLinearSDE object that represents `sde` conditioned on the observations
  given by $N(\mu,\Sigma)$.
  """


  if isinstance(ts, tuple) or isinstance(ts, list):
    assert len(ts) == len(means), 'ts and means must have the same length'
    ts = jnp.array(ts)
    means = jnp.array(means)
    if covs is not None:
      covs = jnp.array(covs)
    if inverse_covs is not None:
      inverse_covs = jnp.array(inverse_covs)

  if ts.ndim == 0:
    assert means.ndim == 1, 'means must be unbatched if ts is unbatched'
    ts = jnp.array([ts])
    means = jnp.array([means])
    if covs is not None:
      covs = jnp.array([covs])
    if inverse_covs is not None:
      inverse_covs = jnp.array([inverse_covs])

  assert ts.shape[0] == means.shape[0]
  assert ts.ndim == 1
  assert means.ndim == 2

  from diffusion_crf.timeseries import ProbabilisticTimeSeries
  # Construct the node potentials
  pts = ProbabilisticTimeSeries(ts,
                                means,
                                standard_deviation=covs,
                                certainty=inverse_covs,
                                parameterization=parameterization)
  return ConditionedLinearSDE(sde, pts, parallel=parallel)

################################################################################################################

if __name__ == '__main__':
  from debug import *
  from diffusion_crf.sde.simple_sdes import OrnsteinUhlenbeck

  N = 5
  x_dim = 2
  ts = jnp.linspace(0, 1, N)
  save_times = jnp.linspace(0, 1, 10)
  sigma = 1e-0
  t = 0.3
  xt = jnp.array([0.5, 0.5])

  # Create the parameters of the model
  key = random.PRNGKey(2)
  W = random.normal(key, (x_dim,))
  means = jnp.cos(2*jnp.pi*ts[:,None]*W[None,:])
  covs = jnp.ones_like(means)*1/1e-2
  inverse_covs = 1/covs
  sde = OrnsteinUhlenbeck(sigma, 0.1, x_dim)
  conditioned_sde = build_conditional_sde(sde, ts, means, inverse_covs=inverse_covs, parallel=True)

  # Try getting the local SDE at time t
  local_sde = conditioned_sde.get_local_sde_at_t(t)

  pxt1 = conditioned_sde.get_marginal(t)
  pxt2 = local_sde.get_marginal(t)

  # import pdb; pdb.set_trace()


  items1 = conditioned_sde.get_matching_items(t, xt)
  sampled_items1 = conditioned_sde.sample_matching_items(t, key)
  flow1 = conditioned_sde.get_flow(t, xt)

  node_potentials = conditioned_sde.node_potentials.to_nat()

  # Create the other model
  from latent_linear_sde import DiffusionCRF, DiagonalEagerLinearOperator
  from latent_linear_sde import NaturalGaussian as NatGaussian
  from latent_linear_sde import OrnsteinUhlenbeck as OU

  def make_potential(pot):
    J = DiagonalEagerLinearOperator(jnp.diag(pot.J.as_matrix()))
    return NatGaussian(J, pot.h)

  node_potentials = jax.vmap(make_potential)(node_potentials)
  sde2 = OU(sigma, 0.1, x_dim)
  dcrf = DiffusionCRF(sde2, ts, node_potentials)
  items2 = dcrf.get_sde_at_t(t).get_matching_items(t, xt)
  sampled_items2 = dcrf.get_sde_at_t(t).sample_matching_items(t, key)
  flow2 = dcrf.get_sde_at_t(t).get_flow(t, xt)
  import pdb; pdb.set_trace()