import jax
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
import diffusion_crf
import jax.random as random
import jax.numpy as jnp
from diffusion_crf.base import AbstractBatchableObject, AbstractPotential
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries
from diffusion_crf.sde import AbstractLinearSDE, BrownianMotion, OrnsteinUhlenbeck, CriticallyDampedLangevinDynamics
from diffusion_crf.ssm.simple_encoder import AbstractEncoder
from diffusion_crf.crf import Messages
from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
import equinox as eqx
from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior, IdentityEncoder

class StochasticLatentInterpolator(AbstractBatchableObject):

  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.linear_sde.batch_size

  def interpolate(self,
                  key: PRNGKeyArray,
                  ts: Float[Array, 'S'],
                  yts: Float[Array, 'S D'],
                  interpolation_times: Float[Array, 'TminusS'],
                  return_backwards: bool = False,
                  return_potentials: bool = False) -> Union[Float[Array, 'T D'], Dict[str, Any]]:
    """Get everything that we would want to know about the stochastic interpolation
    at the times in interpolation_times.

    **Arguments**:
      - ts: The times at which the original data is observed.
      - yts: The original data observed at the times in ts.
      - interpolation_times: The extra times at which we want to know about the interpolation.
           Note that this function returns samples at the union of ts and interpolation_times.

    **Returns**: A dictionary with the following keys:
      - samples: The samples from the stochastic interpolation at the times in
        interpolation_times.
      - flow: The probability flow vector field evaluated at each sample.
      - score: The score of the samples from the stochastic interpolation at the
        times in interpolation_times.
      - drift: The drift vector field evaluated at each sample.
      - bwd: The natural parameters of the backward message.
      - potentials: The natural parameters of the potential functions.
    """
    ts = jnp.array(ts)
    yts = jnp.array(yts)
    interpolation_times = jnp.array(interpolation_times)

    series = TimeSeries(ts, yts)
    prob_series = self.encoder(series)
    potentials = prob_series.node_potentials

    # Construct the CRF p(x | \Theta(Y))
    parallel = jax.devices()[0].platform == 'gpu'
    cond_sde = ConditionedLinearSDE(self.linear_sde, prob_series, parallel=parallel)

    # Discretize the CRF so that we get a sequence of xts at the times in prob_series.ts
    # and also the times in self.discretization_times
    result = cond_sde.discretize(interpolation_times)
    crf, interleaved_times = result.crf, result.info
    all_ts = interleaved_times.ts

    # Compute the backward messages
    bwd = crf.get_backward_messages()
    messages = Messages(fwd=None, bwd=bwd)

    # Sample from the conditioned SDE
    xts = crf.sample(key, messages=messages)

    if return_backwards and return_potentials:
      return all_ts, xts, bwd, potentials
    elif return_backwards:
      return all_ts, xts, bwd
    elif return_potentials:
      return all_ts, xts, potentials
    else:
      return all_ts, xts

def make_stochastic_interpolator(*,
                                 base_sde: Literal['brownian',
                                                   'ornstein_uhlenbeck'],
                                 sde_args: Dict[str, Any],
                                 observation_dim: int,
                                 sigma: Union[None,float] = None):
  return make_latent_stochastic_interpolator(base_sde=base_sde,
                                             sde_args=sde_args,
                                             observation_dim=observation_dim,
                                             latent_dim=observation_dim,
                                             sigma=sigma)

def make_latent_stochastic_interpolator(*,
                                        base_sde: Literal['brownian',
                                                          'ornstein_uhlenbeck',
                                                          'langevin'],
                                        sde_args: Dict[str, Any],
                                        observation_dim: int,
                                        latent_dim: int,
                                        sigma: Union[None,float] = None):

  if sigma is None:
    assert observation_dim == latent_dim
    encoder = IdentityEncoder(observation_dim)
  else:
    encoder = PaddingLatentVariableEncoderWithPrior(y_dim=observation_dim,
                                                  x_dim=latent_dim,
                                                  sigma=sigma)

  # Get the linear SDE
  if base_sde == 'brownian':
    linear_sde = BrownianMotion(sigma=sde_args['sigma'], dim=latent_dim)
  elif base_sde == 'ornstein_uhlenbeck':
    linear_sde = OrnsteinUhlenbeck(sigma=sde_args['sigma'],
                            lambda_=sde_args['lambda'],
                            dim=latent_dim)
  elif base_sde == 'langevin':
    linear_sde = CriticallyDampedLangevinDynamics(mass=sde_args['mass'],
                                            beta=sde_args['beta'],
                                            dim=latent_dim)
  else:
    raise ValueError(f"Invalid base SDE: {base_sde}")

  return StochasticLatentInterpolator(linear_sde, encoder)

################################################################################################################

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  from diffusion_crf.sde.sde_base import linear_sde_test
  import matplotlib.pyplot as plt
  import diffusion_crf.util as util
  from diffusion_crf.gaussian.dist import MixedGaussian
  from diffusion_crf.timeseries import TimeSeries
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  import pickle
  from diffusion_crf.sde import *

  N = 5
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts, series.yts

  y_dim = series.observation_dim

  sde = BrownianMotion(sigma=0.1, dim=y_dim)
  # sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  # sde = HigherOrderTrackingModel(sigma=0.1, position_dim=y_dim, order=2)
  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.01)
  decoder = PaddingLatentVariableDecoder(y_dim=y_dim,
                                        x_dim=sde.dim)

  interpolator = StochasticLatentInterpolator(sde, encoder)

  # Make times to interpolate between
  interpolation_times = util.get_times_to_interleave_for_upsample(data_times, 2)
  import pdb; pdb.set_trace()

  key = random.PRNGKey(0)
  result = interpolator.interpolate(key, data_times, yts, interpolation_times)

  interpolator = make_stochastic_interpolator(base_sde='brownian',
                                              sde_args=dict(sigma=0.1),
                                              observation_dim=y_dim,
                                              sigma=0.01)

  out = interpolator.interpolate(key, data_times, yts, interpolation_times)

  import pdb; pdb.set_trace()
