import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractSDE
from diffusion_crf.base import AbstractBatchableObject
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo, interleave_series
from diffusion_crf.ssm.simple_encoder import AbstractEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
from diffusion_crf import AbstractLinearSDE, build_conditional_sde, MixedGaussian, NaturalGaussian, ConditionedLinearSDE, interleave_times, ODESolverParams, ode_solve
from diffusion_crf.matrix import AbstractMatrix
from diffusion_crf.base import AbstractTransition, AbstractPotential
from diffusion_crf.gaussian.transition import GaussianTransition
import diffusion_crf.util as util
from diffusion_crf.crf import Messages, CRF
from diffusion_crf.gaussian import StandardGaussian, MixedGaussian, NaturalGaussian
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
from Models.models.ho_models.ho_base import AbstractBwdPredictorModel
from Models.models.base import CRFState, AbstractEncoderDecoderModel
from Models.models.transformer import MyModel, MyModelHypers
from Models.models.rnn import MyGRUModelHypers, MyGRUModel

################################################################################################################

class AbstractAutoregressiveBwdModel(AbstractBwdPredictorModel, abc.ABC):

  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder
  transformer: AbstractEncoderDecoderModel

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True) # If we don't predict the covariance, then we will solve for it by running message passing.
                                             # This is completely valid if we assume that the potential functions have a covariance
                                             # function that only depends on the time and not the value of our data.

  def make_crf_state(self, yts: TimeSeries) -> CRFState:
    """Make a CRF state for a given observed series.  Assume that yts is the entire observed (but maybe masked)
    sequence so that we know at what times we get potentials at."""
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.'
    yts = self.make_fully_observed_series(yts)
    return CRFState(self.linear_sde, self.encoder, yts, self.get_discretization_info(yts.ts), parameterization='mixed')

  @abc.abstractmethod
  def predict_next_backward_messages(
    self,
    yts: TimeSeries,
    xts: TimeSeries,
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> Float[Array, 'T D']:
    """Compute beta_{t_{k+1}}(x_{1:k})"""
    pass

  def predict_current_backward_messages(
    self,
    intermediate_xts: TimeSeries,
    yts: TimeSeries,
    xts: TimeSeries,
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> AbstractPotential:
    """Compute beta_{t}(x_{1:k}).  Assumes that each t is in between two points in xts.ts"""
    ts = intermediate_xts.ts

    # Each t represents a time before a time in xts.ts
    assert ts.shape[0] == self.generation_len - 1, f'ts must have length {self.generation_len - 1} but has length {ts.shape[0]}.'
    tkp1 = xts.ts[1:]

    # Predict the next backward messages
    predicted_bwd_node = self.predict_next_backward_messages(yts, xts, context=context)

    def get_continuous_extension(t, tkp1, bwd):
      base_transition = self.linear_sde.get_transition_distribution(t, tkp1)
      return base_transition.update_and_marginalize_out_y(bwd)

    out = jax.vmap(get_continuous_extension)(ts, tkp1, predicted_bwd_node)
    return out

  def predict_next_state_distribution(
    self,
    yts: TimeSeries, # The observed data
    xts: TimeSeries, # The series that we have already generated
    *,
    context: Optional[Float[Array, 'S C']] = None,
    crf_state: Optional[CRFState] = None,
  ) -> AbstractPotential:
    """Predict the next distribution for the series.  This is q(x_{i+1} | x_{l:i}, Y_{1:k}).

    Arguments:
      yts: The observed data
      xts: The series that we have already generated
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.
      crf_state: The precomputed state for the CRF.  If this is not provided, then we will
                 run message passing to compute it.

    Returns:
      bwdks: The next backward messages.
    """
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert len(xts) == self.generation_len, f'xts must have length {self.generation_len} but has length {len(xts)}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    assert xts.yts.shape[-1] == self.linear_sde.dim, "xts must have the same dimension as the SDE"

    if crf_state is None:
      crf_state = self.make_crf_state(yts) # This needs the entire observed sequence only for the times
    crf: CRF = crf_state.crf

    # Construct the predicted backward messages + node potentials
    predicted_bwd_node = self.predict_next_backward_messages(yts, xts, context=context)

    # Compute the predicted next distribution
    updated_transitions = crf.base_transitions[self.latent_generation_start_index:].unnormalized_update_y(predicted_bwd_node)
    next_distribution = updated_transitions.condition_on_x(xts.yts[:-1])
    return next_distribution

  def loss_fn(
    self,
    yts: TimeSeries,
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Dict[str, Scalar]:
    assert len(yts) == self.obs_seq_len, f'series must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    # Construct the CRF for the observed series and precompute the forward and backward messages
    crf_state = self.make_crf_state(yts)
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages
    info: DiscretizeInfo = crf_state.discretization_info

    if info.ts.shape[0] != self.latent_seq_len:
      raise ValueError(f'info.ts must have length {self.latent_seq_len} but has length {info.ts.shape[0]}.')

    # Mask the unobserved parts of the series for sanity checking.
    # This doesn't affect anything if we implemented the model correctly.
    yts = self.mask_observation_space_sequence(yts)

    # Sample from p(x_{1:N} | Y_{1:N})
    xts_values = crf.sample(key, messages=messages)
    xts = TimeSeries(info.ts, xts_values)
    xts_generation_buffer = xts[self.latent_generation_start_index:] # Get x_{l:N}

    # Compute the true smoothed transitions, p(x_{i+1} | x_{l:i}, Y_{1:N})
    # and condition on the previous latent variable to get the next distribution
    bwd = crf_state.messages.bwd
    bwd_and_node = bwd + crf.node_potentials
    true_next_bwd_and_node = bwd_and_node[self.latent_generation_start_index + 1:]
    assert true_next_bwd_and_node.batch_size == self.generation_len - 1

    # Predict our models next backward messages, beta_{t_{k+1}}(x_{1:k})
    predicted_next_bwd_and_node = self.predict_next_backward_messages(yts, xts_generation_buffer)

    # Compute the potential matching loss
    true_next_bwd_and_node_mixed = true_next_bwd_and_node.to_mixed()
    predicted_next_bwd_and_node_mixed = predicted_next_bwd_and_node.to_mixed()
    true_mean, predicted_mean = true_next_bwd_and_node_mixed.mu, predicted_next_bwd_and_node_mixed.mu

    mu_diff = (true_mean - predicted_mean)
    mse_loss = jnp.mean(mu_diff**2)

    if debug:
      import pdb; pdb.set_trace()

    losses = dict(mse=mse_loss)
    return losses

  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    *,
    crf_state: Optional[CRFState] = None,
    debug: Optional[bool] = False,
    **kwargs
  ) -> TimeSeries:
    """Sample from p(x_{l:N} | Y_{1:k}) autoregressively.  This involves first sampling from p(x_l | Y_{1:k}) and then iterating to sample from p(x_{i+1} | x_{l:i}, Y_{1:k})"""
    assert len(yts) == self.obs_seq_len, f'yts must be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, "yts must have the same dimension as the observed data"

    if crf_state is None:
      yts = self.mask_observation_space_sequence(yts) # This is to ensure that the samples don't depend on unobserved values
      crf_state = self.make_crf_state(yts)
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages
    info: DiscretizeInfo = crf_state.discretization_info

    # Generate the buffer of the predicted latent series
    xts_values = crf.sample(key, messages=messages)

    # We will assume that the first latent variable comes from the correct prior distribution.
    # See the discussion in the docstring for AbstractAutoregressiveModel for more details.
    mask = jnp.zeros(xts_values.shape[0], dtype=bool)
    mask = mask.at[:self.latent_generation_start_index+1].set(True)

    # Turn xts_values into a TimeSeries
    full_xts_buffer = TimeSeries(info.ts, xts_values, observation_mask=mask)
    xts_buffer = full_xts_buffer[self.latent_generation_start_index:]

    # Encode the encoder series context for our transformer
    yts_cond = yts[:self.cond_len]
    context = self.transformer.create_context(yts_cond)

    # Create a new crf_state that will get the target covariances for us
    dummy_yts = TimeSeries(yts.ts, jnp.zeros_like(yts.yts))
    crf_state = self.make_crf_state(dummy_yts)

    #########################################################
    # Autoregressive sampling
    #########################################################
    def scan_body(carry, inputs, debug=False):
      xts_buffer = carry
      key, k = inputs

      # Predict all of the transitions.  The model is autoregressive so we don't need to worry about
      # masking xts_buffer or anything like that.  Also, we cut off the unobserved parts of yts inside
      # the predict_next_state_distribution function.
      transitions = self.predict_next_state_distribution(yts, xts_buffer, context=context, crf_state=crf_state)

      # Get p(x_{i+1} | x_k)
      transition = transitions[k]

      # Sample x_{i+1} ~ p(x_{i+1} | x_k)
      xkp1 = transition.sample(key)

      # Update the buffer with the new prediction
      new_yts = util.fill_array(xts_buffer.yts, k+1, xkp1)

      # Update the observation mask to reflect the new prediction
      new_observation_mask = util.fill_array(xts_buffer.observation_mask, k+1, True)

      # Create the new time series
      new_xts = TimeSeries(xts_buffer.ts, new_yts, new_observation_mask)
      if debug:
        import pdb; pdb.set_trace()
      return new_xts, new_xts

    carry = xts_buffer
    keys = random.split(key, len(xts_buffer.ts)-1)
    inputs = (keys, jnp.arange(len(xts_buffer.ts)-1))

    if debug:
      for i, item in enumerate(zip(*inputs)):
        carry, _ = scan_body(carry, item, debug=True)
      generated_latent_seq = carry
      import pdb; pdb.set_trace()
    else:
      generated_latent_seq, all_xts_buffer = jax.lax.scan(scan_body, carry, inputs)

    # Concatenate the original latent sequence as well
    prev_latent_seq = full_xts_buffer[:self.latent_generation_start_index]
    generated_latent_seq = jtu.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), prev_latent_seq, generated_latent_seq)

    assert len(generated_latent_seq) == self.latent_seq_len, f'output must have length {self.latent_seq_len} but has length {len(generated_latent_seq)}.'
    return generated_latent_seq

################################################################################################################

class MyReparameterizedAutoregressiveRNNBwdModel(AbstractAutoregressiveBwdModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True)

  potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'] = eqx.field(static=True)
  parametrization: Literal['nat', 'mixed', 'std'] = eqx.field(static=True)
  matrix_dim: int = eqx.field(static=True)
  dim: int = eqx.field(static=True)


  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               hidden_size: int,
               *,
               potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'],
               parametrization: Literal['nat', 'mixed', 'std'],
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               latent_cond_len: int,
               key: PRNGKeyArray,
               predict_cov: bool = False):
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.potential_cov_type = potential_cov_type
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = latent_cond_len*(1 + self.interpolation_freq) # Multiple by the interpolation frequency to account for the fact that we are predicting at the original frequency
    self.parametrization = parametrization
    assert predict_cov == False, 'We don\'t support predicting the covariance for the reparametrized model'
    self.predict_cov = predict_cov

    assert self.latent_generation_start_index >= 0, f'latent_generation_start_index must be non-negative but is {self.latent_generation_start_index}.'

    print(f'Generation length: {self.generation_len}, latent_seq_len: {self.latent_seq_len}, latent_generation_start_index: {self.latent_generation_start_index}')

    dim = self.linear_sde.dim

    if potential_cov_type == 'diagonal':
      matrix_dim = dim
    elif potential_cov_type == 'dense':
      matrix_dim = dim**2
    elif potential_cov_type == 'block2x2':
      matrix_dim = dim*2
      assert dim%2 == 0
    elif potential_cov_type == 'block3x3':
      matrix_dim = dim*3
      assert dim%3 == 0

    out_size = dim

    self.matrix_dim = matrix_dim
    self.dim = dim

    # Create the transformer
    hypers = MyGRUModelHypers(hidden_size=hidden_size)
    self.transformer = MyGRUModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=out_size,
                               hypers=hypers,
                               key=key)

  def predict_next_backward_messages(
    self,
    yts: TimeSeries, # The observed data
    xts: TimeSeries, # The series that we have already generated
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> AbstractPotential:
    """Predict the next distribution for the series.

    Arguments:
      yts: The observed data
      xts: The series that we have already generated
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      bwdks: The next distribution.  This will have the same length as the
             upsampled prediction length, which is
             (1 + self.interpolation_freq)*self.pred_len
    """
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert len(xts) == self.generation_len, f'xts must have length {self.generation_len} but has length {len(xts)}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    assert xts.yts.shape[-1] == self.linear_sde.dim, "xts must have the same dimension as the SDE"

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Create the context
    if context is None:
      context = self.transformer.create_context(yts_cond)

    # Predict the backward messages parameters
    outputs = self.transformer(yts_cond, xts, context)
    assert outputs.shape[0] == self.generation_len

    #####################################
    # Construct the covariance matrices
    #####################################

    def construct_potential(vec: Float[Array, 'D'], mat: AbstractMatrix) -> AbstractPotential:
      # Helper function to construct a potential from a vector and a matrix
      if self.parametrization == 'nat':
        return NaturalGaussian(mat, vec)
      elif self.parametrization == 'mixed':
        return MixedGaussian(vec, mat)
      elif self.parametrization == 'std':
        return StandardGaussian(vec, mat)

    crf_state = self.make_crf_state(yts) # This needs the entire observed sequence only for the times
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages
    info: DiscretizeInfo = crf_state.discretization_info

    # Compute the backward messages.  The means of this are not correct
    # but the covariances are correct because of our assumption that the potential
    # functions have a covariance that only depends on the time and not the value of
    # the data.  This assumption works because the smoothed mean has a linear relationship
    # with Y, but the covariance has a nonlinear relationship with Y due to matrix inverses
    # during message passing.
    backward_messages = crf_state.messages.bwd
    bwd_node = backward_messages + crf.node_potentials
    bwd_node = bwd_node[self.latent_generation_start_index + 1:]

    # Depending on the parametrization, we need to convert the covariance matrix
    T = outputs.shape[0]
    if self.parametrization == 'nat':
      mat = bwd_node.to_nat().J
    elif self.parametrization == 'mixed':
      mat = bwd_node.to_mixed().J
    elif self.parametrization == 'std':
      mat = bwd_node.to_std().Sigma

    # Construct the predicted backward messages + node potentials
    predicted_bwd_node = jax.vmap(construct_potential)(outputs[:-1], mat)
    assert predicted_bwd_node.batch_size == self.generation_len - 1
    return predicted_bwd_node

################################################################################################################


if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import *
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.neural_diffusion_crf.interpolate_potential import InterpolateResult, initialize_potential_interpolation
  import matplotlib.pyplot as plt
  # turn on x64
  # jax.config.update('jax_enable_x64', True)

  # blah()


  N = 20
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts, series.yts

  freq = 0

  key = random.PRNGKey(0)

  # Create the SDE and emission potential encoder
  y_dim = yts.shape[1]
  # sde = BrownianMotion(sigma=0.1, dim=y_dim)
  # sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  sde = HigherOrderTrackingModel(sigma=0.1, position_dim=y_dim, order=2)
  sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=1/(1 + freq))

  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.1,
                                                  use_prior=False)
  covariance_matrix = encoder(series).node_potentials[0].to_mixed().J

  I = sde.F.eye(sde.F.shape[0])
  prior = StandardGaussian(mu=jnp.zeros(sde.dim), Sigma=I)

  from diffusion_crf.matrix import *
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  # model = MyAutoregressiveModel(sde,
  model = MyReparameterizedAutoregressiveRNNBwdModel(sde,
                                encoder,
                                hidden_size=16,
                                potential_cov_type=potential_cov_type,
                                parametrization='mixed',
                                interpolation_freq=freq,
                                seq_len=len(series),
                                cond_len=6,
                                latent_cond_len=1,
                                key=key,
                                predict_cov=False)

  # Evaluate the loss function
  out = model.loss_fn(series, key=key, debug=True)

  # Sample a series
  sampled_series = model.sample(key, series, debug=False)

  # Check that the samples don't depend on unobserved values
  series2 = model.mask_observation_space_sequence(series)
  sampled_series2 = model.sample(key, series2, debug=False)
  assert jnp.allclose(sampled_series.yts, sampled_series2.yts)

  xts_generated = sampled_series[model.latent_generation_start_index:]
  def output(xts_values):
    xts = TimeSeries(ts=xts_generated.ts, yts=xts_values)
    out_dist = model.predict_next_state_distribution(series, xts)
    return out_dist.mu

  means = output(xts_generated.yts)

  # Compute the Jacobian
  J = eqx.filter_jacfwd(output)(xts_generated.yts)
  J_flat = J.sum(axis=(1, 3))

  # Check that the Jacobian is lower triangular
  assert jnp.allclose(J_flat, jnp.tril(J_flat))


  # Check downsampling
  latent_seq = model.downsample_seq_to_original_freq(sampled_series)
  assert len(latent_seq) == len(series)
  import pdb; pdb.set_trace()

