import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractSDE
from diffusion_crf.base import AbstractBatchableObject
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo, interleave_series
from diffusion_crf.ssm.simple_encoder import AbstractEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
from diffusion_crf import AbstractLinearSDE, MixedGaussian, NaturalGaussian, ConditionedLinearSDE, interleave_times
from diffusion_crf.matrix import AbstractMatrix
from diffusion_crf.base import AbstractTransition, AbstractPotential
from diffusion_crf.gaussian.transition import GaussianTransition
import diffusion_crf.util as util
from diffusion_crf.crf import Messages, CRF
from diffusion_crf.gaussian import StandardGaussian, MixedGaussian, NaturalGaussian
from diffusion_crf.gaussian.transition import GaussianTransition
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
from Models.models.base import AbstractModel, CRFState, AbstractHiddenState
from Models.models.base import AbstractEncoderDecoderModel
from Models.models.transformer import MyModel, MyModelHypers
from Models.models.ode_sde_simulation import AbstractSolverParams, ODESolverParams, ode_solve, SDESolverParams, sde_sample, DiffraxSolverState
from Models.models.rnn import MyGRUModelHypers, MyGRUModel, StackedGRUSequenceHypers, StackedGRURNN
from Models.models.ho_models.ho_base import AbstractBwdPredictorModel

from Models.models.neural_sde import LocalSDEWithState, LocalSDE, AbstractNeuralSDE
from Models.models.rnn import MyGRUModelState
from Models.models.resnet import TimeDependentResNet

################################################################################################################

class MyNeuralSDERNNBwd(AbstractNeuralSDE, AbstractBwdPredictorModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters
  resnet: TimeDependentResNet

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  predict_cov: bool = eqx.field(static=True)

  potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'] = eqx.field(static=True)
  parametrization: Literal['nat', 'mixed', 'std'] = eqx.field(static=True)
  matrix_dim: int = eqx.field(static=True)
  dim: int = eqx.field(static=True)


  predict_flow_or_drift: Literal['flow', 'drift'] = eqx.field(static=True)

  _log_function_evaluations: bool = eqx.field(static=True)
  use_sequential_sampling: bool = eqx.field(static=True)


  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               hidden_size: int,
               *,
               potential_cov_type: Literal['diagonal', 'dense', 'block2x2', 'block3x3'],
               parametrization: Literal['nat', 'mixed', 'std'],
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               latent_cond_len: int,
               key: PRNGKeyArray,
               predict_cov: bool = False):
    assert interpolation_freq == 1, "For reverse compatability"
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.potential_cov_type = potential_cov_type
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = latent_cond_len*(1 + self.interpolation_freq) # Multiple by the interpolation frequency to account for the fact that we are predicting at the original frequency
    self.parametrization = parametrization
    assert predict_cov == False, 'We don\'t support predicting the covariance for the reparametrized model'
    self.predict_cov = predict_cov
    self.use_sequential_sampling = False

    assert self.latent_generation_start_index >= 0, f'latent_generation_start_index must be non-negative but is {self.latent_generation_start_index}.'

    print(f'Generation length: {self.generation_len}, latent_seq_len: {self.latent_seq_len}, latent_generation_start_index: {self.latent_generation_start_index}')


    dim = self.linear_sde.dim

    if potential_cov_type == 'diagonal':
      matrix_dim = dim
    elif potential_cov_type == 'dense':
      matrix_dim = dim**2
    elif potential_cov_type == 'block2x2':
      matrix_dim = dim*2
      assert dim%2 == 0
    elif potential_cov_type == 'block3x3':
      matrix_dim = dim*3
      assert dim%3 == 0

    out_size = dim

    self.matrix_dim = matrix_dim
    self.dim = dim

    k1, k2 = random.split(key)

    # Create the transformer
    hypers = MyGRUModelHypers(hidden_size=hidden_size)
    self.transformer = MyGRUModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim,
                               out_channels=out_size, # We'll predict the bwd message at t_kp1 and give a feature to get the bwd at t
                              #  out_channels=out_size + hidden_size, # We'll predict the bwd message at t_kp1 and give a feature to get the bwd at t
                               hypers=hypers,
                               key=k1)
    # self.transformer = MyGRUModel(cond_in_channels=encoder.y_dim,
    #                            in_channels=dim,
    #                            out_channels=out_size, # We'll predict the bwd message at t_kp1 and give a feature to get the bwd at t
    #                           #  out_channels=out_size + hidden_size, # We'll predict the bwd message at t_kp1 and give a feature to get the bwd at t
    #                            hypers=hypers,
    #                            n_layers=3,
    #                            intermediate_channels=hidden_size,
    #                            key=k1)

    # Create the time dependent resnet to predict the backward messages at the intermediate times
    self.resnet = TimeDependentResNet(input_shape=(self.dim,),
                                         working_size=hidden_size,
                                         hidden_size=hidden_size,
                                         out_size=dim,
                                         n_blocks=3,
                                         cond_shape=(out_size,),
                                        #  cond_shape=(hidden_size,),
                                         embedding_size=2*hidden_size,
                                         out_features=hidden_size,
                                         key=k2)

    self.predict_flow_or_drift = 'drift'

    # Hack to get NFE results
    from Models.models.nfe_vs_tol import _HACK_TO_GET_NFE_RESULTS
    self._log_function_evaluations = _HACK_TO_GET_NFE_RESULTS


  def loss_fn(
    self,
    yts: TimeSeries,
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Dict[str, Scalar]:
    assert len(yts) == self.obs_seq_len, f'series must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    k1, k2 = random.split(key)

    # We need to draw random times in between the times in prob_series.ts
    info = self.get_random_discretization_info(k1, yts.ts)

    if info.ts.shape[0] != self.latent_seq_len:
      raise ValueError(f'info.ts must have length {self.latent_seq_len} but has length {info.ts.shape[0]}.')

    # Create the crf state to get the backward messages
    yts = self.make_fully_observed_series(yts)
    crf_state = CRFState(self.linear_sde, self.encoder, yts, info, parameterization='mixed')
    cond_sde = crf_state.cond_sde
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages

    # Mask the unobserved parts of the series for sanity checking.
    # This doesn't affect anything if we implemented the model correctly.
    yts = self.mask_observation_space_sequence(yts)

    # Sample from p(x_{1:N} | Y_{1:N})
    xts_values = crf.sample(key, messages=messages)
    xts = TimeSeries(info.ts, xts_values)
    xts_generation_buffer = xts[self.latent_generation_start_index:] # Get x_{l:N}

    # Compute the true smoothed transitions, p(x_{i+1} | x_{l:i}, Y_{1:N})
    # and condition on the previous latent variable to get the next distribution
    bwd = crf_state.messages.bwd
    bwd_and_node = bwd + crf.node_potentials
    true_next_bwd_and_node = bwd_and_node[self.latent_generation_start_index:-1] # minus 1 because we're predicting the current backward message
    assert true_next_bwd_and_node.batch_size == self.generation_len - 1
    true_next_bwd_and_node = info.filter_new_times(true_next_bwd_and_node)

    # Compute the backward messages from the model
    intermediate_xts = info.filter_new_times(xts)
    original_xts = info.filter_base_times(xts)
    predicted_next_bwd_and_node = self.predict_current_backward_messages(intermediate_xts, yts, original_xts)

    # Compute the potential matching loss
    true_next_bwd_and_node_mixed = true_next_bwd_and_node.to_mixed()
    predicted_next_bwd_and_node_mixed = predicted_next_bwd_and_node.to_mixed()
    true_mean, predicted_mean = true_next_bwd_and_node_mixed.mu, predicted_next_bwd_and_node_mixed.mu

    mu_diff = (true_mean - predicted_mean)
    mse_loss = jnp.mean(mu_diff**2)

    losses = dict(mse=mse_loss)

    if debug:
      import pdb; pdb.set_trace()

    return losses

  def get_initial_state(self, yts: TimeSeries, context: Optional[Float[Array, 'S C']] = None) -> AbstractHiddenState:
    raise NotImplementedError('Not implemented')

  def single_step_predict_control(self, t: float, xt: Float[Array, 'D'], state: AbstractHiddenState) -> Float[Array, 'D']:
    raise NotImplementedError('Not implemented')

  def predict_current_backward_messages(
    self,
    intermediate_xts: TimeSeries,
    yts: TimeSeries,
    xts: TimeSeries,
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ):
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    assert xts.yts.shape[-1] == self.linear_sde.dim, "xts must have the same dimension as the SDE"

    if self._log_function_evaluations:
      from Models.nfe_vs_tol import _MODEL_EVALUATION_PRINT_STATEMENT
      jax.debug.print(_MODEL_EVALUATION_PRINT_STATEMENT)

    #####################################
    # For each t_k, predict features to use for when we predict the backward message for time t
    #####################################

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Create the context
    if context is None:
      context = self.transformer.create_context(yts_cond)

    # # Predict the drift
    # out_and_features_k = self.transformer(yts_cond, xts, context)[:-1]
    # assert out_and_features_k.shape[0] == self.pred_len
    # assert out_and_features_k.shape[0] == len(intermediate_xts)
    # latent_dim = self.linear_sde.dim
    # out, features_k = out_and_features_k[...,:latent_dim], out_and_features_k[...,latent_dim:]

    # Predict the drift
    features_k = self.transformer(yts_cond, xts, context)[:-1]
    # assert features_k.shape[0] == self.pred_len
    assert features_k.shape[0] == len(intermediate_xts)

    #####################################
    # Use each of the features_k to predict the means of the backward message for time t_{k+1}
    #####################################
    outputs = jax.vmap(self.resnet)(intermediate_xts.ts, intermediate_xts.yts, features_k)
    outputs = outputs + features_k

    #####################################
    # Construct the covariance matrices
    #####################################

    def construct_potential(vec: Float[Array, 'D'], mat: AbstractMatrix) -> AbstractPotential:
      # Helper function to construct a potential from a vector and a matrix
      if self.parametrization == 'nat':
        return NaturalGaussian(mat, vec)
      elif self.parametrization == 'mixed':
        return MixedGaussian(vec, mat)
      elif self.parametrization == 'std':
        return StandardGaussian(vec, mat)

    crf_state = self.make_crf_state(yts) # This needs the entire observed sequence only for the times
    crf: CRF = crf_state.crf
    messages: Messages = crf_state.messages
    info: DiscretizeInfo = crf_state.discretization_info

    # Compute the backward messages.  The means of this are not correct
    # but the covariances are correct because of our assumption that the potential
    # functions have a covariance that only depends on the time and not the value of
    # the data.  This assumption works because the smoothed mean has a linear relationship
    # with Y, but the covariance has a nonlinear relationship with Y due to matrix inverses
    # during message passing.
    backward_messages = crf_state.messages.bwd
    bwd_node = backward_messages + crf.node_potentials
    bwd_node = bwd_node[self.latent_generation_start_index + 1:]

    # Depending on the parametrization, we need to convert the covariance matrix
    T = outputs.shape[0]
    if self.parametrization == 'nat':
      mat = bwd_node.to_nat().J
    elif self.parametrization == 'mixed':
      mat = bwd_node.to_mixed().J
    elif self.parametrization == 'std':
      mat = bwd_node.to_std().Sigma

    # Extract the relevant matrices
    dummy_key = random.PRNGKey(0)
    info = self.get_random_discretization_info(dummy_key, yts.ts)
    intermediate_mat = info.filter_new_times(mat)

    # Construct the predicted backward messages + node potentials
    predicted_bwd_node = jax.vmap(construct_potential)(outputs, intermediate_mat)
    return predicted_bwd_node

  def predict_control(
    self,
    yts: TimeSeries, # The observed data
    xts: TimeSeries, # The series that we have already generated
    *,
    context: Optional[Float[Array, 'S C']] = None,
    current_index: Optional[int] = None
  ) -> Float[Array, 'T D']:
    """Predict the drift or flow for the series.

    Arguments:
      yts: The observed data
      xts: The series that we have already generated
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      The drift or flow at the times in series_dec.
    """
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert len(xts) == self.generation_len, f'xts must have length {self.generation_len} but has length {len(xts)}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    assert xts.yts.shape[-1] == self.linear_sde.dim, "xts must have the same dimension as the SDE"

    dummy_key = random.PRNGKey(0)
    info = self.get_random_discretization_info(dummy_key, yts.ts)
    intermediate_xts = info.filter_new_times(xts)
    original_xts = info.filter_base_times(xts)
    bwd_messages: AbstractPotential = self.predict_current_backward_messages(intermediate_xts, yts, original_xts, context=context)

    # For completeness, we need to interleave the backward messages with the original xts
    T = len(original_xts)
    zero = bwd_messages[0].zeros_like(bwd_messages[0])
    dummy_messages = jtu.tree_map(lambda x: jnp.broadcast_to(x, (T, *x.shape)), zero)
    full_bwd_messages: MixedGaussian = info.interleave(bwd_messages, dummy_messages)
    full_bwd_messages = full_bwd_messages[self.latent_generation_start_index:]
    assert full_bwd_messages.batch_size == self.generation_len

    # Compute the score of the backward messages
    score = full_bwd_messages.score(xts.yts)

    # Add on the part of the drift/flow that is due to the linear SDE
    F, L = self.linear_sde.F, self.linear_sde.L
    def fix_control(t, xt, control):
      LTc = L.T@control
      return F@xt + L@LTc

    final_control = jax.vmap(fix_control)(xts.ts, xts.yts, score)

    if current_index is not None:
      return final_control[current_index]
    else:
      return final_control


class MySDEState(AbstractHiddenState):
  transformer_state: MyGRUModelState
  covariance_matrices: AbstractMatrix

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.transformer_state.batch_size

  @property
  def index(self) -> int:
    return self.transformer_state.index

################################################################################################################

def blah():
  from Models.experiment_identifier import ExperimentIdentifier
  from main import load_empty_model
  from diffusion_crf import TAGS

  ei = ExperimentIdentifier.make_experiment_id(config_name='noisy_double_pendulum',
                                              objective='mse',
                                              model_name='my_neural_sde_rnn_bwd',
                                              sde_type='tracking',
                                              freq=1,
                                              group='final_models',
                                              seed=0)

  ts = ei.get_train_state()
  model: AbstractModel = ts.model
  best_model: AbstractModel = ts.best_model
  datasets = ei.get_data_fixed()
  train_data, val_data, test_data = datasets['train_data'], datasets['val_data'], datasets['test_data']
  series = train_data[0]

  key = random.PRNGKey(0)
  # best_model.loss_fn(series, key, debug=True)
  out1 = best_model.sample(key, series, debug=False)
  import pdb; pdb.set_trace()


################################################################################################################

if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import BrownianMotion, CriticallyDampedLangevinDynamics, TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde import HigherOrderTrackingModel
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.neural_diffusion_crf.interpolate_potential import InterpolateResult, initialize_potential_interpolation
  import matplotlib.pyplot as plt
  import wadler_lindig as wl
  # turn on x64
  # jax.config.update("jax_enable_x64", True)

  # blah()

  # N = 15
  N = 20
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts.astype(jnp.float32), series.yts.astype(jnp.float32)
  series = TimeSeries(data_times, yts)

  freq = 1

  key = random.PRNGKey(0)

  # Create the SDE and emission potential encoder
  y_dim = yts.shape[1]
  # sde = BrownianMotion(sigma=0.1, dim=y_dim)
  sde = HigherOrderTrackingModel(sigma=0.1, position_dim=y_dim, order=2)
  # sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=1.0)

  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.1)
  covariance_matrix = encoder(series).node_potentials[0].to_mixed().J

  I = sde.F.eye(sde.F.shape[0])
  prior = StandardGaussian(mu=jnp.zeros(sde.dim), Sigma=I)

  from diffusion_crf.matrix import *
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  # model = MyNeuralSDE(sde,
  #                     encoder,
  #                     n_layers=10,
  #                     filter_width=4,
  #                     hidden_channel_size=128,
  #                     num_transformer_heads=16,
  #                     interpolation_freq=freq,
  #                     seq_len=len(series),
  #                     cond_len=5,
  #                     latent_cond_len=4,
  #                     key=key,
  #                     predict_flow_or_drift='flow')

  model = MyNeuralSDERNNBwd(sde,
                                encoder,
                                hidden_size=16,
                                potential_cov_type=potential_cov_type,
                                parametrization='mixed',
                                interpolation_freq=freq,
                                seq_len=len(series),
                                cond_len=6,
                                latent_cond_len=1,
                                key=key,
                                predict_cov=False)

  xts = model.basic_interpolation(key, series)

  # # # Evaluate the loss function
  out = model.loss_fn(series, key=key, debug=False)
  control = model.predict_control(series, xts[model.latent_generation_start_index:])
  import pdb; pdb.set_trace()

  # Pull a regular sample from the model
  xts = model.sample(key, series, debug=True)
  import pdb; pdb.set_trace()

  context = model.transformer.create_context(series[:model.cond_len])

  # Predict the entire control
  control2 = model.predict_control(series, xts, context=context)

  # Try predicting the control sequentially
  state = model.get_initial_state(series)
  all_controls = []
  for i in range(model.latent_generation_start_index, len(xts)):
    t, xt = xts.ts[i], xts.yts[i]
    c, state = model.single_step_predict_control(t, xt, state)
    all_controls.append(c)
  all_controls = jnp.array(all_controls)

  assert jnp.allclose(control, all_controls)



  # Pull a regular sample from the model
  out1 = model.sample(key, series, debug=False)


  # Pull a sample from the model with state
  out2 = model.sequential_sample(key, series, debug=False)

  assert jnp.allclose(out1.yts, out2.yts)


  import pdb; pdb.set_trace()


