import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractSDE
from diffusion_crf.base import AbstractBatchableObject
from diffusion_crf.timeseries import TimeSeries, AbstractPotential, DiscretizeInfo, interleave_series, ProbabilisticTimeSeries
from diffusion_crf.ssm.simple_encoder import AbstractEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
from diffusion_crf import AbstractLinearSDE, build_conditional_sde, MixedGaussian, NaturalGaussian, ConditionedLinearSDE, interleave_times, ODESolverParams, ode_solve
from diffusion_crf.matrix import AbstractMatrix
from diffusion_crf.base import AbstractTransition, AbstractPotential
from diffusion_crf.gaussian.transition import GaussianTransition
import diffusion_crf.util as util
from diffusion_crf.crf import Messages, CRF
from diffusion_crf.gaussian import StandardGaussian, MixedGaussian, NaturalGaussian
from diffusion_crf.gaussian.transition import GaussianTransition
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
from Models.models.base import AbstractModel, CRFState
from Models.models.base import AbstractEncoderDecoderModel
from Models.models.transformer import MyModel, MyModelHypers
from diffusion_crf.matrix import *
from Models.models.rnn import MyGRUModelHypers, MyGRUModel

"""
This module implements non-probabilistic forecasting models within the neural diffusion CRF framework. It provides:

1. AbstractNonProbabilisticModel - A base class for models that predict all node potentials
   for a latent sequence at once, rather than autoregressively

2. MyNonProbabilisticModel - A concrete implementation using a transformer-based architecture
   with support for different covariance matrix structures and parameterizations

These models generate samples by directly predicting an entire sequence of node potentials
and then sampling from the resulting conditional CRF. When used without covariance prediction,
they effectively function as deterministic forecasting models despite operating within
a probabilistic framework.
"""

################################################################################################################

class AbstractNonProbabilisticModel(AbstractModel, abc.ABC):
  """CRF of q(x_{1:N} | Y_{1:k}).  We generate samples at once by predicting
    the node potentials for the entire sequence at once.  When we don't predict the covariances
    then this is equivalent to a non-probabilistic forecasting model.
  """

  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder
  transformer: AbstractEncoderDecoderModel

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True) # If we don't predict the covariance, then we will solve for it by running message passing.
                                             # This is completely valid if we assume that the potential functions have a covariance
                                             # function that only depends on the time and not the value of our data.

  @abc.abstractmethod
  def predict_node_potentials(
    self,
    yts: TimeSeries,
  ) -> AbstractPotential:
    """Predict the node potentials for the entire latent sequence.

    Arguments:
      yts: The observed series.

    Returns:
      The node potentials for the entire latent sequence.
    """
    raise NotImplementedError

  def loss_fn(
    self,
    yts: TimeSeries,
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Dict[str, Scalar]:
    # yts and prob_series are NOT upsampled!  Also prob_series is
    # will be equal to self.encoder(yts)
    assert len(yts) == self.obs_seq_len, f'series must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    # Get the ground truth potentials
    true_potentials = self.encoder(yts).node_potentials
    true_potentials = true_potentials[self.cond_len:] # Only predict the unobserved potentials

    # Predict all of the node potentials.  We cut off the unobserved parts of yts inside the predict_node_potentials function.
    predicted_potentials = self.predict_node_potentials(yts)

    # Compute the potential matching loss
    true_potentials_mixed = true_potentials.to_mixed()
    predicted_potentials_mixed = predicted_potentials.to_mixed()
    true_mean, predicted_mean = true_potentials_mixed.mu, predicted_potentials_mixed.mu
    true_J, predicted_J = true_potentials_mixed.J, predicted_potentials_mixed.J

    J_diff = (true_J - predicted_J).elements
    mu_diff = (true_mean - predicted_mean)
    mse_loss = jnp.mean(mu_diff**2) + jnp.mean(J_diff**2)

    if debug:
      import pdb; pdb.set_trace()

    losses = dict(mse=mse_loss)
    return losses

  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    *,
    crf_state: Optional[CRFState] = None,
    debug: Optional[bool] = False,
    **kwargs
  ) -> TimeSeries:
    """Sample from p(x_{1:N} | Y_{1:k}) autoregressively.  This involves first sampling from p(x_1 | Y_{1:k}) and then iterating to sample from p(x_{i+1} | x_{1:i}, Y_{1:k})"""
    assert len(yts) == self.obs_seq_len, f'yts must be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, "yts must have the same dimension as the observed data"

    # Get the observed potentials
    observed_potentials = self.encoder(yts).node_potentials[:self.cond_len].to_nat()

    # Predict the node potentials.  We cut off the unobserved parts of yts inside the predict_node_potentials function.
    predicted_potentials = self.predict_node_potentials(yts).to_nat()

    # Combine the observed and predicted potentials
    node_potentials = jtu.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), observed_potentials, predicted_potentials)

    prob_series = ProbabilisticTimeSeries(yts.ts, node_potentials.to_nat()) # Go to natural parameters to improve numerical stability

    # Construct the CRF p(x | Y)
    cond_sde = ConditionedLinearSDE(self.linear_sde, prob_series, parallel=jax.devices()[0].platform == 'gpu')

    # Discretize the CRF at a specified set of times.  Use AbstractModel.get_discretization_info
    # to do this without worrying about duplicating times.
    discretization_info = self.get_discretization_info(yts.ts)
    crf = cond_sde.discretize(info=discretization_info)

    pred_xts_values = crf.sample(key)
    pred_xts = TimeSeries(discretization_info.ts, pred_xts_values)

    if debug:
      import pdb; pdb.set_trace()

    assert len(pred_xts) == self.latent_seq_len, f'pred_xts must have length {self.latent_seq_len} but has length {len(pred_xts)}.'

    return pred_xts

################################################################################################################

class MyNonProbabilisticModel(AbstractNonProbabilisticModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True)

  potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'] = eqx.field(static=True)
  parametrization: Literal['nat', 'mixed', 'std'] = eqx.field(static=True)
  matrix_dim: int = eqx.field(static=True)
  dim: int = eqx.field(static=True)

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               n_layers: int,
               filter_width: int,
               hidden_channel_size: int,
               num_transformer_heads: int,
               *,
               potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'],
               parametrization: Literal['nat', 'mixed', 'std'],
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               key: PRNGKeyArray,
               predict_cov: bool = False):
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.potential_cov_type = potential_cov_type
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.parametrization = parametrization
    self.predict_cov = predict_cov
    self.latent_cond_len = 0 # Not used in this model

    dim = self.linear_sde.dim

    if potential_cov_type == 'diagonal':
      matrix_dim = dim
    elif potential_cov_type == 'dense':
      matrix_dim = dim**2
    elif potential_cov_type == 'block2x2':
      matrix_dim = dim*2
      assert dim%2 == 0
    elif potential_cov_type == 'block3x3':
      matrix_dim = dim*3
      assert dim%3 == 0

    if self.predict_cov == False:
      # We'll compute the covariances using message passing.  This is valid
      # when our potential functions have a covariance that only depends on the
      # time and not the value of the data.
      out_size = dim
    else:
      # Parametrizing a transition requires 1 matrix + 1 vector
      out_size = matrix_dim + dim

    self.matrix_dim = matrix_dim
    self.dim = dim

    # Create the transformer
    hypers = MyModelHypers(n_blocks=n_layers,
                           hidden_channel_size=hidden_channel_size,
                           wavenet_kernel_width=filter_width,
                           num_transformer_heads=num_transformer_heads)
    self.transformer = MyModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=out_size,
                               hypers=hypers,
                               key=key,
                               strict_autoregressive=False,
                               causal_decoder=False) # Doesn't need to be autoregressive

  def predict_node_potentials(
    self,
    yts: TimeSeries,
  ) -> AbstractPotential:
    """Predict the node potentials for the entire latent sequence.

    Arguments:
      yts: The observed series.

    Returns:
      The node potentials for the entire latent sequence.
    """
    assert len(yts) == self.obs_seq_len, f'series must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Pass a dummy input sequence to the transformer
    dummy_input = TimeSeries(yts.ts, jnp.zeros((yts.ts.shape[0], self.encoder.x_dim)))[self.cond_len:]
    outputs = self.transformer(yts_cond, dummy_input)
    assert outputs.shape[0] == self.pred_len

    #####################################
    # Construct the covariance matrices
    #####################################

    def construct_matrix(vec: Float[Array, 'D'], mat: AbstractMatrix) -> AbstractPotential:
      # Helper function to construct a potential from a vector and a matrix
      if self.parametrization == 'nat':
        return NaturalGaussian(mat, vec)
      elif self.parametrization == 'mixed':
        return MixedGaussian(vec, mat)
      elif self.parametrization == 'std':
        return StandardGaussian(vec, mat)

    if self.predict_cov == False:

      # Get the covariance matrices from the encoder
      true_potentials = self.encoder(self.make_fully_observed(yts)).node_potentials[self.cond_len:]

      # Depending on the parametrization, we need to convert the covariance matrix
      T = outputs.shape[0]
      if self.parametrization == 'nat':
        mat = true_potentials.to_nat().J
      elif self.parametrization == 'mixed':
        mat = true_potentials.to_mixed().J
      elif self.parametrization == 'std':
        mat = true_potentials.to_std().Sigma

      # We don't have a last transition
      out_potentials = jax.vmap(construct_matrix)(outputs, mat)

    else:

      # Retrieve the parameters of the potential distribution
      vec, mat_elements = outputs[...,:self.dim], outputs[...,self.dim:]

      # Reshape the matrix to the correct shape
      if self.potential_cov_type == 'diagonal':
        pass # Nothing to do
      elif self.potential_cov_type == 'dense':
        mat_elements = mat_elements.reshape((-1 ,self.dim, self.dim))
      elif self.potential_cov_type == 'block2x2':
        mat_elements = mat_elements.reshape((-1 ,2, 2, self.dim//2))
      elif self.potential_cov_type == 'block3x3':
        mat_elements = mat_elements.reshape((-1 ,3, 3, self.dim//3))

      mat = jax.vmap(partial(util.to_matrix, symmetric=True))(mat_elements)

      # Add a bit of jitter to the covariance matrix to ensure that it has full rank
      jitter = 1e-8*mat.eye(mat.shape[1])
      mat = mat + jitter

      out_potentials = jax.vmap(construct_matrix)(vec, mat)

    assert out_potentials.batch_size == self.pred_len
    return out_potentials

################################################################################################################

class MyNonProbabilisticRNNModel(MyNonProbabilisticModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True)

  potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'] = eqx.field(static=True)
  parametrization: Literal['nat', 'mixed', 'std'] = eqx.field(static=True)
  matrix_dim: int = eqx.field(static=True)
  dim: int = eqx.field(static=True)

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               hidden_size: int,
               *,
               potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'],
               parametrization: Literal['nat', 'mixed', 'std'],
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               key: PRNGKeyArray,
               predict_cov: bool = False):
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.potential_cov_type = potential_cov_type
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.parametrization = parametrization
    self.predict_cov = predict_cov
    self.latent_cond_len = 0 # Not used in this model

    dim = self.linear_sde.dim

    if potential_cov_type == 'diagonal':
      matrix_dim = dim
    elif potential_cov_type == 'dense':
      matrix_dim = dim**2
    elif potential_cov_type == 'block2x2':
      matrix_dim = dim*2
      assert dim%2 == 0
    elif potential_cov_type == 'block3x3':
      matrix_dim = dim*3
      assert dim%3 == 0

    if self.predict_cov == False:
      # We'll compute the covariances using message passing.  This is valid
      # when our potential functions have a covariance that only depends on the
      # time and not the value of the data.
      out_size = dim
    else:
      # Parametrizing a transition requires 1 matrix + 1 vector
      out_size = matrix_dim + dim

    self.matrix_dim = matrix_dim
    self.dim = dim

    # Create the transformer
    hypers = MyGRUModelHypers(hidden_size=hidden_size)
    self.transformer = MyGRUModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=out_size,
                               hypers=hypers,
                               key=key)

################################################################################################################

if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import *
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.neural_diffusion_crf.interpolate_potential import InterpolateResult, initialize_potential_interpolation
  import matplotlib.pyplot as plt

  N = 100
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts, series.yts

  freq = 2

  key = random.PRNGKey(0)

  # Create the SDE and emission potential encoder
  y_dim = yts.shape[1]
  # sde = BrownianMotion(sigma=0.1, dim=y_dim)
  sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=1/(1 + freq))

  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.1)
  covariance_matrix = encoder(series).node_potentials[0].to_mixed().J

  I = sde.F.eye(sde.F.shape[0])
  prior = StandardGaussian(mu=jnp.zeros(sde.dim), Sigma=I)

  from diffusion_crf.matrix import *
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  model = MyNonProbabilisticModel(sde,
                                 encoder,
                                 n_layers=4,
                                 filter_width=4,
                                 hidden_channel_size=16,
                                 num_transformer_heads=4,
                                 potential_cov_type=potential_cov_type,
                                 parametrization='std',
                                 interpolation_freq=freq,
                                 seq_len=len(series),
                                 cond_len=5,
                                 key=key,
                                 predict_cov=False)

  # Run the loss function
  out = model.loss_fn(series,
                      key=key,
                      debug=False)


  # Sample a series
  sampled_series = model.sample(key, series, debug=False)

  # Check that the samples don't depend on unobserved values
  series2 = model.mask_observation_space_sequence(series)
  sampled_series2 = model.sample(key, series2, debug=False)
  assert jnp.allclose(sampled_series.yts, sampled_series2.yts)

  import pdb; pdb.set_trace()
