import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractSDE
from diffusion_crf.base import AbstractBatchableObject
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo, interleave_series
from diffusion_crf.ssm.simple_encoder import AbstractEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
from diffusion_crf import AbstractLinearSDE, MixedGaussian, NaturalGaussian, ConditionedLinearSDE, interleave_times
from diffusion_crf.matrix import AbstractMatrix
from diffusion_crf.base import AbstractTransition, AbstractPotential
from diffusion_crf.gaussian.transition import GaussianTransition
import diffusion_crf.util as util
from diffusion_crf.crf import Messages, CRF
from diffusion_crf.gaussian import StandardGaussian, MixedGaussian, NaturalGaussian
from diffusion_crf.gaussian.transition import GaussianTransition
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
from Models.models.base import AbstractModel, CRFState, AbstractHiddenState
from Models.models.base import AbstractEncoderDecoderModel
from Models.models.transformer import MyModel, MyModelHypers
from Models.models.ode_sde_simulation import AbstractSolverParams, ODESolverParams, ode_solve, SDESolverParams, sde_sample, DiffraxSolverState
from Models.models.rnn import MyGRUModelHypers, MyGRUModel, StackedGRUSequenceHypers, StackedGRURNN

"""
This module implements neural SDE models for time series forecasting. It provides:

1. AbstractNeuralSDE - A base class for models that learn to control a stochastic differential
   equation, either by predicting drift or probability flow terms

2. LocalSDE - A helper class that manages the SDE simulation for time-local predictions

3. MyNeuralSDE - A concrete implementation using a transformer-based architecture
   with flow or drift matching capabilities

These models leverage SDE dynamics to generate sample paths between observed points
by learning the appropriate drift or flow terms that guide the stochastic process
toward the target distribution. The approach provides a principled way to generate
realistic, diverse samples that respect the underlying physics-inspired dynamics.
"""

################################################################################################################

class AbstractNeuralSDE(AbstractModel, abc.ABC):
  """A neural SDE that models the process p(X | y_{1:k})"""

  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder
  transformer: AbstractEncoderDecoderModel

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_flow_or_drift: Literal['flow', 'drift'] = eqx.field(static=True)
  use_sequential_sampling: bool = eqx.field(static=True)

  @abc.abstractmethod
  def predict_control(
    self,
    yts: TimeSeries, # The observed data
    xts: TimeSeries, # The series that we have already generated
    *,
    context: Optional[Float[Array, 'S C']] = None,
    current_index: Optional[int] = None
  ) -> Float[Array, 'T D']:
    """Predict the drift or flow for the series.

    Arguments:
      yts: The observed data
      xts: The series that we have already generated
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.
      current_index: The index of the current time step.  If this is provided, we will
                     return the control at the current time step.

    Returns:
      The drift or flow at the times in series_dec.
    """
    pass

  def loss_fn(
    self,
    yts: TimeSeries,
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Dict[str, Scalar]:
    assert len(yts) == self.obs_seq_len, f'series must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    k1, k2 = random.split(key)

    # Construct the CRF p(x | Y)
    prob_series = self.encoder(yts)
    cond_sde = ConditionedLinearSDE(self.linear_sde, prob_series, parallel=jax.devices()[0].platform == 'gpu')

    # We need to draw random times in between the times in prob_series.ts
    info = self.get_random_discretization_info(k1, prob_series.ts)

    if info.ts.shape[0] != self.latent_seq_len:
      raise ValueError(f'info.ts must have length {self.latent_seq_len} but has length {info.ts.shape[0]}.')

    # Sample the things we need for flow/drift matching.  Even though we pass in
    # the new (random) times to this function, it returns the items at both
    # the new times and the original times.
    items = cond_sde.multi_sample_matching_items(info.new_times, k2)
    xts = TimeSeries(items.t, items.xt)
    xts_generation_buffer = xts[self.latent_generation_start_index:] # Get x_{l:N}

    # Depending on our parameterization, extract the probability flow ODE vector field or SDE drift
    if self.predict_flow_or_drift == 'flow':
      control_target = items.flow
    elif self.predict_flow_or_drift == 'drift':
      control_target = items.drift
    control_target = control_target[self.latent_generation_start_index:]

    # Predict the flow with our model.  Although we pass in the entire observed series to the model,
    # under the hood the model only uses part of yts for conditioning.
    pred_control = self.predict_control(yts, xts_generation_buffer)

    # We don't actually care about the last element of the control because we don't predict past
    # the last time
    pred_control = pred_control[:-1]
    control_target = control_target[:-1]

    if True:
      # Get a mask to isolate the control target that we actually care about
      new_mask = info.new_indices_mask[self.latent_generation_start_index:-1]

      pred_control = pred_control*new_mask[:,None]
      control_target = control_target*new_mask[:,None]

      matching_loss = jnp.sum((pred_control - control_target)**2)/jnp.sum(new_mask)

      # As a sanity check, check the cosine similarity between the predicted control_target and the true control_target
      dot_products = jnp.sum(pred_control*control_target, axis=-1)
      both_norms = jnp.linalg.norm(pred_control, axis=-1)*jnp.linalg.norm(control_target, axis=-1)
      values = dot_products/both_norms
      all_cos_sim = jnp.where(new_mask, values, 0.0)
      cos_sim = jnp.sum(all_cos_sim)/jnp.sum(new_mask)
    else:

      # Compute the flow matching loss
      matching_loss = jnp.mean((pred_control - control_target)**2)

      # As a sanity check, check the cosine similarity between the predicted control_target and the true control_target
      cos_sim = jnp.mean(jnp.sum(pred_control*control_target, axis=-1) / (jnp.linalg.norm(pred_control, axis=-1)*jnp.linalg.norm(control_target, axis=-1)))

    if debug:
      import pdb; pdb.set_trace()

    if self.predict_flow_or_drift == 'flow':
      losses = dict(flow_matching=matching_loss, cos_sim=cos_sim)
    elif self.predict_flow_or_drift == 'drift':
      losses = dict(drift_matching=matching_loss, cos_sim=cos_sim)
    else:
      raise ValueError(f'predict_flow_or_drift must be either "flow" or "drift" but got {self.predict_flow_or_drift}.')

    return losses

  def get_default_solver_params(self) -> Union[ODESolverParams, SDESolverParams]:
    if self.predict_flow_or_drift == 'flow':
      return ODESolverParams(rtol=1e-3,
                             atol=1e-3,
                             solver='dopri5',
                             adjoint='recursive_checkpoint',
                             stepsize_controller='pid',
                             max_steps=20_000,
                             throw=True,
                             progress_meter=None)
    elif self.predict_flow_or_drift == 'drift':
      return SDESolverParams(solver='shark',
                             adjoint='direct',
                             stepsize_controller='constant',
                             max_steps=100,
                             throw=False,
                             progress_meter=None,
                             brownian_simulation_type='unsafe')
    else:
      raise ValueError(f'predict_flow_or_drift must be either "flow" or "drift" but got {self.predict_flow_or_drift}.')

  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    *,
    crf_state: Optional[CRFState] = None,
    debug: Optional[bool] = False,
    solver_params: Optional[AbstractSolverParams] = None,
    reuse_solver_state: Optional[bool] = False, # Doesn't seem like we do fewer function evaluations overall!
    **kwargs
  ) -> TimeSeries:
    """Sample from p(x_{1:N} | Y_{1:k}) autoregressively.  This involves first sampling from p(x_1 | Y_{1:k}) and then iterating to sample from p(x_{i+1} | x_{1:i}, Y_{1:k})"""
    assert len(yts) == self.obs_seq_len, f'yts must be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, "yts must have the same dimension as the observed data"

    if 'GRU' in str(type(self.transformer)) and self.use_sequential_sampling:
      return self.sequential_sample(key, yts, crf_state=crf_state, debug=debug, solver_params=solver_params, **kwargs)

    yts = self.mask_observation_space_sequence(yts) # This is to ensure that the samples don't depend on unobserved values

    # Generate the buffer of latent variables that we will update.  This also samples the first point
    info = self.get_random_discretization_info(key, yts.ts)
    crf_state = CRFState(self.linear_sde, self.encoder, self.make_fully_observed_series(yts), info, parameterization='mixed')
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages

    # Generate the buffer of the predicted latent series
    xts_values = crf.sample(key, messages=messages)

    # We will assume that the first latent variable comes from the correct prior distribution.
    # See the discussion in the docstring for AbstractAutoregressiveModel for more details.
    mask = jnp.zeros(xts_values.shape[0], dtype=bool)
    mask = mask.at[:self.latent_generation_start_index+1].set(True)

    full_xts_buffer = TimeSeries(info.ts, xts_values, observation_mask=mask)
    xts_buffer = full_xts_buffer[self.latent_generation_start_index:]

    # Get the start and end times of the prediction interval
    original_times = xts_buffer.ts[::2] # Every other time is the random time (interpolation_freq=1)
    start_times = original_times[:-1]
    end_times = original_times[1:]

    # Get the indices of the times that we'll retrieve the control at.
    # We're using random times like this because the transformer expects
    # random looking times to use for conditioning.
    intermediate_indices = jnp.arange(1, len(xts_buffer), 2)
    intermediate_times = xts_buffer.ts[intermediate_indices]

    ########################
    ########################
    # Generate the buffer of latent variables that we will update.  This also samples the first point
    discretization_key = kwargs.get('discretization_key', None)
    if discretization_key is not None:
      # info2 = self.get_discretization_info(yts.ts)
      info2 = self.get_random_discretization_info(discretization_key, yts.ts)
      dummy_ts = info2.ts[self.latent_generation_start_index:]
      original_times2 = dummy_ts[::2]
      start_times = original_times2[:-1]
      end_times = original_times2[1:]
      intermediate_times = dummy_ts[intermediate_indices]
    ########################
    ########################

    # Encode the encoder series context
    context = self.transformer.create_context(yts[:self.cond_len])

    # Get the solver params
    if solver_params is None:
      solver_params = self.get_default_solver_params()

    def scan_body(carry, inputs, debug=False):
      xts_buffer, diffrax_solver_state = carry
      key, k, t_save, t_start, t_end = inputs

      k: int # The index of the random time in between the current marginal index and the next marginal index

      x_start = xts_buffer.yts[k-1]

      # We will want to save the value at t_save for the future
      save_times = jnp.array([t_start, t_save, t_end])

      # Create the local SDE for this time step and simulate it
      local_sde = LocalSDE(self, xts_buffer, k, yts, context, solver_params)
      simulated_trajectory, updated_diffrax_solver_state = local_sde.simulate(x0=x_start,
                                                save_times=save_times,
                                                key=key,
                                                diffrax_solver_state=diffrax_solver_state)
      if reuse_solver_state is False:
        updated_diffrax_solver_state = DiffraxSolverState()

      # Extract the values at the times that we saved
      assert len(simulated_trajectory) == 3
      xt_save = simulated_trajectory.yts[1]
      xt_end = simulated_trajectory.yts[2] # This is the starting point for the next iteration

      # Fill in the new xts_buffer
      new_values = util.fill_array(xts_buffer.yts, k, xt_save)
      new_values = util.fill_array(new_values, k+1, xt_end)
      new_xts_buffer = TimeSeries(xts_buffer.ts, new_values)

      # Update the observation mask to reflect the new prediction
      new_observation_mask = util.fill_array(xts_buffer.observation_mask, k, True)
      new_observation_mask = util.fill_array(new_observation_mask, k+1, True)

      # Create the new time series
      new_xts_buffer = TimeSeries(xts_buffer.ts, new_values, new_observation_mask)

      if debug:
        import pdb; pdb.set_trace()
      return (new_xts_buffer, updated_diffrax_solver_state), new_xts_buffer

    if reuse_solver_state is False:
      diffrax_solver_state = DiffraxSolverState()
    else:
      # Initialize the solver state
      k0 = intermediate_indices[0]
      x0 = xts_buffer.yts[k0-1]
      t0 = start_times[0]
      t1 = end_times[0]
      local_sde = LocalSDE(self, xts_buffer, k0, yts, context, solver_params)
      diffrax_solver_state = solver_params.initialize_solve_state(sde=local_sde,
                                                                  x0=x0,
                                                                  t0=t0,
                                                                  t1=t1,
                                                                  key=key)

    carry = (xts_buffer, diffrax_solver_state)
    keys = random.split(key, len(intermediate_indices))
    inputs = (keys, intermediate_indices, intermediate_times, start_times, end_times)

    if debug:
      for i, item in enumerate(zip(*inputs)):
        carry, _ = scan_body(carry, item, debug=True)
      generated_latent_seq, _ = carry
      import pdb; pdb.set_trace()
    else:
      (generated_latent_seq, _), generated_xts = jax.lax.scan(scan_body, carry, inputs)

    # Concatenate the original latent sequence as well
    prev_latent_seq = full_xts_buffer[:self.latent_generation_start_index]
    generated_latent_seq = jtu.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), prev_latent_seq, generated_latent_seq)

    assert len(generated_latent_seq) == self.latent_seq_len, f'output must have length {self.latent_seq_len} but has length {len(generated_latent_seq)}.'
    return generated_latent_seq

  def sequential_sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    *,
    crf_state: Optional[CRFState] = None,
    debug: Optional[bool] = False,
    solver_params: Optional[AbstractSolverParams] = None,
    **kwargs
  ) -> TimeSeries:
    """Sample from p(x_{1:N} | Y_{1:k}) autoregressively.  This involves first sampling from p(x_1 | Y_{1:k}) and then iterating to sample from p(x_{i+1} | x_{1:i}, Y_{1:k})"""
    assert len(yts) == self.obs_seq_len, f'yts must be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, "yts must have the same dimension as the observed data"

    yts = self.mask_observation_space_sequence(yts) # This is to ensure that the samples don't depend on unobserved values

    # Generate the buffer of latent variables that we will update.  This also samples the first point
    info = self.get_random_discretization_info(key, yts.ts)
    crf_state = CRFState(self.linear_sde, self.encoder, self.make_fully_observed_series(yts), info)
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages

    # Generate the buffer of the predicted latent series
    xts_values = crf.sample(key, messages=messages)

    # We will assume that the first latent variable comes from the correct prior distribution.
    # See the discussion in the docstring for AbstractAutoregressiveModel for more details.
    mask = jnp.zeros(xts_values.shape[0], dtype=bool)
    mask = mask.at[:self.latent_generation_start_index+1].set(True)

    full_xts_buffer = TimeSeries(info.ts, xts_values, observation_mask=mask)
    xts_buffer = full_xts_buffer[self.latent_generation_start_index:]

    # Get the start and end times of the prediction interval
    original_times = xts_buffer.ts[::2] # Every other time is the random time (interpolation_freq=1)
    start_times = original_times[:-1]
    end_times = original_times[1:]

    # Get the indices of the times that we'll retrieve the control at.
    # We're using random times like this because the transformer expects
    # random looking times to use for conditioning.
    intermediate_indices = jnp.arange(1, len(xts_buffer), 2)
    intermediate_times = xts_buffer.ts[intermediate_indices]

    # Encode the encoder series context
    context = self.transformer.create_context(yts[:self.cond_len])

    # Get the solver params
    if solver_params is None:
      solver_params = self.get_default_solver_params()

    def scan_body(carry, inputs, debug=False):
      x_start, state = carry
      key, k, t_save, t_start, t_end = inputs

      _, state = self.single_step_predict_control(t_start, x_start, state)

      # We will want to save the value at t_save for the future
      save_times = jnp.array([t_start, t_save, t_end])

      # Create the local SDE for this time step and simulate it
      local_sde = LocalSDEWithState(self, state, yts, context, solver_params)
      simulated_trajectory = local_sde.simulate(x0=x_start,
                                                save_times=save_times,
                                                key=key)

      # Extract the values at the times that we saved
      assert len(simulated_trajectory) == 3
      xt_save = simulated_trajectory.yts[1]
      xt_end = simulated_trajectory.yts[2] # This is the starting point for the next iteration

      # Update the state
      _, new_state = self.single_step_predict_control(t_save, xt_save, state)

      if debug:
        import pdb; pdb.set_trace()
      return (xt_end, new_state), (xt_save, xt_end)

    # Get the initial state of the model
    initial_state = self.get_initial_state(yts, context=context)
    xt_start = xts_buffer.yts[0]
    carry = (xt_start, initial_state)

    keys = random.split(key, len(intermediate_indices))
    inputs = (keys, intermediate_indices, intermediate_times, start_times, end_times)

    if debug:
      all_xts = []
      for i, item in enumerate(zip(*inputs)):
        carry, out = scan_body(carry, item, debug=True)
        xt_intermediate, xt_end = out
        all_xts.append(xt_intermediate)
        all_xts.append(xt_end)
      generated_latent_seq = jnp.array(all_xts)
      import pdb; pdb.set_trace()
    else:
      xt_end, (xt_intermediate, xt_end) = jax.lax.scan(scan_body, carry, inputs)
      # Interleave the intermediate and end points
      stacked = jnp.concatenate([xt_intermediate[None], xt_end[None]], axis=0)
      generated_latent_seq = einops.rearrange(stacked, 'B N D -> (N B) D', B=2, N=xt_end.shape[0], D=xt_end.shape[-1])
      generated_latent_seq = jnp.concatenate([xt_start[None], generated_latent_seq], axis=0)

    generated_latent_seq = TimeSeries(xts_buffer.ts, generated_latent_seq)

    # Concatenate the original latent sequence as well
    prev_latent_seq = full_xts_buffer[:self.latent_generation_start_index]
    generated_latent_seq = jtu.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), prev_latent_seq, generated_latent_seq)

    assert len(generated_latent_seq) == self.latent_seq_len, f'output must have length {self.latent_seq_len} but has length {len(generated_latent_seq)}.'
    return generated_latent_seq

################################################################################################################

class LocalSDE(AbstractSDE):
  """Represents the SDE that bridges two distributions.  We simulate this SDE in order
  to transport samples from the first distribution to the second distribution.  The transition
  distribution of the start and end distribution is constructed to be equal to p(x_{i+1} | x_{1:i}, Y_{1:k})"""

  model: AbstractNeuralSDE # Our model that contains a transformer to predict the control at all times
  xts_buffer: TimeSeries   # The buffer of latent variables that we have currently generated
  k: int                   # The index into the buffer that contains the control for this local SDE

  yts: TimeSeries              # The observed data
  context: Float[Array, 'S C'] # The context embedding of the observed data

  solver_params: Union[ODESolverParams, SDESolverParams]

  @property
  def batch_size(self) -> int:
    return self.model.batch_size

  def get_control(self, t: Scalar, xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the control at time t for the local SDE.  This is the control that will be used to transport
    the sample from the start distribution to the end distribution.
    """
    # Insert xt into the buffer so that we can make a prediction of the control at that time
    new_ts = util.fill_array(self.xts_buffer.ts, self.k, t)
    new_yts = util.fill_array(self.xts_buffer.yts, self.k, xt)
    series_dec = TimeSeries(new_ts, new_yts)

    # Predict the control at the new time
    return self.model.predict_control(self.yts, series_dec, context=self.context, current_index=self.k)

  def get_drift(self, t: Scalar,  xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the drift of this SDE at time t."""
    return self.get_control(t, xt)

  def get_diffusion_coefficient(self, t: Scalar, xt: Float[Array, 'D']) -> AbstractMatrix:
    """Get the diffusion coefficient of this SDE at time t.  By construction this is the same
    as the diffusion coefficient of the base linear SDE.
    """
    return self.model.linear_sde.get_diffusion_coefficient(t, xt)

  def get_flow(self, t: Scalar, xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the vector field of the probability flow ODE of this SDE at time t."""
    return self.get_control(t, xt)

  def get_transition_distribution(self, s: Scalar, t: Scalar) -> AbstractTransition:
    raise NotImplementedError('No closed form transition distribution for neural SDEs.')

  def simulate(
    self,
    x0: Float[Array, 'D'],
    save_times: Float[Array, 'T'],
    key: PRNGKeyArray,
    diffrax_solver_state: Optional[DiffraxSolverState] = DiffraxSolverState()
  ) -> TimeSeries:
    """Simulates the SDE or its probability flow ODE to transport samples through time.

    This function simulates either the neural SDE directly or its probability flow ODE,
    depending on the model configuration. It transports samples from an initial time
    to a final time, saving the trajectory at specified intermediate times.

    Args:
        x0: Initial state with shape [D] where D is the state dimension
        save_times: Array of 3 times [t0, t1, t2] at which to save the trajectory
        key: PRNG key for stochastic simulation
        diffrax_solver_state: Optional state from a previous solve to continue from

    Returns:
        TimeSeries containing the simulated trajectory at the save_times

    Raises:
        ValueError: If predict_flow_or_drift is not 'flow' or 'drift'
        AssertionError: If save_times does not have exactly 3 elements
    """
    assert len(save_times) == 3, f'save_times must have length 3 but has length {len(save_times)}.'

    if self.model.predict_flow_or_drift == 'flow':
      # Simulate the probability flow ODE of the neural SDE
      simulated_trajectory, sol = ode_solve(self,
                                        x0=x0,
                                        save_times=save_times,
                                        params=self.solver_params,
                                        diffrax_solver_state=diffrax_solver_state,
                                        return_solve_solution=True)

    elif self.model.predict_flow_or_drift == 'drift':
      # Simulate the neural SDE
      simulated_trajectory, sol = sde_sample(self,
                                        x0=x0,
                                        key=key,
                                        save_times=save_times,
                                        params=self.solver_params,
                                        diffrax_solver_state=diffrax_solver_state,
                                        return_solve_solution=True)
    else:
      raise ValueError(f'predict_flow_or_drift must be either "flow" or "drift" but got {self.model.predict_flow_or_drift}.')

    return simulated_trajectory, sol

class LocalSDEWithState(AbstractSDE):
  """Represents the SDE that bridges two distributions.  We simulate this SDE in order
  to transport samples from the first distribution to the second distribution.  The transition
  distribution of the start and end distribution is constructed to be equal to p(x_{i+1} | x_{1:i}, Y_{1:k})"""

  model: AbstractNeuralSDE     # Our model that contains a transformer to predict the control at all times
  state: AbstractHiddenState   # The buffer of latent variables that we have currently generated

  yts: TimeSeries              # The observed data
  context: Float[Array, 'S C'] # The context embedding of the observed data

  solver_params: Union[ODESolverParams, SDESolverParams]

  @property
  def batch_size(self) -> int:
    return self.model.batch_size

  def get_control(self, t: Scalar, xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the control at time t for the local SDE.  This is the control that will be used to transport
    the sample from the start distribution to the end distribution.
    """
    out, _ = self.model.single_step_predict_control(t, xt, self.state)
    return out

  def get_drift(self, t: Scalar,  xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the drift of this SDE at time t."""
    return self.get_control(t, xt)

  def get_diffusion_coefficient(self, t: Scalar, xt: Float[Array, 'D']) -> AbstractMatrix:
    """Get the diffusion coefficient of this SDE at time t.  By construction this is the same
    as the diffusion coefficient of the base linear SDE.
    """
    return self.model.linear_sde.get_diffusion_coefficient(t, xt)

  def get_flow(self, t: Scalar, xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the vector field of the probability flow ODE of this SDE at time t."""
    return self.get_control(t, xt)

  def get_transition_distribution(self, s: Scalar, t: Scalar) -> AbstractTransition:
    raise NotImplementedError('No closed form transition distribution for neural SDEs.')

  def simulate(
    self,
    x0: Float[Array, 'D'],
    save_times: Float[Array, 'T'],
    key: PRNGKeyArray
  ) -> TimeSeries:
    """Simulates the SDE or its probability flow ODE to transport samples through time.

    This function simulates either the neural SDE directly or its probability flow ODE,
    depending on the model configuration. It transports samples from an initial time
    to a final time, saving the trajectory at specified intermediate times.

    Args:
        x0: Initial state with shape [D] where D is the state dimension
        save_times: Array of 3 times [t0, t1, t2] at which to save the trajectory
        key: PRNG key for stochastic simulation
        diffrax_solver_state: Optional state from a previous solve to continue from

    Returns:
        TimeSeries containing the simulated trajectory at the save_times

    Raises:
        ValueError: If predict_flow_or_drift is not 'flow' or 'drift'
        AssertionError: If save_times does not have exactly 3 elements
    """
    assert len(save_times) == 3, f'save_times must have length 3 but has length {len(save_times)}.'

    if self.model.predict_flow_or_drift == 'flow':
      # Simulate the probability flow ODE of the neural SDE
      simulated_trajectory = ode_solve(self,
                                   x0=x0,
                                   save_times=save_times,
                                   params=self.solver_params)

    elif self.model.predict_flow_or_drift == 'drift':
      # Simulate the neural SDE
      simulated_trajectory = sde_sample(self,
                                   x0=x0,
                                   key=key,
                                   save_times=save_times,
                                   params=self.solver_params)
    else:
      raise ValueError(f'predict_flow_or_drift must be either "flow" or "drift" but got {self.model.predict_flow_or_drift}.')

    return simulated_trajectory

################################################################################################################

class MyNeuralSDE(AbstractNeuralSDE):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True)

  dim: int = eqx.field(static=True)

  predict_flow_or_drift: Literal['flow', 'drift'] = eqx.field(static=True)

  _log_function_evaluations: bool = eqx.field(static=True)

  use_sequential_sampling: bool = eqx.field(static=True)

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               n_layers: int,
               filter_width: int,
               hidden_channel_size: int,
               num_transformer_heads: int,
               *,
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               latent_cond_len: int,
               key: PRNGKeyArray,
               predict_flow_or_drift: str):
    assert interpolation_freq == 1, "Neural SDEs are in continuous time, so there is no upsampling.  However, in practice our transformers will upsample the series by using interpolation_freq=1."
    self.interpolation_freq = interpolation_freq
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = latent_cond_len*(1 + self.interpolation_freq) # Multiple by the interpolation frequency to account for the fact that we are predicting at the original frequency
    self.use_sequential_sampling = False

    if latent_cond_len%2 != 0:
      raise ValueError(f'latent_cond_len must be even but is {latent_cond_len}.')

    assert self.latent_generation_start_index >= 0, f'latent_generation_start_index must be non-negative but is {self.latent_generation_start_index}.'

    assert predict_flow_or_drift in ['flow', 'drift'], f'predict_flow_or_drift must be either "flow" or "drift" but got {predict_flow_or_drift}.'
    self.predict_flow_or_drift = predict_flow_or_drift

    k1, k2 = random.split(key)
    dim = self.linear_sde.dim

    self.dim = dim

    # Create the transformer
    hypers = MyModelHypers(n_blocks=n_layers,
                           hidden_channel_size=hidden_channel_size,
                           wavenet_kernel_width=filter_width,
                           num_transformer_heads=num_transformer_heads)
    self.transformer = MyModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=dim,
                               hypers=hypers,
                               key=key,
                               strict_autoregressive=False,
                               causal_decoder=True)

    # Hack to get NFE results
    from Models.models.nfe_vs_tol import _HACK_TO_GET_NFE_RESULTS
    self._log_function_evaluations = _HACK_TO_GET_NFE_RESULTS

  def predict_control(
    self,
    yts: TimeSeries, # The observed data
    xts: TimeSeries, # The series that we have already generated
    *,
    context: Optional[Float[Array, 'S C']] = None,
    current_index: Optional[int] = None
  ) -> Float[Array, 'T D']:
    """Predict the drift or flow for the series.

    Arguments:
      yts: The observed data
      xts: The series that we have already generated
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      The drift or flow at the times in series_dec.
    """
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert len(xts) == self.generation_len, f'xts must have length {self.generation_len} but has length {len(xts)}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    assert xts.yts.shape[-1] == self.linear_sde.dim, "xts must have the same dimension as the SDE"

    if self._log_function_evaluations:
      from Models.nfe_vs_tol import _MODEL_EVALUATION_PRINT_STATEMENT
      jax.debug.print(_MODEL_EVALUATION_PRINT_STATEMENT)

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Create the context
    if context is None:
      context = self.transformer.create_context(yts_cond)

    # Predict the drift
    outputs = self.transformer(yts_cond, xts, context)
    assert outputs.shape[0] == self.generation_len

    # Add on the part of the drift/flow that is due to the linear SDE
    F, L = self.linear_sde.F, self.linear_sde.L
    def fix_control(t, xt, control):
      LTc = L.T@control
      return F@xt + L@LTc

    final_control = jax.vmap(fix_control)(xts.ts, xts.yts, outputs)

    if current_index is not None:
      return final_control[current_index]
    else:
      return final_control

  def single_step_predict_control(self, t: float, xt: Float[Array, 'D'], state: AbstractHiddenState) -> Float[Array, 'D']:
    raise NotImplementedError('Not implemented for transformer yet.')

################################################################################################################

class MyNeuralSDERNN(MyNeuralSDE):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True)

  dim: int = eqx.field(static=True)

  predict_flow_or_drift: Literal['flow', 'drift'] = eqx.field(static=True)

  _log_function_evaluations: bool = eqx.field(static=True)

  use_sequential_sampling: bool = eqx.field(static=True)

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               hidden_size: int,
               *,
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               latent_cond_len: int,
               key: PRNGKeyArray,
               predict_flow_or_drift: str,
               n_layers: Optional[int] = None,
               intermediate_channels: Optional[int] = None):
    assert interpolation_freq == 1, "Neural SDEs are in continuous time, so there is no upsampling.  However, in practice our transformers will upsample the series by using interpolation_freq=1."
    self.interpolation_freq = interpolation_freq
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = latent_cond_len*(1 + self.interpolation_freq) # Multiple by the interpolation frequency to account for the fact that we are predicting at the original frequency
    self.use_sequential_sampling = True

    if latent_cond_len%2 != 0:
      raise ValueError(f'latent_cond_len must be even but is {latent_cond_len}.')

    assert self.latent_generation_start_index >= 0, f'latent_generation_start_index must be non-negative but is {self.latent_generation_start_index}.'

    assert predict_flow_or_drift in ['flow', 'drift'], f'predict_flow_or_drift must be either "flow" or "drift" but got {predict_flow_or_drift}.'
    self.predict_flow_or_drift = predict_flow_or_drift

    k1, k2 = random.split(key)
    dim = self.linear_sde.dim

    self.dim = dim

    # Create the transformer
    hypers = MyGRUModelHypers(hidden_size=hidden_size)
    self.transformer = MyGRUModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=dim,
                               hypers=hypers,
                               key=key,
                               n_layers=n_layers,
                               intermediate_channels=intermediate_channels)

    # Hack to get NFE results
    from Models.models.nfe_vs_tol import _HACK_TO_GET_NFE_RESULTS
    self._log_function_evaluations = _HACK_TO_GET_NFE_RESULTS

  def get_initial_state(self, yts: TimeSeries, context: Optional[Float[Array, 'S C']] = None) -> AbstractHiddenState:
    if context is None:
      assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
      context = self.transformer.create_context(yts[:self.cond_len])
    return self.transformer.get_initial_state(yts[:self.cond_len], context=context)

  def single_step_predict_control(self, t: float, xt: Float[Array, 'D'], state: AbstractHiddenState) -> Float[Array, 'D']:
    output, updated_state = self.transformer.single_step(t, xt, state)
    F, L = self.linear_sde.F, self.linear_sde.L
    return F@xt + L@(L.T@output), updated_state

################################################################################################################

def blah():
  from Models.experiment_identifier import ExperimentIdentifier
  from main import load_empty_model
  from diffusion_crf import TAGS

  ei = ExperimentIdentifier.make_experiment_id(config_name='noisy_double_pendulum',
                                              objective=None,
                                              model_name='my_neural_ode_rnn',
                                              sde_type='tracking',
                                              freq=1,
                                              group='rnn_models',
                                              seed=0)

  ts = ei.get_train_state()
  model: AbstractModel = ts.model
  best_model: AbstractModel = ts.best_model
  datasets = ei.get_data_fixed()
  train_data, val_data, test_data = datasets['train_data'], datasets['val_data'], datasets['test_data']
  series = test_data[0]

  key = random.PRNGKey(0)
  dkey1, dkey2 = random.split(key)
  out1 = model.sample(key, series, debug=False, discretization_key=dkey1)
  out2 = best_model.sample(key, series, debug=False, discretization_key=dkey2)
  import pdb; pdb.set_trace()


if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import BrownianMotion, CriticallyDampedLangevinDynamics, TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.neural_diffusion_crf.interpolate_potential import InterpolateResult, initialize_potential_interpolation
  import matplotlib.pyplot as plt
  import wadler_lindig as wl
  # turn on x64
  jax.config.update("jax_enable_x64", True)

  # blah()

  N = 8
  # N = 64
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts.astype(jnp.float32), series.yts.astype(jnp.float32)
  series = TimeSeries(data_times, yts)

  freq = 1

  key = random.PRNGKey(0)

  # Create the SDE and emission potential encoder
  y_dim = yts.shape[1]
  sde = BrownianMotion(sigma=0.1, dim=y_dim)
  # sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  # sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=1/(1 + freq))

  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.1)
  covariance_matrix = encoder(series).node_potentials[0].to_mixed().J

  I = sde.F.eye(sde.F.shape[0])
  prior = StandardGaussian(mu=jnp.zeros(sde.dim), Sigma=I)

  from diffusion_crf.matrix import *
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  # model = MyNeuralSDE(sde,
  #                     encoder,
  #                     n_layers=10,
  #                     filter_width=4,
  #                     hidden_channel_size=128,
  #                     num_transformer_heads=16,
  #                     interpolation_freq=freq,
  #                     seq_len=len(series),
  #                     cond_len=5,
  #                     latent_cond_len=4,
  #                     key=key,
  #                     predict_flow_or_drift='flow')

  model = MyNeuralSDERNN(sde,
                      encoder,
                      hidden_size=8,
                      interpolation_freq=freq,
                      seq_len=len(series),
                      cond_len=2,
                      latent_cond_len=2,
                      key=key,
                      predict_flow_or_drift='flow',
                      n_layers=2,
                      intermediate_channels=4)

  # Pull a regular sample from the model
  xts = model.sample(key, series, debug=False)

  context = model.transformer.create_context(series[:model.cond_len])

  # Predict the entire control
  control = model.predict_control(series, xts)
  control2 = model.predict_control(series, xts, context=context)

  # Try predicting the control sequentially
  state = model.get_initial_state(series)
  all_controls = []
  for i in range(model.latent_generation_start_index, len(xts)):
    t, xt = xts.ts[i], xts.yts[i]
    c, state = model.single_step_predict_control(t, xt, state)
    all_controls.append(c)
  all_controls = jnp.array(all_controls)

  assert jnp.allclose(control, all_controls)



  # Pull a regular sample from the model
  out1 = model.sample(key, series, debug=False)


  # Pull a sample from the model with state
  out2 = model.sequential_sample(key, series, debug=False)

  assert jnp.allclose(out1.yts, out2.yts)


  import pdb; pdb.set_trace()




  # Run the loss function
  out = model.loss_fn(series,
                      key=key,
                      debug=False)

  # Sample a series
  dkeys = random.split(key, 64)
  vals = jnp.array([ 0.    ,  0.1875,  0.375 ,  0.5625,  0.75  ,  0.9375,  1.125 ,
        1.3125,  1.5   ,  1.6875,  1.875 ,  2.0625,  2.25  ,  2.4375,
        2.625 ,  2.8125,  3.    ,  3.1875,  3.375 ,  3.5625,  3.75  ,
        3.9375,  4.125 ,  4.3125,  4.5   ,  4.6875,  4.875 ,  5.0625,
        5.25  ,  5.4375,  5.625 ,  5.8125,  6.    ,  6.1875,  6.375 ,
        6.5625,  6.75  ,  6.9375,  7.125 ,  7.3125,  7.5   ,  7.6875,
        7.875 ,  8.0625,  8.25  ,  8.4375,  8.625 ,  8.8125,  9.    ,
        9.1875,  9.375 ,  9.5625,  9.75  ,  9.9375, 10.125 , 10.3125,
        10.5   , 10.6875, 10.875 , 11.0625, 11.25  , 11.4375, 11.625 ,
        11.8125], dtype=jnp.float32)
  series = eqx.tree_at(lambda x: x.ts, series, vals)

  out1 = model.sample(key, series, debug=False, discretization_key=key)
  out2 = model.sample(key, series, debug=False)
  import pdb; pdb.set_trace()

  def sample_fn(dkey):
    return model.sample(key, series, debug=False, discretization_key=dkey)

  sample_fn(dkeys[5])

  sampled_series = jax.vmap(sample_fn)(dkeys)
  times = sampled_series.ts
  import pdb; pdb.set_trace()

  # Check that the samples don't depend on unobserved values
  series2 = model.mask_observation_space_sequence(series)
  sampled_series2 = model.sample(key, series2, debug=False)
  assert jnp.allclose(sampled_series.yts, sampled_series2.yts)

  hard_solver_params = ODESolverParams(rtol=1e-12,
                             atol=1e-12,
                             solver='kvaerno5',
                             adjoint='recursive_checkpoint',
                             stepsize_controller='pid',
                             max_steps=5000,
                             throw=False,
                             progress_meter=None)
  true_sampled_series = model.sample(key,
                                series,
                                debug=False,
                                solver_params=hard_solver_params,
                                reuse_solver_state=False)



  soft_solver_params = ODESolverParams(rtol=1e-4,
                             atol=1e-4,
                             solver='dopri5',
                             adjoint='recursive_checkpoint',
                             stepsize_controller='pid',
                             max_steps=2000,
                             throw=False,
                             progress_meter=None)

  # Sample a series
  sampled_series = model.sample(key,
                                series,
                                debug=False,
                                solver_params=soft_solver_params,
                                reuse_solver_state=False)
  sampled_series2 = model.sample(key,
                                 series,
                                 debug=False,
                                 solver_params=soft_solver_params,
                                 reuse_solver_state=True)

  out = true_sampled_series.yts - sampled_series.yts
  out2 = true_sampled_series.yts - sampled_series2.yts
  import pdb; pdb.set_trace()

  # Check downsampling
  latent_seq = model.downsample_seq_to_original_freq(sampled_series)

  import pdb; pdb.set_trace()

