import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractSDE
from diffusion_crf.sde import BrownianMotion
from diffusion_crf.base import AbstractBatchableObject
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo, interleave_series
from diffusion_crf.ssm.simple_encoder import AbstractEncoder, IdentityEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
from diffusion_crf import AbstractLinearSDE, MixedGaussian, NaturalGaussian, ConditionedLinearSDE, interleave_times
from diffusion_crf.matrix import AbstractMatrix
from diffusion_crf.base import AbstractTransition, AbstractPotential
from diffusion_crf.gaussian.transition import GaussianTransition
import diffusion_crf.util as util
from diffusion_crf.api.stochastic_interpolation import make_stochastic_interpolator
from diffusion_crf.crf import Messages, CRF
from diffusion_crf.gaussian import StandardGaussian, MixedGaussian, NaturalGaussian
from diffusion_crf.gaussian.transition import GaussianTransition
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
from Models.models.base import AbstractModel, CRFState
from Models.models.timefeatures import TimeFeatures
from Models.models.base import AbstractEncoderDecoderModel
from Models.models.transformer import MyModel, MyModelHypers
from Models.models.ode_sde_simulation import ODESolverParams, ode_solve, SDESolverParams, sde_sample
from Models.models.rnn import MyGRUModelHypers, MyGRUModel

"""
This module implements flow-based diffusion models for time series forecasting.
Predicts observation space flows, not latent space flows.
"""

################################################################################################################

class AbstractBaselineDiffusionModel(AbstractModel, abc.ABC):
  """A diffusion model that models the process p(y_{1:n} | y_{1:k})"""

  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder
  transformer: AbstractEncoderDecoderModel

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  def series_to_vec(self, series: TimeSeries) -> Float[Array, 'T*D']:
    """Convert a series to a vector of shape (T*D,)"""
    return einops.rearrange(series.yts, 'T D -> (T D)')

  @abc.abstractmethod
  def predict_control(
    self,
    s: Scalar,
    yts: TimeSeries, # The observed data
    ys_future_series: TimeSeries, # The series at simulation time s
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> Float[Array, 'P D']:
    """Predict the drift or flow for the series.

    Arguments:
      s: The current simulation time
      yts: The observed data
      ys_future_series: The future series at simulation time s
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      The drift or flow at the times in ys_future_series, which has length self.pred_len.
    """
    pass

  def loss_fn(
    self,
    yts: TimeSeries,
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Dict[str, Scalar]:
    assert len(yts) == self.obs_seq_len, f'series must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    k1, k2, k3 = random.split(key, 3)

    # Split the series into the conditioning part and the future part
    yts_cond, yts_future = yts[:self.cond_len], yts[self.cond_len:]

    ########################
    # Brownian bridge endpoints
    ########################
    # Flatten the series to a vector that represents a sample from the target
    y1 = self.series_to_vec(yts_future)
    series_shape = yts_future.yts.shape

    # Sample from the prior
    y0 = random.normal(k1, y1.shape)
    endpoint_ts = jnp.array([0.0, 1.0])
    endpoint_yts = jnp.concatenate([y0[None], y1[None]], axis=0)
    endpoint_series = TimeSeries(endpoint_ts, endpoint_yts)

    ########################
    # Brownian bridge simulation
    ########################
    encoder = IdentityEncoder(dim=y0.shape[-1])
    prob_series = encoder(endpoint_series)
    bm = BrownianMotion(sigma=0.1, dim=y0.shape[-1])
    brownian_bridge = ConditionedLinearSDE(bm, prob_series, parallel=jax.devices()[0].platform == 'gpu')

    # Sample a random time in between the endpoints
    s = random.uniform(k2, shape=())
    items = brownian_bridge.multi_sample_matching_items(jnp.array([s]), key) # Returns the items at the new time and the original times

    ys = items.xt[1]
    flow_target = items.flow[1] # We are going to do flow matching because it is the fastest at inference time
    flow_target = flow_target.reshape(series_shape)

    # Reshape the sampled part
    ys_future = ys.reshape(series_shape)

    # Predict the flow with out model
    ys_future_series = TimeSeries(yts_future.ts, ys_future)
    pred_flow = self.predict_control(s, yts, ys_future_series)

    # Compute the flow matching loss
    matching_loss = jnp.mean((pred_flow - flow_target)**2)

    # As a sanity check, check the cosine similarity between the predicted flow_target and the true flow_target
    cos_sim = jnp.mean(jnp.sum(pred_flow*flow_target, axis=-1) / (jnp.linalg.norm(pred_flow, axis=-1)*jnp.linalg.norm(flow_target, axis=-1)))

    losses = dict(flow_matching=matching_loss, cos_sim=cos_sim)

    if debug:
      import pdb; pdb.set_trace()

    return losses

  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    *,
    crf_state: Optional[CRFState] = None,
    debug: Optional[bool] = False,
    to_latent: bool = True,
    **kwargs
  ) -> TimeSeries:
    """Sample from p(x_{1:N} | Y_{1:k}) non-autoregressively"""
    assert len(yts) == self.obs_seq_len, f'yts must be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, "yts must have the same dimension as the observed data"

    # Split the series into the past and future parts
    past_yts, dummy_future_yts = yts[:self.cond_len], yts[self.cond_len:]

    # Create the context
    context = self.transformer.create_context(past_yts)

    # Create the local SDE that will perform the transport
    info = self.get_discretization_info(dummy_future_yts.ts)
    local_sde = LocalSDE(self, yts, context, info)

    # Sample from the prior
    k1, k2 = random.split(key)
    y0 = random.normal(k1, shape=(self.pred_len, self.encoder.y_dim))

    # Simulate the SDE
    save_times = jnp.array([0.0, 1.0])
    yts_flat = local_sde.simulate(y0.ravel(), save_times)

    yts_values = yts_flat[1].yts.reshape(y0.shape)
    pred_future_yts = TimeSeries(info.ts, yts_values)

    # Concatenate the observed series with the predicted future series
    all_ts = jnp.concatenate([past_yts.ts, pred_future_yts.ts])
    all_yts = jnp.concatenate([past_yts.yts, pred_future_yts.yts], axis=0)
    pred_yts = TimeSeries(all_ts, all_yts)

    if to_latent:
      pred_yts = self.basic_interpolation(k2, pred_yts)

    if debug:
      import pdb; pdb.set_trace()

    return pred_yts


class LocalSDE(AbstractSDE):
  """Helper class to pass into the ODE solver"""

  model: AbstractBaselineDiffusionModel # Our model that contains a transformer to predict the control at all times

  yts: TimeSeries              # The observed data
  context: Float[Array, 'S C'] # The context embedding of the observed data

  info: DiscretizeInfo

  @property
  def batch_size(self) -> int:
    return self.model.batch_size

  def get_control(self, s: Scalar, xs_flat: Float[Array, 'T D']) -> Float[Array, 'T D']:
    """Get the control at time t for the local SDE.  This is the control that will be used to transport
    the sample from the start distribution to the end distribution.
    """
    xs = einops.rearrange(xs_flat, '(T D) -> T D', T=self.info.ts.shape[0])
    series_dec = TimeSeries(self.info.ts, xs)

    # Predict the control at the new time
    control = self.model.predict_control(s, self.yts, series_dec, context=self.context)
    return control.ravel()

  def get_drift(self, t: Scalar,  xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the drift of this SDE at time t."""
    raise NotImplementedError('we\'re doing flow matching')

  def get_diffusion_coefficient(self, t: Scalar, xt: Float[Array, 'D']) -> AbstractMatrix:
    """Get the diffusion coefficient of this SDE at time t.  By construction this is the same
    as the diffusion coefficient of the base linear SDE.
    """
    raise NotImplementedError('we\'re doing flow matching')

  def get_flow(self, t: Scalar, xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the vector field of the probability flow ODE of this SDE at time t."""
    return self.get_control(t, xt)

  def get_transition_distribution(self, s: Scalar, t: Scalar) -> AbstractTransition:
    raise NotImplementedError('No closed form transition distribution for neural SDEs.')

  def simulate(
    self,
    x0: Float[Array, 'D'],
    save_times: Float[Array, 'T']
  ) -> TimeSeries:
    """Transport the sample x0 from the start time (at save_times[0]) to the end time (at save_times[-1])"""

    # Simulate the probability flow ODE of the neural SDE
    solver_params = ODESolverParams(rtol=1e-3,
                                    atol=1e-5,
                                    solver='dopri5',
                                    adjoint='recursive_checkpoint',
                                    stepsize_controller='pid',
                                    max_steps=2000,
                                    throw=False,
                                    progress_meter=None)
                                    # progress_meter='tqdm')

    simulated_trajectory = ode_solve(self,
                                      x0=x0,
                                      save_times=save_times,
                                      params=solver_params)

    return simulated_trajectory

################################################################################################################

class MyBaselineDiffusionModel(AbstractBaselineDiffusionModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from
  dim: int = eqx.field(static=True)

  simulation_time_features: TimeFeatures # Learnable parameter.  Doing this here to avoid re-implementing the time features in the transformer.

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               n_layers: int,
               filter_width: int,
               hidden_channel_size: int,
               num_transformer_heads: int,
               *,
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               key: PRNGKeyArray):
    assert interpolation_freq == 0, 'We don\'t support downsampling here!'
    self.interpolation_freq = interpolation_freq
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = 0 # Not used in this model

    k1, k2 = random.split(key)
    dim = self.encoder.y_dim

    self.dim = dim

    time_feature_size = hidden_channel_size
    self.simulation_time_features = TimeFeatures(embedding_size=2*time_feature_size,
                                       out_features=time_feature_size,
                                       key=k2)
    # Create the transformer
    hypers = MyModelHypers(n_blocks=n_layers,
                           hidden_channel_size=hidden_channel_size,
                           wavenet_kernel_width=filter_width,
                           num_transformer_heads=num_transformer_heads)
    self.transformer = MyModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim + time_feature_size,
                               out_channels=dim,
                               hypers=hypers,
                               key=k1,
                               strict_autoregressive=False,
                               causal_decoder=False) # We don't need to be autoregressive!

  def predict_control(
    self,
    s: Scalar,
    yts: TimeSeries, # The observed data
    ys_future_series: TimeSeries, # The series at simulation time s
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> Float[Array, 'T D']:
    """Predict the drift or flow for the series.

    Arguments:
      s: The current simulation time
      yts: The observed data
      ys_future_series: The future series at simulation time s
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      The drift or flow at the times in ys_future_series, which has length self.pred_len.
    """
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert len(ys_future_series) == self.pred_len, f'ys_future_series must have length {self.pred_len} but has length {len(ys_future_series)}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    assert ys_future_series.yts.shape[-1] == self.encoder.y_dim, "ys_future_series must have the same dimension as the SDE"

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Create the context
    if context is None:
      context = self.transformer.create_context(yts_cond)

    # Create the time features for the simulation time
    simulation_time_features = self.simulation_time_features(s)
    simulation_time_features = jnp.broadcast_to(simulation_time_features[None], (len(ys_future_series), simulation_time_features.shape[-1]))

    # Concatentate the time features with the series
    concat_features = jnp.concatenate([ys_future_series.yts, simulation_time_features], axis=-1)
    xs_aug = TimeSeries(ys_future_series.ts, concat_features)

    # Predict the control
    outputs = self.transformer(yts_cond, xs_aug, context=context)
    assert outputs.shape[0] == self.pred_len
    return outputs

################################################################################################################

class MyBaselineDiffusionRNNModel(MyBaselineDiffusionModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from
  dim: int = eqx.field(static=True)

  simulation_time_features: TimeFeatures # Learnable parameter.  Doing this here to avoid re-implementing the time features in the transformer.

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               hidden_size: int,
               *,
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               key: PRNGKeyArray):
    assert interpolation_freq == 0, 'We don\'t support downsampling here!'
    self.interpolation_freq = interpolation_freq
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = 0 # Not used in this model

    k1, k2 = random.split(key)
    dim = self.encoder.y_dim

    self.dim = dim

    time_feature_size = hidden_size
    self.simulation_time_features = TimeFeatures(embedding_size=2*time_feature_size,
                                       out_features=time_feature_size,
                                       key=k2)

    # Create the transformer
    hypers = MyGRUModelHypers(hidden_size=hidden_size)
    self.transformer = MyGRUModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim + time_feature_size,
                               out_channels=dim,
                               hypers=hypers,
                               key=key)

################################################################################################################

if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import BrownianMotion, CriticallyDampedLangevinDynamics, TimeScaledLinearTimeInvariantSDE, HigherOrderTrackingModel
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.neural_diffusion_crf.interpolate_potential import InterpolateResult, initialize_potential_interpolation
  import matplotlib.pyplot as plt

  N = 10
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts, series.yts

  freq = 0

  key = random.PRNGKey(0)

  # Create the SDE and emission potential encoder
  y_dim = yts.shape[1]
  # sde = BrownianMotion(sigma=0.1, dim=y_dim)
  sde = HigherOrderTrackingModel(sigma=0.1, position_dim=y_dim, order=2)
  # sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  # sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=1/(1 + freq))

  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.1)
  covariance_matrix = encoder(series).node_potentials[0].to_mixed().J

  I = sde.F.eye(sde.F.shape[0])
  prior = StandardGaussian(mu=jnp.zeros(sde.dim), Sigma=I)

  from diffusion_crf.matrix import *
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  model = MyBaselineDiffusionModel(sde,
                      encoder,
                      n_layers=4,
                      filter_width=4,
                      hidden_channel_size=16,
                      num_transformer_heads=4,
                      interpolation_freq=freq,
                      seq_len=len(series),
                      cond_len=5,
                      key=key)

  # Run the loss function
  out = model.loss_fn(series,
                      key=key,
                      debug=False)

  # Sample a series
  sampled_series = model.sample(key, series, debug=False)

  # Check that the samples don't depend on unobserved values
  series2 = model.mask_observation_space_sequence(series)
  sampled_series2 = model.sample(key, series2, debug=False)
  assert jnp.allclose(sampled_series.yts, sampled_series2.yts)

  # Check downsampling
  latent_seq = model.downsample_seq_to_original_freq(sampled_series)

  import pdb; pdb.set_trace()

