import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type, Dict
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from diffusion_crf.base import *
from diffusion_crf.crf import *
from diffusion_crf.continuous_crf import *
from diffusion_crf.matrix.matrix_base import *
from diffusion_crf.gaussian.dist import *
from diffusion_crf.gaussian.transition import *
from diffusion_crf.matrix.matrix_with_inverse import MatrixWithInverse
from diffusion_crf.sde.sde_base import AbstractLinearSDE, AbstractLinearTimeInvariantSDE
from plum import dispatch
import diffusion_crf.util as util
from diffusion_crf.timeseries import ProbabilisticTimeSeries, interleave_series

__all__ = ['ConditionedLinearSDE']

class ConditionedLinearSDE(AbstractLinearSDE, AbstractContinuousCRF):

  sde: AbstractLinearSDE
  probabilistic_time_series: ProbabilisticTimeSeries

  parallel: bool = eqx.field(static=True)

  def __init__(
    self,
    sde: AbstractLinearSDE,
    probabilistic_time_series: ProbabilisticTimeSeries,
    parallel: Optional[bool] = None
  ):
    assert isinstance(sde, AbstractLinearSDE)
    assert isinstance(probabilistic_time_series, ProbabilisticTimeSeries)
    if isinstance(sde, ConditionedLinearSDE):
      # Then combine the two SDEs
      self.sde = sde.sde
      self.probabilistic_time_series = interleave_series(sde.probabilistic_time_series, probabilistic_time_series)
    else:
      self.sde = sde
      self.probabilistic_time_series = probabilistic_time_series
    if parallel is None:
      parallel = jax.devices()[0].platform == 'gpu'
    self.parallel = parallel

  def get_base_transition_distribution(self, s: Float[Array, 'D'], t: Float[Array, 'D']) -> AbstractTransition:
    return self.sde.get_transition_distribution(s, t)

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.probabilistic_time_series.batch_size

  @property
  def dim(self) -> int:
    return self.sde.dim

  def get_params(
    self,
    t: Scalar,
    *,
    messages: Optional[Messages] = None
  ) -> Tuple[Float[Array, 'D D'],
             Float[Array, 'D'],
             Float[Array, 'D D']]:
    F, u, L = self.sde.get_params(t)
    # F, u, L = self.sde.F, self.sde.u, self.sde.L
    LLT = L@L.T

    # Discretize and then get the backward messages
    crf_result = self.discretize(t[None])
    crf, info = crf_result.crf, crf_result.info

    # Get the index of the new time
    assert info.new_indices.size == 1
    new_index = info.new_indices[0]

    # Get the backward messages
    messages = Messages.from_messages(messages, crf, need_bwd=True)
    bwd = messages.bwd

    # Convert to natural parameters
    bwdt = bwd[new_index]
    if isinstance(bwdt, NaturalGaussian) == False:
      bwdt = bwdt.to_nat()

    F_cond = F - LLT@bwdt.J
    u_cond = u + LLT@bwdt.h

    return F_cond, u_cond, L

  def get_drift(
    self,
    t: Scalar,
    xt: Float[Array, 'D'],
    *,
    messages: Optional[Messages] = None
  ) -> Float[Array, 'D']:
    F, u, L = self.get_params(t, messages=messages)
    return F@xt + u

  def get_diffusion_coefficient(
    self,
    t: Scalar,
    xt: Float[Array, 'D'],
    *,
    messages: Optional[Messages] = None
  ) -> AbstractMatrix:
    _, _, L = self.get_params(t, messages=messages)
    return L

  def get_transition_distribution(
    self,
    s: Scalar,
    t: Scalar,
    *,
    messages: Optional[Messages] = None
  ) -> GaussianTransition:
    st = jnp.array([s, t])

    # Discretize and then get the backward messages
    crf_result = self.discretize(st)
    crf, info = crf_result.crf, crf_result.info
    transitions = crf.get_transitions(messages=messages)

    # Get the mask for the parallel scan that says when to reset the scan
    segment_ends = info.new_indices
    reset_mask = jnp.arange(len(crf))[:,None] == segment_ends[None,:]
    reset_mask = reset_mask.any(axis=1)

    # Perform the parallel scan
    def operator(right: AbstractTransition, left: AbstractTransition) -> AbstractTransition:
      return left.chain(right) # The transitions are in reverse order

    transitions = util.parallel_segmented_scan(operator, transitions[::-1], reset_mask[::-1])[::-1]
    return transitions[segment_ends[1]]

  def get_marginal(
    self,
    t: Scalar,
    return_messages: bool = False,
    *,
    messages: Optional[Messages] = None
  ) -> Union[AbstractPotential, Tuple[AbstractPotential, AbstractPotential, AbstractPotential]]:
    t = jnp.array(t)

    # Discretize and then get the backward messages
    crf_result = self.discretize(t[None])
    crf, info = crf_result.crf, crf_result.info

    # This is the index of the node at time t
    index = info.new_indices[0]

    # Get the marginals.  Compute the messages out here to avoid recomputing them
    messages = Messages.from_messages(messages, crf, need_fwd=True, need_bwd=True)
    marginals = crf.get_marginals(messages=messages)
    pxt = marginals[index]

    if return_messages:
      return pxt, messages[index]
    else:
      return pxt

  def get_idx_before_t(self, t: Scalar) -> int:
    return jnp.searchsorted(self.ts, t, side='right') - 1

  def get_local_sde_at_t(
    self,
    t: Scalar,
    *,
    messages: Optional[Messages] = None
  ) -> 'ConditionedLinearSDE':
    """
    Get the local SDE at time t.  This SDE will be conditioned on only 2 node potentials
    and will have the same distribution as this SDE at the times in between the node potentials.

    This saves having to do message passing again when computing the probability flow ODE.
    """
    # Discretize and then get the messages
    crf = self.discretize()
    messages = Messages.from_messages(messages, crf, need_fwd=True, need_bwd=True)

    # Find the index of the node at time t
    idx = self.get_idx_before_t(t)

    # Get the potentials for the ends of the interval that t is in
    fwd_prior = messages.fwd[idx] + self.node_potentials[idx]
    bwd_prior = messages.bwd[idx+1] + self.node_potentials[idx+1]

    # Concatenate them and make a new SDE
    new_potentials = jtu.tree_map(lambda f, b: jnp.concatenate([f[None],b[None]], axis=0), fwd_prior, bwd_prior)
    new_ts = jnp.array([self.ts[idx], self.ts[idx+1]])

    pts = ProbabilisticTimeSeries.from_potentials(new_ts, new_potentials)
    local_sde = ConditionedLinearSDE(self.sde, pts)
    return local_sde

  def get_flow(
    self,
    t: Scalar,
    xt: Float[Array, 'D'],
    method: str = 'score',
    *,
    messages: Optional[Messages] = None
  ) -> Float[Array, 'D']:
    t = jnp.array(t)

    if self.ts.shape[-1] > 2:
      # This avoids needing to do message passing when we do a jvp
      local_sde = self.get_local_sde_at_t(t, messages=messages)
      return local_sde.get_flow(t, xt, method=method)

    # Get the base drift and diffusion coefficient
    vt = self.sde.get_drift(t, xt)
    L = self.sde.get_diffusion_coefficient(t, xt)
    LLT = L@L.T

    if method == 'jvp':
      # Get the noise that generated xt
      pxt = self.get_marginal(t, messages=messages)
      noise = pxt.get_noise(xt)

      def sample(t):
        pxt = self.get_marginal(t)
        return pxt._sample(noise)

      xt2, dxtdt = jax.jvp(sample, (t,), (jnp.ones_like(t),))
      return dxtdt

    elif method == 'score':
      crf_result = self.discretize(t[None])
      crf, info = crf_result.crf, crf_result.info
      index = info.new_indices[0]

      # Get the forward and backward messages
      messages = Messages.from_messages(messages, crf, need_fwd=True, need_bwd=True)
      fwd = messages.fwd
      bwd = messages.bwd
      fwdt, bwdt = fwd[index], bwd[index]

      # Get the transition distribution
      return vt + 0.5*LLT@(bwdt.score(xt) - fwdt.score(xt))

  def get_matching_items(
    self,
    t: Scalar,
    xt: Float[Array, 'D'],
    *,
    messages: Optional[Messages] = None
  ) -> Mapping[str, Float[Array, 'D']]:
    t = jnp.array(t)

    if self.ts.shape[-1] > 2:
      # This avoids needing to do message passing when we do a jvp
      local_sde = self.get_local_sde_at_t(t, messages=messages)
      return local_sde.get_matching_items(t, xt)

    #############
    # Get the marginal distribution at time t
    #############
    # Discretize and then get the backward messages
    crf_result = self.discretize(t[None])
    crf, info = crf_result.crf, crf_result.info

    # This is the index of the node at time t
    index = info.new_indices[0]

    # Get the marginals.  Compute the messages out here to avoid recomputing them
    messages = Messages.from_messages(messages, crf, need_fwd=True, need_bwd=True)
    marginals = crf.get_marginals(messages=messages)
    pxt = marginals[index]
    message_t = messages[index]
    fwd, bwd = message_t.fwd, message_t.bwd

    #############
    # Reparametrize xt
    #############
    noise = pxt.get_noise(xt)
    score = pxt.score(xt)
    fwd_score, bwd_score = fwd.score(xt), bwd.score(xt)

    def sample(t):
      pxt = self.get_marginal(t) # Can't pass in messages here because we need to recompute them for the jvp
      xt = pxt._sample(noise)
      return xt

    _, dxtdt = jax.jvp(sample, (t,), (jnp.ones_like(t),))
    drift = self.get_drift(t, xt, messages=messages)
    return dict(xt=xt,
                flow=dxtdt,
                score=score,
                drift=drift,
                pxt=pxt,
                fwd_score=fwd_score,
                bwd_score=bwd_score)

  def sample_matching_items(
    self,
    t: Scalar,
    key: PRNGKeyArray,
    *,
    messages: Optional[Messages] = None
  ) -> Dict[str, Float[Array, 'D']]:
    t = jnp.array(t)

    if self.ts.shape[-1] > 2:
      # This avoids needing to do message passing when we do a jvp
      local_sde = self.get_local_sde_at_t(t, messages=messages)
      return local_sde.sample_matching_items(t, key)

    def sample(t):
      # Discretize and then get the backward messages
      crf_result = self.discretize(t[None])
      crf, info = crf_result.crf, crf_result.info

      # This is the index of the node at time t
      index = info.new_indices[0]

      # Get the marginals
      fwd, bwd = crf.get_forward_messages(), crf.get_backward_messages()
      messages = Messages(fwd, bwd)
      message_t = messages[index]

      # Get the marginal distribution and sample from it
      marginals = crf.get_marginals(messages=messages)
      pxt = marginals[index]
      xt = pxt.sample(key)
      return xt, (pxt, messages, message_t)

    xt, dxtdt, (pxt, messages, message_t) = jax.jvp(sample, (t,), (jnp.ones_like(t),), has_aux=True)
    fwd, bwd = message_t.fwd, message_t.bwd

    score = pxt.score(xt)
    fwd_score, bwd_score = fwd.score(xt), bwd.score(xt)
    drift = self.get_drift(t, xt, messages=messages)
    return dict(xt=xt,
                flow=dxtdt,
                score=score,
                drift=drift,
                pxt=pxt,
                fwd_score=fwd_score,
                bwd_score=bwd_score)

  def multi_sample_matching_items(
    self,
    ts: Float[Array, 'T'],
    key: PRNGKeyArray,
    *,
    messages: Optional[Messages] = None
  ) -> Dict[str, Float[Array, 'T D']]:

    # Discretize the SDE at the self.times and ts
    crf_result = self.discretize(ts)
    crf, info = crf_result.crf, crf_result.info

    # Get the forward and backward messages
    fwd = crf.get_forward_messages()
    bwd = crf.get_backward_messages()
    messages = Messages(fwd, bwd)

    # Sample a trajectory and get the marginals
    xts = crf.sample(key, messages=messages)
    marginals = crf.get_marginals(messages=messages)

    # Compute the flow at each point
    def get_items(t, xt, fwd, bwd, pxt):
      vt = self.sde.get_drift(t, xt)
      L = self.sde.get_diffusion_coefficient(t, xt)
      LLT = L@L.T
      fwd_score, bwd_score = fwd.score(xt), bwd.score(xt)
      flow = vt + 0.5*LLT@(bwd_score - fwd_score)
      drift = vt + LLT@bwd_score
      return FlowItems(t=t,
                       xt=xt,
                       flow=flow,
                       score=pxt.score(xt),
                       drift=drift,
                       fwd_score=fwd_score,
                       bwd_score=bwd_score)

    items = jax.vmap(get_items)(info.ts, xts, fwd, bwd, marginals)
    return items

class FlowItems(AbstractBatchableObject):
  t: Scalar
  xt: Float[Array, 'D']
  flow: Float[Array, 'D']
  score: Float[Array, 'D']
  drift: Float[Array, 'D']
  fwd_score: Float[Array, 'D']
  bwd_score: Float[Array, 'D']

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self.xt.ndim == 1:
      return None
    elif self.xt.ndim == 2:
      return self.xt.shape[0]
    elif self.xt.ndim > 2:
      return self.xt.shape[:-1]
    else:
      raise ValueError(f'xt has {self.xt.ndim} dimensions')

################################################################################################################

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  from diffusion_crf.sde.sde_base import linear_sde_test
  import matplotlib.pyplot as plt
  import diffusion_crf.util as util
  from diffusion_crf.gaussian.dist import MixedGaussian
  from diffusion_crf.timeseries import TimeSeries
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.tracking import HigherOrderTrackingModel
  jax.config.update('jax_enable_x64', True)

  import pickle
  data = pickle.load(open('data_dump.pkl', 'rb'))
  ts, yts, observation_mask = data['ts'], data['yts'], data['observation_mask']
  series = TimeSeries(ts, yts, observation_mask)

  y_dim = series.observation_dim
  key = jax.random.PRNGKey(0)

  # sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  sde = HigherOrderTrackingModel(sigma=0.1, position_dim=y_dim, order=2)
  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.01)
  decoder = PaddingLatentVariableDecoder(y_dim=y_dim,
                                        x_dim=sde.dim)
  prob_series = encoder(series)
  cond_sde = ConditionedLinearSDE(sde, prob_series)

  crf = cond_sde.discretize()

  new_times = util.get_times_to_interleave_for_upsample(ts, 1)
  items = cond_sde.multi_sample_matching_items(new_times, key)



  import pdb; pdb.set_trace()