import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractSDE
from diffusion_crf.base import AbstractBatchableObject
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo, interleave_series
from diffusion_crf.ssm.simple_encoder import AbstractEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
from diffusion_crf import AbstractLinearSDE, build_conditional_sde, MixedGaussian, NaturalGaussian, ConditionedLinearSDE, interleave_times, ODESolverParams, ode_solve
from diffusion_crf.matrix import AbstractMatrix
from diffusion_crf.base import AbstractTransition, AbstractPotential
from diffusion_crf.gaussian.transition import GaussianTransition
import diffusion_crf.util as util
from diffusion_crf.crf import Messages, CRF
from diffusion_crf.gaussian import StandardGaussian, MixedGaussian, NaturalGaussian
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
from Models.models.base import AbstractModel, CRFState
from Models.models.base import AbstractEncoderDecoderModel
from Models.models.transformer import MyModel, MyModelHypers
from Models.models.rnn import MyGRUModelHypers, MyGRUModel

"""
This module implements a baseline autoregressive model for forecasting the distribution q(y_{k+1:N} | Y_{1:k}).
"""

################################################################################################################

class AbstractBaselineAutoregressiveModel(AbstractModel, abc.ABC):
  """Model for q(y_{k+1:N} | Y_{1:k}).  Although this model shouldn't directly inherit from AbstractModel
  because it doesn't predict the latent variables, we do so to make it have the same interface
  as the other models.
  """
  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder
  transformer: AbstractEncoderDecoderModel

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  @abc.abstractmethod
  def predict_next_state_distribution(
    self,
    yts: TimeSeries, # The observed data
    yts_buffer: TimeSeries, # The buffer of observed data
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> AbstractPotential:
    """Predict the next distribution for the series.

    Arguments:
      yts: The observed data
      yts_buffer: The buffer of observed data.  The first self.cond_len elements should
                  be the same as yts[:self.cond_len].  The remaining elements should
                  be the observed data that we have already sampled.
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      The next distribution.  This is a batched Gaussian distribution with mean mu and covariance Sigma
      that represents the transition distributions between consecutive elements of yts_buffer.
    """
    raise NotImplementedError

  def loss_fn(
    self,
    yts: TimeSeries,
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Dict[str, Scalar]:
    assert len(yts) == self.obs_seq_len, f'series must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    predicted_next_distribution = self.predict_next_state_distribution(yts, yts)

    # Compute the log likelihood of the observed values under our model
    log_probs = predicted_next_distribution.log_prob(yts.yts[1:])

    # We only want to compute the loss on the unobserved parts of the series
    log_probs = log_probs[self.cond_len-1:]

    ml_loss = -log_probs.mean()

    if debug:
      import pdb; pdb.set_trace()

    losses = dict(ml=ml_loss)
    return losses

  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    *,
    debug: Optional[bool] = False,
    to_latent: bool = True,
    **kwargs
  ) -> TimeSeries:
    """Sample from p(y_{k+1:N} | Y_{1:k}) autoregressively."""
    assert len(yts) == self.obs_seq_len, f'yts must be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, "yts must have the same dimension as the observed data"

    # Keep a buffer of the observed series.  There is a deprecation variable in TimSeries (mask_value)
    # that causes an error if we just set yts_buffer to yts.  So we need to do this manually.
    yts_buffer = TimeSeries(ts=yts.ts, yts=yts.yts, observation_mask=yts.observation_mask)

    # Encode the encoder series context for our transformer
    yts_cond = yts[:self.cond_len]
    context = self.transformer.create_context(yts_cond)

    #########################################################
    # Autoregressive sampling
    #########################################################
    def scan_body(carry, inputs, debug=False):
      yts_buffer = carry
      key, k = inputs

      # Predict all of the transitions.  The model is autoregressive so we don't need to worry about
      # masking yts_buffer or anything like that.  Also, we cut off the unobserved parts of yts inside
      # the predict_next_state_distribution function.
      transitions = self.predict_next_state_distribution(yts, yts_buffer, context=context)

      # Get p(y_{i+1} | y_k)
      transition = transitions[k]

      # Sample y_{i+1} ~ p(y_{i+1} | y_k)
      ykp1 = transition.sample(key)

      # Update the buffer with the new prediction
      new_yts = util.fill_array(yts_buffer.yts, k+1, ykp1)

      # Update the observation mask to reflect the new prediction
      new_observation_mask = util.fill_array(yts_buffer.observation_mask, k+1, True)

      # Create the new time series
      new_yts = TimeSeries(yts_buffer.ts, new_yts, new_observation_mask)
      if debug:
        import pdb; pdb.set_trace()
      return new_yts, new_yts

    carry = yts_buffer
    keys = random.split(key, len(yts_buffer.ts)-1)
    inputs = (keys, jnp.arange(len(yts_buffer.ts)-1))

    # Start the autoregressive sampling from the last observed time point
    inputs = jtu.tree_map(lambda x: x[self.cond_len-1:], inputs)

    if debug:
      for i, item in enumerate(zip(*inputs)):
        carry, _ = scan_body(carry, item, debug=True)
      import pdb; pdb.set_trace()
    else:
      carry, all_yts_buffer = jax.lax.scan(scan_body, carry, inputs)

    assert len(carry) == self.obs_seq_len, f'output must have length {self.obs_seq_len} but has length {len(carry)}.'

    if to_latent:
      return self.basic_interpolation(keys[0], carry)
    else:
      return carry

################################################################################################################

class MyBaselineAutoregressiveModel(AbstractBaselineAutoregressiveModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'] = eqx.field(static=True)
  parametrization: Literal['nat', 'mixed', 'std'] = eqx.field(static=True)
  matrix_dim: int = eqx.field(static=True)
  dim: int = eqx.field(static=True)


  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               n_layers: int,
               filter_width: int,
               hidden_channel_size: int,
               num_transformer_heads: int,
               *,
               potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'],
               parametrization: Literal['nat', 'mixed', 'std'],
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               key: PRNGKeyArray):
    assert interpolation_freq == 0, 'This model is not designed to downsample the series.'

    potential_cov_type = 'diagonal' # Force this

    self.linear_sde = linear_sde
    self.encoder = encoder
    self.potential_cov_type = potential_cov_type
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.parametrization = parametrization
    self.latent_cond_len = 0 # Not used in this model

    dim = self.encoder.y_dim

    if potential_cov_type == 'diagonal':
      matrix_dim = dim
    elif potential_cov_type == 'dense':
      matrix_dim = dim**2
    elif potential_cov_type == 'block2x2':
      matrix_dim = dim*2
      # assert dim%2 == 0
    elif potential_cov_type == 'block3x3':
      matrix_dim = dim*3
      # assert dim%3 == 0

    # Parametrizing a gaussian distribution requires 1 matrix + 1 vector
    out_size = matrix_dim + dim

    self.matrix_dim = matrix_dim
    self.dim = dim

    # Create the transformer
    hypers = MyModelHypers(n_blocks=n_layers,
                           hidden_channel_size=hidden_channel_size,
                           wavenet_kernel_width=filter_width,
                           num_transformer_heads=num_transformer_heads)
    self.transformer = MyModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=out_size,
                               hypers=hypers,
                               key=key,
                               strict_autoregressive=False,
                               causal_decoder=True)

  def predict_next_state_distribution(
    self,
    yts: TimeSeries, # The observed data
    yts_buffer: TimeSeries, # The buffer of observed data
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> AbstractPotential:
    """Predict the next distribution for the series.

    Arguments:
      yts: The observed data
      yts_buffer: The buffer of observed data.  The first self.cond_len elements should
                  be the same as yts[:self.cond_len].  The remaining elements should
                  be the observed data that we have already sampled.
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      The next distribution.  This is a batched Gaussian distribution with mean mu and covariance Sigma
      that represents the transition distributions between consecutive elements of yts_buffer.
    """
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Create the context
    if context is None:
      context = self.transformer.create_context(yts_cond)

    # Predict the backward messages parameters
    outputs = self.transformer(yts_cond, yts_buffer, context)
    assert outputs.shape[0] == self.obs_seq_len

    #####################################
    # Construct the covariance matrices
    #####################################

    def construct_matrix(vec: Float[Array, 'D'], mat: AbstractMatrix) -> AbstractPotential:
      # Helper function to construct a potential from a vector and a matrix
      if self.parametrization == 'nat':
        return NaturalGaussian(mat, vec)
      elif self.parametrization == 'mixed':
        return MixedGaussian(vec, mat)
      elif self.parametrization == 'std':
        return StandardGaussian(vec, mat)

    # Retrieve the parameters of the potential distribution
    vec, unscaled_mat_elements = outputs[...,:self.dim], outputs[...,self.dim:]

    ############################
    # Scale the matrix elements to ensure that
    # they are able to represent small covariance values
    ############################
    # Ensure that the max value isn't too large
    def flipped_softplus(x):
      return -jax.nn.softplus(-x)

    # Eyballed these values in desmos to get a good fit
    a = 1.6
    b = -10.8
    c = 1.4
    mat_elements = jnp.exp(flipped_softplus(a*unscaled_mat_elements + b) + c)

    # Reshape the matrix to the correct shape
    if self.potential_cov_type == 'diagonal':
      pass # Nothing to do
    elif self.potential_cov_type == 'dense':
      mat_elements = mat_elements.reshape((-1 ,self.dim, self.dim))
    elif self.potential_cov_type == 'block2x2':
      mat_elements = mat_elements.reshape((-1 ,2, 2, self.dim//2))
    elif self.potential_cov_type == 'block3x3':
      mat_elements = mat_elements.reshape((-1 ,3, 3, self.dim//3))

    mat = jax.vmap(partial(util.to_matrix, symmetric=True))(mat_elements)

    # Add a bit of jitter to the covariance matrix to ensure that it has full rank
    jitter = 1e-8*mat.eye(mat.shape[1])
    mat = mat + jitter

    out_potentials = jax.vmap(construct_matrix)(vec, mat)

    # We don't have a last transition.  This is the correct way to
    # slice as well in order to ensure that the Jacobian is lower
    # triangular.
    out_potentials = out_potentials[:-1]

    return out_potentials

################################################################################################################

class MyBaselineAutoregressiveRNNModel(MyBaselineAutoregressiveModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'] = eqx.field(static=True)
  parametrization: Literal['nat', 'mixed', 'std'] = eqx.field(static=True)
  matrix_dim: int = eqx.field(static=True)
  dim: int = eqx.field(static=True)


  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               hidden_size: int,
               *,
               potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'],
               parametrization: Literal['nat', 'mixed', 'std'],
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               key: PRNGKeyArray):
    assert interpolation_freq == 0, 'This model is not designed to downsample the series.'

    potential_cov_type = 'diagonal' # Force this

    self.linear_sde = linear_sde
    self.encoder = encoder
    self.potential_cov_type = potential_cov_type
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.parametrization = parametrization
    self.latent_cond_len = 0 # Not used in this model

    dim = self.encoder.y_dim

    if potential_cov_type == 'diagonal':
      matrix_dim = dim
    elif potential_cov_type == 'dense':
      matrix_dim = dim**2
    elif potential_cov_type == 'block2x2':
      matrix_dim = dim*2
      # assert dim%2 == 0
    elif potential_cov_type == 'block3x3':
      matrix_dim = dim*3
      assert dim%3 == 0

    # Parametrizing a gaussian distribution requires 1 matrix + 1 vector
    out_size = matrix_dim + dim

    self.matrix_dim = matrix_dim
    self.dim = dim

    # Create the transformer
    hypers = MyGRUModelHypers(hidden_size=hidden_size)
    self.transformer = MyGRUModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=out_size,
                               hypers=hypers,
                               key=key)

################################################################################################################

if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import *
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.neural_diffusion_crf.interpolate_potential import InterpolateResult, initialize_potential_interpolation
  import matplotlib.pyplot as plt

  N = 10
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts, series.yts

  freq = 0

  key = random.PRNGKey(0)

  # Create the SDE and emission potential encoder
  y_dim = yts.shape[1]
  sde = HigherOrderTrackingModel(sigma=0.1, position_dim=y_dim, order=2)
  # sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=1/(1 + freq))

  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.1)
  covariance_matrix = encoder(series).node_potentials[0].to_mixed().J

  I = sde.F.eye(sde.F.shape[0])
  prior = StandardGaussian(mu=jnp.zeros(sde.dim), Sigma=I)

  from diffusion_crf.matrix import *
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  model = MyBaselineAutoregressiveModel(sde,
                                        encoder,
                                        n_layers=4,
                                        filter_width=4,
                                        hidden_channel_size=16,
                                        num_transformer_heads=4,
                                        potential_cov_type=potential_cov_type,
                                        parametrization='std',
                                        interpolation_freq=freq,
                                        seq_len=len(series),
                                        cond_len=4,
                                        key=key)


  # Run the loss function
  out = model.loss_fn(series,
                      key=key,
                      debug=False)

  # Sample a series
  sampled_series = model.sample(key, series, debug=False)

  # Check that the samples don't depend on unobserved values
  series2 = model.mask_observation_space_sequence(series)
  sampled_series2 = model.sample(key, series2, debug=False)
  assert jnp.allclose(sampled_series.yts, sampled_series2.yts)

  # Check downsampling
  latent_seq = model.downsample_seq_to_original_freq(sampled_series)

  ##################
  # Check the autoregressive property
  ##################

  # Sample a series
  sampled_series = model.sample(key, series, debug=False, to_latent=False)

  yts = series
  def output(xts_values):
    xts = TimeSeries(ts=sampled_series.ts, yts=xts_values)
    out_dist = model.predict_next_state_distribution(yts, xts)
    return out_dist.mu

  means = output(sampled_series.yts)

  # Compute the Jacobian
  J = eqx.filter_jacfwd(output)(sampled_series.yts)
  J_flat = J.sum(axis=(1, 3))

  # Check that the Jacobian is lower triangular
  assert jnp.allclose(J_flat, jnp.tril(J_flat))

  import pdb; pdb.set_trace()

