import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractSDE
from diffusion_crf.base import AbstractBatchableObject
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo, interleave_series
from diffusion_crf.ssm.simple_encoder import AbstractEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
from diffusion_crf import AbstractLinearSDE, build_conditional_sde, MixedGaussian, NaturalGaussian, ConditionedLinearSDE, interleave_times, ODESolverParams, ode_solve
from diffusion_crf.matrix import AbstractMatrix
from diffusion_crf.base import AbstractTransition, AbstractPotential
from diffusion_crf.gaussian.transition import GaussianTransition
import diffusion_crf.util as util
from diffusion_crf.crf import Messages, CRF
from diffusion_crf.gaussian import StandardGaussian, MixedGaussian, NaturalGaussian
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
from Models.models.base import AbstractModel, CRFState
from Models.models.base import AbstractEncoderDecoderModel
from Models.models.transformer import MyModel, MyModelHypers
from Models.models.rnn import MyGRUModelHypers, MyGRUModel

"""
This module implements autoregressive models for time series forecasting within the neural diffusion CRF framework. It provides:

1. AbstractAutoregressiveModel - A base class defining the autoregressive sampling and training process
   where each state is predicted conditioned on previous states and observations

2. MyAutoregressiveModel - A concrete implementation using a transformer-based architecture
   with support for different covariance parameterizations and prediction strategies

The autoregressive approach generates samples by sequentially predicting distributions for
each time step conditioned on previous states, allowing for efficient one-pass generation
while maintaining consistency with the underlying SDE dynamics.
"""

################################################################################################################

class AbstractAutoregressiveModel(AbstractModel, abc.ABC):
  """Autoregressive model for q(x_{l:N} | Y_{1:k}).  We generate samples from the process:
          x_l ~ q(x_l | Y_{1:k})
          x_{i+1} ~ q(x_{i+1} | x_{l:i}, Y_{1:k}) = N(x_{i+1}; \mu_{i+1}(x_{l:i}, Y_{1:k}), \Sigma_{i+1}(x_{l:i}, Y_{1:k}))

  x_l is the l'th element of the latent sequence and is found by latent_seq_len - generation_len

  We don't explicitly model the prior distribution but instead use an approximation of it by
  sampling from the marginal distribution of a CRF constructed with node potentials coming
  from only Y_{1:k}.  Technically this is wrong because the true distribution is
  p(x_l | Y_{1:k}) = \int p(x_l | Y_{1:N})\mu(Y_{k+1:N} | Y_{1:k}) dY_{k+1:N},
  but we should expect this to be a good approximation because the value of x_l should be
  determined pretty well by Y_{1:k} if k is large enough.

  We also might not actually predict the covariance matrix but instead use the
  covariance coming from the CRF.  This is valid if we assume that the potential
  functions have a covariance that only depends on the time and not the value of
  our data.
  """
  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder
  transformer: AbstractEncoderDecoderModel

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True) # If we don't predict the covariance, then we will solve for it by running message passing.
                                             # This is completely valid if we assume that the potential functions have a covariance
                                             # function that only depends on the time and not the value of our data.

  @abc.abstractmethod
  def predict_next_state_distribution(
    self,
    yts: TimeSeries, # The observed data
    xts: TimeSeries, # The series that we have already generated
    *,
    context: Optional[Float[Array, 'S C']] = None,
    crf_state: Optional[CRFState] = None,
  ) -> AbstractPotential:
    """Predict the next distribution for the series.  This is q(x_{i+1} | x_{l:i}, Y_{1:k}).

    Arguments:
      yts: The observed data
      xts: The series that we have already generated
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.
      crf_state: The precomputed state for the CRF.  If this is not provided, then we will
                 run message passing to compute it.

    Returns:
      bwdks: The next distribution.  This will have the same length as the
             upsampled prediction length, which is
             (1 + self.interpolation_freq)*self.pred_len
    """
    raise NotImplementedError

  def loss_fn(
    self,
    yts: TimeSeries,
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Dict[str, Scalar]:
    assert len(yts) == self.obs_seq_len, f'series must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    # Construct the CRF for the observed series and precompute the forward and backward messages
    crf_state = self.make_crf_state(yts)
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages
    info: DiscretizeInfo = crf_state.discretization_info

    if info.ts.shape[0] != self.latent_seq_len:
      raise ValueError(f'info.ts must have length {self.latent_seq_len} but has length {info.ts.shape[0]}.')

    # Mask the unobserved parts of the series for sanity checking.
    # This doesn't affect anything if we implemented the model correctly.
    yts = self.mask_observation_space_sequence(yts)

    # Sample from p(x_{1:N} | Y_{1:N})
    xts_values = crf.sample(key, messages=messages)
    xts = TimeSeries(info.ts, xts_values)
    xts_generation_buffer = xts[self.latent_generation_start_index:] # Get x_{l:N}

    # Compute the true smoothed transitions, p(x_{i+1} | x_{l:i}, Y_{1:N})
    # and condition on the previous latent variable to get the next distribution
    true_transitions = crf.get_transitions(messages=messages)[self.latent_generation_start_index:]
    true_next_distribution = true_transitions.condition_on_x(xts_generation_buffer.yts[:-1])
    assert true_next_distribution.batch_size == self.generation_len - 1

    # Predict our models next state distribution, q(x_{i+1} | x_{l:i}, Y_{1:k})
    # We cut off the unobserved parts of yts inside the predict_next_state_distribution function
    predicted_next_distribution = self.predict_next_state_distribution(yts, xts_generation_buffer, crf_state=crf_state)

    # Compute the potential matching loss
    true_next_distribution_mixed = true_next_distribution.to_mixed()
    predicted_next_distribution_mixed = predicted_next_distribution.to_mixed()
    true_mean, predicted_mean = true_next_distribution_mixed.mu, predicted_next_distribution_mixed.mu
    true_J, predicted_J = true_next_distribution_mixed.J, predicted_next_distribution_mixed.J

    J_diff = (true_J - predicted_J).elements
    mu_diff = (true_mean - predicted_mean)
    mse_loss = jnp.mean(mu_diff**2) + jnp.mean(J_diff**2)

    # Compute the log likelihood of the observed values under our model
    log_probs = predicted_next_distribution.log_prob(xts_generation_buffer.yts[1:])

    ml_loss = -log_probs.mean()

    if debug:
      import pdb; pdb.set_trace()

    losses = dict(ml=ml_loss, mse=mse_loss)
    return losses

  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    *,
    crf_state: Optional[CRFState] = None,
    debug: Optional[bool] = False,
    **kwargs
  ) -> TimeSeries:
    """Sample from p(x_{l:N} | Y_{1:k}) autoregressively.  This involves first sampling from p(x_l | Y_{1:k}) and then iterating to sample from p(x_{i+1} | x_{l:i}, Y_{1:k})"""
    assert len(yts) == self.obs_seq_len, f'yts must be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, "yts must have the same dimension as the observed data"

    if crf_state is None:
      yts = self.mask_observation_space_sequence(yts) # This is to ensure that the samples don't depend on unobserved values
      crf_state = self.make_crf_state(yts)
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages
    info: DiscretizeInfo = crf_state.discretization_info

    # Generate the buffer of the predicted latent series
    xts_values = crf.sample(key, messages=messages)

    # We will assume that the first latent variable comes from the correct prior distribution.
    # See the discussion in the docstring for AbstractAutoregressiveModel for more details.
    mask = jnp.zeros(xts_values.shape[0], dtype=bool)
    mask = mask.at[:self.latent_generation_start_index+1].set(True)

    # Turn xts_values into a TimeSeries
    full_xts_buffer = TimeSeries(info.ts, xts_values, observation_mask=mask)
    xts_buffer = full_xts_buffer[self.latent_generation_start_index:]

    # Encode the encoder series context for our transformer
    yts_cond = yts[:self.cond_len]
    context = self.transformer.create_context(yts_cond)

    # Create a new crf_state that will get the target covariances for us
    dummy_yts = TimeSeries(yts.ts, jnp.zeros_like(yts.yts))
    crf_state = self.make_crf_state(dummy_yts)

    #########################################################
    # Autoregressive sampling
    #########################################################
    def scan_body(carry, inputs, debug=False):
      xts_buffer = carry
      key, k = inputs

      # Predict all of the transitions.  The model is autoregressive so we don't need to worry about
      # masking xts_buffer or anything like that.  Also, we cut off the unobserved parts of yts inside
      # the predict_next_state_distribution function.
      transitions = self.predict_next_state_distribution(yts, xts_buffer, context=context, crf_state=crf_state)

      # Get p(x_{i+1} | x_k)
      transition = transitions[k]

      # Sample x_{i+1} ~ p(x_{i+1} | x_k)
      xkp1 = transition.sample(key)

      # Update the buffer with the new prediction
      new_yts = util.fill_array(xts_buffer.yts, k+1, xkp1)

      # Update the observation mask to reflect the new prediction
      new_observation_mask = util.fill_array(xts_buffer.observation_mask, k+1, True)

      # Create the new time series
      new_xts = TimeSeries(xts_buffer.ts, new_yts, new_observation_mask)
      if debug:
        import pdb; pdb.set_trace()
      return new_xts, new_xts

    carry = xts_buffer
    keys = random.split(key, len(xts_buffer.ts)-1)
    inputs = (keys, jnp.arange(len(xts_buffer.ts)-1))

    if debug:
      for i, item in enumerate(zip(*inputs)):
        carry, _ = scan_body(carry, item, debug=True)
      generated_latent_seq = carry
      import pdb; pdb.set_trace()
    else:
      generated_latent_seq, all_xts_buffer = jax.lax.scan(scan_body, carry, inputs)

    # Concatenate the original latent sequence as well
    prev_latent_seq = full_xts_buffer[:self.latent_generation_start_index]
    generated_latent_seq = jtu.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), prev_latent_seq, generated_latent_seq)

    assert len(generated_latent_seq) == self.latent_seq_len, f'output must have length {self.latent_seq_len} but has length {len(generated_latent_seq)}.'
    return generated_latent_seq

################################################################################################################

class MyAutoregressiveModel(AbstractAutoregressiveModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True)

  potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'] = eqx.field(static=True)
  parametrization: Literal['nat', 'mixed', 'std'] = eqx.field(static=True)
  matrix_dim: int = eqx.field(static=True)
  dim: int = eqx.field(static=True)


  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               n_layers: int,
               filter_width: int,
               hidden_channel_size: int,
               num_transformer_heads: int,
               *,
               potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'],
               parametrization: Literal['nat', 'mixed', 'std'],
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               latent_cond_len: int,
               key: PRNGKeyArray,
               predict_cov: bool = False):
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.potential_cov_type = potential_cov_type
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = latent_cond_len*(1 + self.interpolation_freq) # Multiple by the interpolation frequency to account for the fact that we are predicting at the original frequency
    self.parametrization = parametrization
    self.predict_cov = predict_cov

    assert self.latent_generation_start_index >= 0, f'latent_generation_start_index must be non-negative but is {self.latent_generation_start_index}.'

    print(f'Generation length: {self.generation_len}, latent_seq_len: {self.latent_seq_len}, latent_generation_start_index: {self.latent_generation_start_index}')

    dim = self.linear_sde.dim

    if potential_cov_type == 'diagonal':
      matrix_dim = dim
    elif potential_cov_type == 'dense':
      matrix_dim = dim**2
    elif potential_cov_type == 'block2x2':
      matrix_dim = dim*2
      assert dim%2 == 0
    elif potential_cov_type == 'block3x3':
      matrix_dim = dim*3
      assert dim%3 == 0

    if self.predict_cov == False:
      # We'll compute the covariances using message passing.  This is valid
      # when our potential functions have a covariance that only depends on the
      # time and not the value of the data.
      out_size = dim
    else:
      # Parametrizing a transition requires 1 matrix + 1 vector
      out_size = matrix_dim + dim

    self.matrix_dim = matrix_dim
    self.dim = dim

    # Create the transformer
    hypers = MyModelHypers(n_blocks=n_layers,
                           hidden_channel_size=hidden_channel_size,
                           wavenet_kernel_width=filter_width,
                           num_transformer_heads=num_transformer_heads)
    self.transformer = MyModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=out_size,
                               hypers=hypers,
                               key=key,
                               strict_autoregressive=False,
                               causal_decoder=True)

  def predict_next_state_distribution(
    self,
    yts: TimeSeries, # The observed data
    xts: TimeSeries, # The series that we have already generated
    *,
    context: Optional[Float[Array, 'S C']] = None,
    crf_state: Optional[CRFState] = None,
  ) -> AbstractPotential:
    """Predict the next distribution for the series.

    Arguments:
      yts: The observed data
      xts: The series that we have already generated
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      bwdks: The next distribution.  This will have the same length as the
             upsampled prediction length, which is
             (1 + self.interpolation_freq)*self.pred_len
    """
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert len(xts) == self.generation_len, f'xts must have length {self.generation_len} but has length {len(xts)}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    assert xts.yts.shape[-1] == self.linear_sde.dim, "xts must have the same dimension as the SDE"

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Create the context
    if context is None:
      context = self.transformer.create_context(yts_cond)

    # Predict the backward messages parameters
    outputs = self.transformer(yts_cond, xts, context)
    assert outputs.shape[0] == self.generation_len

    #####################################
    # Construct the covariance matrices
    #####################################

    def construct_matrix(vec: Float[Array, 'D'], mat: AbstractMatrix) -> AbstractPotential:
      # Helper function to construct a potential from a vector and a matrix
      if self.parametrization == 'nat':
        return NaturalGaussian(mat, vec)
      elif self.parametrization == 'mixed':
        return MixedGaussian(vec, mat)
      elif self.parametrization == 'std':
        return StandardGaussian(vec, mat)

    if self.predict_cov == False:
      if crf_state is None:
        crf_state = self.make_crf_state(yts) # This needs the entire observed sequence only for the times
      crf: CRF = crf_state.crf
      messages: Messages = crf_state.messages
      info: DiscretizeInfo = crf_state.discretization_info

      # Compute the smoothed transition distributions.  The means of this are not correct
      # but the covariances are correct because of our assumption that the potential
      # functions have a covariance that only depends on the time and not the value of
      # the data.  This assumption works because the smoothed mean has a linear relationship
      # with Y, but the covariance has a nonlinear relationship with Y due to matrix inverses
      # during message passing.
      true_transitions = crf.get_transitions(messages=messages)[self.latent_generation_start_index:]
      true_next_distribution = true_transitions.condition_on_x(xts.yts[:-1])

      # Depending on the parametrization, we need to convert the covariance matrix
      T = outputs.shape[0]
      if self.parametrization == 'nat':
        mat = true_next_distribution.to_nat().J
      elif self.parametrization == 'mixed':
        mat = true_next_distribution.to_mixed().J
      elif self.parametrization == 'std':
        mat = true_next_distribution.to_std().Sigma

      # We don't have a last transition
      out_potentials = jax.vmap(construct_matrix)(outputs[:-1], mat)

    else:

      # Retrieve the parameters of the potential distribution
      vec, unscaled_mat_elements = outputs[...,:self.dim], outputs[...,self.dim:]

      ############################
      # Scale the matrix elements to ensure that
      # they are able to represent small covariance values
      ############################
      # Ensure that the max value isn't too large
      def flipped_softplus(x):
        return -jax.nn.softplus(-x)

      # Eyballed these values in desmos to get a good fit
      a = 1.6
      b = -10.8
      c = 1.4
      mat_elements = jnp.exp(flipped_softplus(a*unscaled_mat_elements + b) + c)

      # Reshape the matrix to the correct shape
      if self.potential_cov_type == 'diagonal':
        pass # Nothing to do
      elif self.potential_cov_type == 'dense':
        mat_elements = mat_elements.reshape((-1 ,self.dim, self.dim))
      elif self.potential_cov_type == 'block2x2':
        mat_elements = mat_elements.reshape((-1 ,2, 2, self.dim//2))
      elif self.potential_cov_type == 'block3x3':
        mat_elements = mat_elements.reshape((-1 ,3, 3, self.dim//3))

      mat = jax.vmap(partial(util.to_matrix, symmetric=True))(mat_elements)

      # Add a bit of jitter to the covariance matrix to ensure that it has full rank
      jitter = 1e-8*mat.eye(mat.shape[1])
      mat = mat + jitter

      out_potentials = jax.vmap(construct_matrix)(vec, mat)

      # We don't have a last transition.  This is the correct way to
      # slice as well in order to ensure that the Jacobian is lower
      # triangular.
      out_potentials = out_potentials[:-1]

    assert out_potentials.batch_size == self.generation_len - 1
    return out_potentials

################################################################################################################

class MyReparameterizedAutoregressiveModel(AbstractAutoregressiveModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True)

  potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'] = eqx.field(static=True)
  parametrization: Literal['nat', 'mixed', 'std'] = eqx.field(static=True)
  matrix_dim: int = eqx.field(static=True)
  dim: int = eqx.field(static=True)


  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               n_layers: int,
               filter_width: int,
               hidden_channel_size: int,
               num_transformer_heads: int,
               *,
               potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'],
               parametrization: Literal['nat', 'mixed', 'std'],
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               latent_cond_len: int,
               key: PRNGKeyArray,
               predict_cov: bool = False):
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.potential_cov_type = potential_cov_type
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = latent_cond_len*(1 + self.interpolation_freq) # Multiple by the interpolation frequency to account for the fact that we are predicting at the original frequency
    self.parametrization = parametrization
    assert predict_cov == False, 'We don\'t support predicting the covariance for the reparametrized model'
    self.predict_cov = predict_cov

    assert self.latent_generation_start_index >= 0, f'latent_generation_start_index must be non-negative but is {self.latent_generation_start_index}.'

    print(f'Generation length: {self.generation_len}, latent_seq_len: {self.latent_seq_len}, latent_generation_start_index: {self.latent_generation_start_index}')

    dim = self.linear_sde.dim

    if potential_cov_type == 'diagonal':
      matrix_dim = dim
    elif potential_cov_type == 'dense':
      matrix_dim = dim**2
    elif potential_cov_type == 'block2x2':
      matrix_dim = dim*2
      assert dim%2 == 0
    elif potential_cov_type == 'block3x3':
      matrix_dim = dim*3
      assert dim%3 == 0

    out_size = dim

    self.matrix_dim = matrix_dim
    self.dim = dim

    # Create the transformer
    hypers = MyModelHypers(n_blocks=n_layers,
                           hidden_channel_size=hidden_channel_size,
                           wavenet_kernel_width=filter_width,
                           num_transformer_heads=num_transformer_heads)
    self.transformer = MyModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=out_size,
                               hypers=hypers,
                               key=key,
                               strict_autoregressive=False,
                               causal_decoder=True)

  def predict_next_state_distribution(
    self,
    yts: TimeSeries, # The observed data
    xts: TimeSeries, # The series that we have already generated
    *,
    context: Optional[Float[Array, 'S C']] = None,
    crf_state: Optional[CRFState] = None,
  ) -> AbstractPotential:
    """Predict the next distribution for the series.

    Arguments:
      yts: The observed data
      xts: The series that we have already generated
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      bwdks: The next distribution.  This will have the same length as the
             upsampled prediction length, which is
             (1 + self.interpolation_freq)*self.pred_len
    """
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert len(xts) == self.generation_len, f'xts must have length {self.generation_len} but has length {len(xts)}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    assert xts.yts.shape[-1] == self.linear_sde.dim, "xts must have the same dimension as the SDE"

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Create the context
    if context is None:
      context = self.transformer.create_context(yts_cond)

    # Predict the backward messages parameters
    outputs = self.transformer(yts_cond, xts, context)
    assert outputs.shape[0] == self.generation_len

    #####################################
    # Construct the covariance matrices
    #####################################

    def construct_potential(vec: Float[Array, 'D'], mat: AbstractMatrix) -> AbstractPotential:
      # Helper function to construct a potential from a vector and a matrix
      if self.parametrization == 'nat':
        return NaturalGaussian(mat, vec)
      elif self.parametrization == 'mixed':
        return MixedGaussian(vec, mat)
      elif self.parametrization == 'std':
        return StandardGaussian(vec, mat)

    if crf_state is None:
      crf_state = self.make_crf_state(yts) # This needs the entire observed sequence only for the times
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages
    info: DiscretizeInfo = crf_state.discretization_info

    # Compute the backward messages.  The means of this are not correct
    # but the covariances are correct because of our assumption that the potential
    # functions have a covariance that only depends on the time and not the value of
    # the data.  This assumption works because the smoothed mean has a linear relationship
    # with Y, but the covariance has a nonlinear relationship with Y due to matrix inverses
    # during message passing.
    backward_messages = crf_state.messages.bwd
    bwd_node = backward_messages + crf.node_potentials
    bwd_node = bwd_node[self.latent_generation_start_index + 1:]

    # Depending on the parametrization, we need to convert the covariance matrix
    T = outputs.shape[0]
    if self.parametrization == 'nat':
      mat = bwd_node.to_nat().J
    elif self.parametrization == 'mixed':
      mat = bwd_node.to_mixed().J
    elif self.parametrization == 'std':
      mat = bwd_node.to_std().Sigma

    # Construct the predicted backward messages + node potentials
    predicted_bwd_node = jax.vmap(construct_potential)(outputs[:-1], mat)

    # Compute the predicted next distribution
    updated_transitions = crf.base_transitions[self.latent_generation_start_index:].unnormalized_update_y(predicted_bwd_node)
    next_distribution = updated_transitions.condition_on_x(xts.yts[:-1])
    return next_distribution

class MyReparameterizedAutoregressiveRNNModel(MyReparameterizedAutoregressiveModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True)

  potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'] = eqx.field(static=True)
  parametrization: Literal['nat', 'mixed', 'std'] = eqx.field(static=True)
  matrix_dim: int = eqx.field(static=True)
  dim: int = eqx.field(static=True)


  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               hidden_size: int,
               *,
               potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'],
               parametrization: Literal['nat', 'mixed', 'std'],
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               latent_cond_len: int,
               key: PRNGKeyArray,
               predict_cov: bool = False):
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.potential_cov_type = potential_cov_type
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = latent_cond_len*(1 + self.interpolation_freq) # Multiple by the interpolation frequency to account for the fact that we are predicting at the original frequency
    self.parametrization = parametrization
    assert predict_cov == False, 'We don\'t support predicting the covariance for the reparametrized model'
    self.predict_cov = predict_cov

    assert self.latent_generation_start_index >= 0, f'latent_generation_start_index must be non-negative but is {self.latent_generation_start_index}.'

    print(f'Generation length: {self.generation_len}, latent_seq_len: {self.latent_seq_len}, latent_generation_start_index: {self.latent_generation_start_index}')

    dim = self.linear_sde.dim

    if potential_cov_type == 'diagonal':
      matrix_dim = dim
    elif potential_cov_type == 'dense':
      matrix_dim = dim**2
    elif potential_cov_type == 'block2x2':
      matrix_dim = dim*2
      assert dim%2 == 0
    elif potential_cov_type == 'block3x3':
      matrix_dim = dim*3
      assert dim%3 == 0

    out_size = dim

    self.matrix_dim = matrix_dim
    self.dim = dim

    # Create the transformer
    hypers = MyGRUModelHypers(hidden_size=hidden_size)
    self.transformer = MyGRUModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=out_size,
                               hypers=hypers,
                               key=key)

################################################################################################################

def blah():
  key = random.PRNGKey(0)

  from Models.experiment_identifier import ExperimentIdentifier
  from main import load_empty_model
  from diffusion_crf import TAGS
  ei = ExperimentIdentifier.make_experiment_id(config_name='noisy_double_pendulum',
                                              objective=None,
                                              model_name='my_autoregressive',
                                              sde_type='tracking',
                                              freq=0,
                                              group='no_leakage_latent_forecasting',
                                              seed=0)
  # ei = ExperimentIdentifier.make_experiment_id(config_name='stocks',
  #                                             objective=None,
  #                                             model_name='my_autoregressive',
  #                                             sde_type='brownian',
  #                                             freq=0,
  #                                             group='no_leakage_obs_forecasting',
  #                                             seed=0)
  datasets = ei.get_data_fixed()
  train_data, val_data, test_data = datasets['train_data'], datasets['val_data'], datasets['test_data']
  series = test_data[0][:20]

  dummy_model = load_empty_model(ei)
  encoder = dummy_model.encoder

  # encoder = eqx.tree_at(lambda x: x.use_prior, encoder, False)

  sde = dummy_model.linear_sde
  from diffusion_crf.matrix import DiagonalMatrix, DenseMatrix, Diagonal2x2BlockMatrix, Diagonal3x3BlockMatrix
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  model = MyReparameterizedAutoregressiveModel(sde,
                                encoder,
                                n_layers=2,
                                filter_width=2,
                                hidden_channel_size=4,
                                num_transformer_heads=4,
                                potential_cov_type=potential_cov_type,
                                parametrization='mixed',
                                interpolation_freq=0,
                                seq_len=len(series),
                                cond_len=8,
                                latent_cond_len=4,
                                key=key,
                                predict_cov=False)

  sampled_series = model.sample(key, series, debug=False)
  xts = sampled_series[model.latent_generation_start_index:]
  yts = series

  crf_state = model.make_crf_state(yts) # This needs the entire observed sequence only for the times
  crf: CRF = crf_state.crf

  # Compute the backward messages.  The means of this are not correct
  # but the covariances are correct because of our assumption that the potential
  # functions have a covariance that only depends on the time and not the value of
  # the data.  This assumption works because the smoothed mean has a linear relationship
  # with Y, but the covariance has a nonlinear relationship with Y due to matrix inverses
  # during message passing.
  backward_messages = crf_state.messages.bwd
  bwd_node = backward_messages + crf.node_potentials
  bwd_node = bwd_node[model.latent_generation_start_index + 1:]

  # Depending on the parametrization, we need to convert the covariance matrix
  if model.parametrization == 'nat':
    mat = bwd_node.to_nat().J
  elif model.parametrization == 'mixed':
    mat = bwd_node.to_mixed().J
  elif model.parametrization == 'std':
    mat = bwd_node.to_std().Sigma

  def construct_potential(vec: Float[Array, 'D'], mat: AbstractMatrix) -> AbstractPotential:
    # Helper function to construct a potential from a vector and a matrix
    if model.parametrization == 'nat':
      return NaturalGaussian(mat, vec)
    elif model.parametrization == 'mixed':
      return MixedGaussian(vec, mat)
    elif model.parametrization == 'std':
      return StandardGaussian(vec, mat)


  Sigma = crf.base_transitions[model.latent_generation_start_index:][0].Sigma
  Sigma_y = mat[0]
  J_y = Sigma_y.get_inverse()
  I = J_y.eye(J_y.shape[0])
  SigmaJy = Sigma@J_y
  Sy = SigmaJy@(I + SigmaJy).get_inverse()
  Ty = I - Sy

  import pdb; pdb.set_trace()

  def update_and_condition_raw(mu_y, J_y, x, A, u, Sigma):
    I = jnp.eye(J_y.shape[0])
    SigmaJy = Sigma@J_y
    Sy = SigmaJy@(I + SigmaJy).get_inverse()
    Ty = I - Sy
    mu = Ty@(A@x + u) + Sy@mu_y
    import pdb; pdb.set_trace()
    return mu

  update_and_condition_raw(xts.yts[-1], J_y, xts.yts[-1], A, u, Sigma)

  def update_and_condition(potential: AbstractPotential, transition: GaussianTransition, x: Float[Array, 'D']):
    mu_y, Sigma_y = potential.to_std().mu, potential.to_std().Sigma.as_matrix()
    A, u, Sigma = transition.A, transition.u, transition.Sigma.as_matrix()
    return update_and_condition_raw(mu_y, Sigma_y, x, A, u, Sigma)

  @eqx.filter_jit
  def reparam(outputs: Float[Array, 'T D']):
    transitions = crf.base_transitions[model.latent_generation_start_index:]
    A, u, Sigma = transitions.A, transitions.u, transitions.Sigma.as_matrix()
    mu = jax.vmap(update_and_condition_raw)(outputs[:-1], mat.as_matrix(), xts.yts[:-1], A, u, Sigma)
    return mu.sum()


    # Construct the predicted backward messages + node potentials
    predicted_bwd_node = jax.vmap(construct_potential)(outputs[:-1], mat)

    # Compute the predicted next distribution
    updated_transitions: GaussianTransition = crf.base_transitions[model.latent_generation_start_index:].unnormalized_update_y(predicted_bwd_node)
    next_distribution = updated_transitions.condition_on_x(xts.yts[:-1])
    return next_distribution.mu.sum()


  outputs = random.normal(key, xts.yts.shape)
  # Calculate the analytical gradient
  grad = eqx.filter_grad(reparam)(outputs)

  predicted_bwd_node = jax.vmap(construct_potential)(outputs[:-1], mat)

  # Calculate numerical gradient for comparison
  from jax.test_util import check_grads
  check_grads(reparam, (outputs,), order=2)
  import pdb; pdb.set_trace()





  def update_and_condition2(potential: AbstractPotential, transition: GaussianTransition, x: Float[Array, 'D']):
    A, u, Sigma = transition.A, transition.u, transition.Sigma
    Jy, muy, logZy = potential.J, potential.mu, potential.logZ
    I = Sigma.set_eye()

    SigmaJ = Sigma@Jy
    I_plus_SigmaJ = I + SigmaJ
    R = I_plus_SigmaJ.T.solve(Jy).T # Jy@(I + Sigma@Jy)^{-1}
    S = Sigma@R                     # Sigma@Jy@(I + Sigma@Jy)^{-1}
    T = I - S
    # If potential is zero (Jy = zero, total uncertainty), then R = 0. Handled naturally
    # If potential is inf (Jy = inf, total certainty), then T = 0 and S = I. Handled naturally, but enforce here for safety
    T, S = util.where(Jy.tags.is_inf, (T.set_zero(), S.set_eye()), (T, S))
    T, S = util.where(Jy.tags.is_zero, (T.set_eye(), S.set_zero()), (T, S))

    # Same for both standard and natural:
    Sigmabar = T@Sigma
    Sigmabar = Sigmabar.set_symmetric()
    Abar = T@A
    ubar = T@u + S@muy
    return Abar@x + ubar

  update_and_condition2(predicted_bwd_node[-1], crf.base_transitions[-1], xts.yts[-1])

  k1, k2 = random.split(key)
  vec = random.normal(k1, (model.dim,))
  mat = DiagonalMatrix(jnp.ones(model.dim), tags=TAGS.symmetric_tags)
  potential = StandardGaussian(vec, mat)
  transition = model.linear_sde.get_transition_distribution(0.1, 0.2)
  mu_y, Sigma_y = potential.to_std().mu, potential.to_std().Sigma.as_matrix()
  A, u, Sigma = transition.A, transition.u, transition.Sigma.as_matrix()
  x = random.normal(k2, vec.shape)


  out1 = update_and_condition(potential, transition, vec)
  out2 = transition.update_y(potential).condition_on_x(vec).mu
  out3 = update_and_condition2(potential.to_mixed(), transition, vec)
  # import pdb; pdb.set_trace()


  def grad_check1(mu_y):
    mu_y = jnp.array(mu_y)
    return update_and_condition_raw(mu_y, Sigma_y, x, A, u, Sigma).sum()

  def grad_check2(mu_y):
    mu_y = jnp.array(mu_y)
    potential = MixedGaussian(mu_y, mat)
    return update_and_condition2(potential, transition, vec).sum()


  check_grads(grad_check1, (mu_y,), order=2)
  check_grads(grad_check2, (mu_y,), order=2)
  import pdb; pdb.set_trace()



if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import *
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.neural_diffusion_crf.interpolate_potential import InterpolateResult, initialize_potential_interpolation
  import matplotlib.pyplot as plt
  # turn on x64
  # jax.config.update('jax_enable_x64', True)

  # blah()


  N = 10
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts, series.yts

  freq = 2

  key = random.PRNGKey(0)

  # Create the SDE and emission potential encoder
  y_dim = yts.shape[1]
  # sde = BrownianMotion(sigma=0.1, dim=y_dim)
  # sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  sde = HigherOrderTrackingModel(sigma=0.1, position_dim=y_dim, order=2)
  sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=1/(1 + freq))

  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.1,
                                                  use_prior=False)
  covariance_matrix = encoder(series).node_potentials[0].to_mixed().J

  I = sde.F.eye(sde.F.shape[0])
  prior = StandardGaussian(mu=jnp.zeros(sde.dim), Sigma=I)

  from diffusion_crf.matrix import *
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  # model = MyAutoregressiveModel(sde,
  model = MyReparameterizedAutoregressiveModel(sde,
                                encoder,
                                n_layers=4,
                                filter_width=4,
                                hidden_channel_size=16,
                                num_transformer_heads=4,
                                potential_cov_type=potential_cov_type,
                                parametrization='mixed',
                                interpolation_freq=freq,
                                seq_len=len(series),
                                cond_len=5,
                                latent_cond_len=3,
                                key=key,
                                predict_cov=False)

  # Evaluate the loss function
  # out = model.loss_fn(series, key=key, debug=False)

  # Sample a series
  sampled_series = model.sample(key, series, debug=True)

  # Check that the samples don't depend on unobserved values
  series2 = model.mask_observation_space_sequence(series)
  sampled_series2 = model.sample(key, series2, debug=False)
  assert jnp.allclose(sampled_series.yts, sampled_series2.yts)

  xts_generated = sampled_series[model.latent_generation_start_index:]
  def output(xts_values):
    xts = TimeSeries(ts=xts_generated.ts, yts=xts_values)
    out_dist = model.predict_next_state_distribution(series, xts)
    return out_dist.mu

  means = output(xts_generated.yts)

  # Compute the Jacobian
  J = eqx.filter_jacfwd(output)(xts_generated.yts)
  J_flat = J.sum(axis=(1, 3))

  # Check that the Jacobian is lower triangular
  assert jnp.allclose(J_flat, jnp.tril(J_flat))


  # Check downsampling
  latent_seq = model.downsample_seq_to_original_freq(sampled_series)
  assert len(latent_seq) == len(series)
  import pdb; pdb.set_trace()

