import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractSDE
from diffusion_crf.sde import BrownianMotion
from diffusion_crf.base import AbstractBatchableObject
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo, interleave_series
from diffusion_crf.ssm.simple_encoder import AbstractEncoder, IdentityEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
from diffusion_crf import AbstractLinearSDE, MixedGaussian, NaturalGaussian, ConditionedLinearSDE, interleave_times
from diffusion_crf.matrix import AbstractMatrix
from diffusion_crf.base import AbstractTransition, AbstractPotential
from diffusion_crf.gaussian.transition import GaussianTransition
import diffusion_crf.util as util
from diffusion_crf.api.stochastic_interpolation import make_stochastic_interpolator
from diffusion_crf.crf import Messages, CRF
from diffusion_crf.gaussian import StandardGaussian, MixedGaussian, NaturalGaussian
from diffusion_crf.gaussian.transition import GaussianTransition
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
from Models.models.base import AbstractModel, CRFState
from Models.models.timefeatures import TimeFeatures
from Models.models.base import AbstractEncoderDecoderModel
from Models.models.transformer import MyModel, MyModelHypers
from Models.models.ode_sde_simulation import ODESolverParams, ode_solve, SDESolverParams, sde_sample
from Models.models.rnn import MyGRUModelHypers, MyGRUModel

"""
This module implements flow-based diffusion models for time series forecasting. It provides:

1. AbstractDiffusionModel - A base class for models that transform a standard Gaussian prior
   into the target conditional distribution p(x_{1:n} | y_{1:k}) using flow matching

2. LocalSDE - A helper class that handles the simulation of the transport process from
   the prior to the target distribution

3. MyDiffusionModel - A concrete implementation using a transformer-based architecture
   that incorporates simulation time as an additional feature

These models generate samples in a non-autoregressive manner by learning to transform
a simple prior distribution (standard Gaussian) into the target conditional distribution
in a single pass, enabling efficient parallel sampling while capturing complex dependencies
in the data.
"""

################################################################################################################

class AbstractDiffusionModel(AbstractModel, abc.ABC):
  """A diffusion model that models the process p(x_{1:n} | y_{1:k}) by going from a Gaussian prior to p(x_{1:n} | y_{1:k})"""

  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder
  transformer: AbstractEncoderDecoderModel

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  def series_to_vec(self, series: TimeSeries) -> Float[Array, 'T*D']:
    """Convert a series to a vector of shape (T*D,)"""
    return einops.rearrange(series.yts, 'T D -> (T D)')

  @abc.abstractmethod
  def predict_control(
    self,
    s: Scalar,
    yts: TimeSeries, # The observed data
    xs: TimeSeries, # The series at simulation time s
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> Float[Array, 'T D']:
    """Predict the drift or flow for the series.

    Arguments:
      s: The current simulation time
      yts: The observed data
      xs: The series at simulation time s
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      The drift or flow at the times in series_dec.
    """
    pass

  def loss_fn(
    self,
    yts: TimeSeries,
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Dict[str, Scalar]:
    assert len(yts) == self.obs_seq_len, f'series must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'

    k1, k2, k3 = random.split(key, 3)

    # Sample from p(x_{1:N} | Y_{1:N})
    prob_series = self.encoder(yts)
    cond_sde = ConditionedLinearSDE(self.linear_sde, prob_series, parallel=jax.devices()[0].platform == 'gpu')
    info = self.get_discretization_info(yts.ts)
    crf = cond_sde.discretize(info=info)
    xts_values = crf.sample(key)
    full_xts = TimeSeries(info.ts, xts_values)

    xts = full_xts[self.latent_generation_start_index:] # Get x_{l:N}
    series_shape = xts.yts.shape

    ########################
    # Brownian bridge endpoints
    ########################
    # Flatten the series to a vector that represents a sample from the target
    x1 = self.series_to_vec(xts)

    # Sample from the prior
    x0 = random.normal(k1, x1.shape)
    endpoint_ts = jnp.array([0.0, 1.0])
    endpoint_xts = jnp.concatenate([x0[None], x1[None]], axis=0)
    endpoint_series = TimeSeries(endpoint_ts, endpoint_xts)

    ########################
    # Brownian bridge simulation
    ########################
    encoder = IdentityEncoder(dim=x0.shape[-1])
    prob_series = encoder(endpoint_series)
    bm = BrownianMotion(sigma=0.1, dim=x0.shape[-1])
    brownian_bridge = ConditionedLinearSDE(bm, prob_series, parallel=jax.devices()[0].platform == 'gpu')

    # Sample a random time in between the endpoints
    s = random.uniform(k2, shape=())
    items = brownian_bridge.multi_sample_matching_items(jnp.array([s]), key) # Returns the items at the new time and the original times

    xs = items.xt[1]
    flow_target = items.flow[1] # We are going to do flow matching because it is the fastest at inference time
    flow_target = flow_target.reshape(series_shape)

    # Predict the flow with out model
    xs_series = TimeSeries(xts.ts, xs.reshape(series_shape))
    pred_flow = self.predict_control(s, yts, xs_series)

    # Compute the flow matching loss
    matching_loss = jnp.mean((pred_flow - flow_target)**2)

    # As a sanity check, check the cosine similarity between the predicted flow_target and the true flow_target
    cos_sim = jnp.mean(jnp.sum(pred_flow*flow_target, axis=-1) / (jnp.linalg.norm(pred_flow, axis=-1)*jnp.linalg.norm(flow_target, axis=-1)))

    losses = dict(flow_matching=matching_loss, cos_sim=cos_sim)

    if debug:
      import pdb; pdb.set_trace()

    return losses

  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    *,
    crf_state: Optional[CRFState] = None,
    debug: Optional[bool] = False,
    **kwargs
  ) -> TimeSeries:
    """Sample from p(x_{1:N} | Y_{1:k}) non-autoregressively"""
    assert len(yts) == self.obs_seq_len, f'yts must be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, "yts must have the same dimension as the observed data"

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Create the context
    context = self.transformer.create_context(yts_cond)

    # Create the local SDE that will perform the transport
    info = self.get_discretization_info(yts.ts)
    ts = info.ts[self.latent_generation_start_index:]
    local_sde = LocalSDE(self, yts, context, ts)

    # Sample from the prior
    x0 = random.normal(key, shape=(self.generation_len, self.linear_sde.dim))

    # Simulate the SDE
    save_times = jnp.array([0.0, 1.0])
    xts_flat = local_sde.simulate(x0.ravel(), save_times)

    xts_values = xts_flat[1].yts.reshape(x0.shape)
    xts = TimeSeries(ts, xts_values)

    # Concatenate the original latent sequence as well
    yts = self.mask_observation_space_sequence(yts)
    xts_naive = self.basic_interpolation(key, yts)[:self.latent_generation_start_index]

    xts = jtu.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), xts_naive, xts)

    if debug:
      import pdb; pdb.set_trace()

    return xts


class LocalSDE(AbstractSDE):
  """Helper class to pass into the ODE solver"""

  model: AbstractDiffusionModel # Our model that contains a transformer to predict the control at all times

  yts: TimeSeries              # The observed data
  context: Float[Array, 'S C'] # The context embedding of the observed data

  ts: Float[Array, 'T']

  @property
  def batch_size(self) -> int:
    return self.model.batch_size

  def get_control(self, s: Scalar, xs_flat: Float[Array, 'T D']) -> Float[Array, 'T D']:
    """Get the control at time t for the local SDE.  This is the control that will be used to transport
    the sample from the start distribution to the end distribution.
    """
    xs = einops.rearrange(xs_flat, '(T D) -> T D', T=self.ts.shape[0])
    series_dec = TimeSeries(self.ts, xs)

    # Predict the control at the new time
    control = self.model.predict_control(s, self.yts, series_dec, context=self.context)
    return control.ravel()

  def get_drift(self, t: Scalar,  xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the drift of this SDE at time t."""
    raise NotImplementedError('we\'re doing flow matching')

  def get_diffusion_coefficient(self, t: Scalar, xt: Float[Array, 'D']) -> AbstractMatrix:
    """Get the diffusion coefficient of this SDE at time t.  By construction this is the same
    as the diffusion coefficient of the base linear SDE.
    """
    raise NotImplementedError('we\'re doing flow matching')

  def get_flow(self, t: Scalar, xt: Float[Array, 'D']) -> Float[Array, 'D']:
    """Get the vector field of the probability flow ODE of this SDE at time t."""
    return self.get_control(t, xt)

  def get_transition_distribution(self, s: Scalar, t: Scalar) -> AbstractTransition:
    raise NotImplementedError('No closed form transition distribution for neural SDEs.')

  def simulate(
    self,
    x0: Float[Array, 'D'],
    save_times: Float[Array, 'T']
  ) -> TimeSeries:
    """Transport the sample x0 from the start time (at save_times[0]) to the end time (at save_times[-1])"""

    # Simulate the probability flow ODE of the neural SDE
    solver_params = ODESolverParams(rtol=1e-3,
                                    atol=1e-5,
                                    solver='dopri5',
                                    adjoint='recursive_checkpoint',
                                    stepsize_controller='pid',
                                    max_steps=2000,
                                    throw=False,
                                    progress_meter=None)
                                    # progress_meter='tqdm')

    simulated_trajectory = ode_solve(self,
                                      x0=x0,
                                      save_times=save_times,
                                      params=solver_params)

    return simulated_trajectory

################################################################################################################

class MyDiffusionModel(AbstractDiffusionModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)

  dim: int = eqx.field(static=True)

  simulation_time_features: TimeFeatures # Learnable parameter.  Doing this here to avoid re-implementing the time features in the transformer.

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               n_layers: int,
               filter_width: int,
               hidden_channel_size: int,
               num_transformer_heads: int,
               *,
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               latent_cond_len: int,
               key: PRNGKeyArray):
    self.interpolation_freq = interpolation_freq
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = latent_cond_len*(1 + self.interpolation_freq) # Multiple by the interpolation frequency to account for the fact that we are predicting at the original frequency

    k1, k2 = random.split(key)
    dim = self.linear_sde.dim

    self.dim = dim

    time_feature_size = hidden_channel_size
    self.simulation_time_features = TimeFeatures(embedding_size=2*time_feature_size,
                                       out_features=time_feature_size,
                                       key=k2)
    # Create the transformer
    hypers = MyModelHypers(n_blocks=n_layers,
                           hidden_channel_size=hidden_channel_size,
                           wavenet_kernel_width=filter_width,
                           num_transformer_heads=num_transformer_heads)
    self.transformer = MyModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim + time_feature_size,
                               out_channels=dim,
                               hypers=hypers,
                               key=k1,
                               strict_autoregressive=False,
                               causal_decoder=False) # We don't need to be autoregressive!

  def predict_control(
    self,
    s: Scalar,
    yts: TimeSeries, # The observed data
    xs: TimeSeries, # The series at simulation time s
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> Float[Array, 'T D']:
    """Predict the drift or flow for the series.

    Arguments:
      s: The current simulation time
      yts: The observed data
      xs: The series at simulation time s
      context: The optional context associated with the encoder series that is passed
               to the decoder network from the encoder network.  This is used to
               avoid re-evaluating the encoder during inference.

    Returns:
      The drift or flow at the times in series_dec.
    """
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.  We assume that yts is the entire observed (but maybe masked) sequence so that we know at what times we get potentials at.'
    assert len(xs) == self.generation_len, f'xs must have length {self.generation_len} but has length {len(xs)}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    assert xs.yts.shape[-1] == self.linear_sde.dim, "xs must have the same dimension as the SDE"

    # Only keep the parts of the series that we are conditioning on!!!
    yts_cond = yts[:self.cond_len]

    # Create the context
    if context is None:
      context = self.transformer.create_context(yts_cond)

    # Create the time features for the simulation time
    simulation_time_features = self.simulation_time_features(s)
    simulation_time_features = jnp.broadcast_to(simulation_time_features[None], (len(xs), simulation_time_features.shape[-1]))

    # Concatentate the time features with the series
    concat_features = jnp.concatenate([xs.yts, simulation_time_features], axis=-1)
    xs_aug = TimeSeries(xs.ts, concat_features)

    # Predict the control
    outputs = self.transformer(yts_cond, xs_aug, context=context)
    assert outputs.shape[0] == self.generation_len
    return outputs

################################################################################################################

class MyDiffusionRNNModel(MyDiffusionModel):
  transformer: AbstractEncoderDecoderModel # Learnable parameters

  linear_sde: AbstractLinearSDE = eqx.field(static=True)
  encoder: AbstractEncoder = eqx.field(static=True)

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True)

  dim: int = eqx.field(static=True)

  simulation_time_features: TimeFeatures # Learnable parameter.  Doing this here to avoid re-implementing the time features in the transformer.

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               hidden_size: int,
               *,
               interpolation_freq: int,
               seq_len: int,
               cond_len: int,
               latent_cond_len: int,
               key: PRNGKeyArray):
    self.interpolation_freq = interpolation_freq
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.obs_seq_len = seq_len
    self.cond_len = cond_len
    self.latent_cond_len = latent_cond_len*(1 + self.interpolation_freq) # Multiple by the interpolation frequency to account for the fact that we are predicting at the original frequency

    k1, k2 = random.split(key)
    dim = self.linear_sde.dim

    self.dim = dim

    time_feature_size = hidden_size
    self.simulation_time_features = TimeFeatures(embedding_size=2*time_feature_size,
                                       out_features=time_feature_size,
                                       key=k2)

    # Create the transformer
    hypers = MyGRUModelHypers(hidden_size=hidden_size)
    self.transformer = MyGRUModel(cond_in_channels=encoder.y_dim,
                               in_channels=dim + time_feature_size,
                               out_channels=dim,
                               hypers=hypers,
                               key=key)

################################################################################################################

if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import BrownianMotion, CriticallyDampedLangevinDynamics, TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.neural_diffusion_crf.interpolate_potential import InterpolateResult, initialize_potential_interpolation
  import matplotlib.pyplot as plt

  N = 10
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts, series.yts

  freq = 1

  key = random.PRNGKey(0)

  # Create the SDE and emission potential encoder
  y_dim = yts.shape[1]
  sde = BrownianMotion(sigma=0.1, dim=y_dim)
  # sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  # sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=1/(1 + freq))

  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.1)
  covariance_matrix = encoder(series).node_potentials[0].to_mixed().J

  I = sde.F.eye(sde.F.shape[0])
  prior = StandardGaussian(mu=jnp.zeros(sde.dim), Sigma=I)

  from diffusion_crf.matrix import *
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  model = MyDiffusionModel(sde,
                      encoder,
                      n_layers=4,
                      filter_width=4,
                      hidden_channel_size=16,
                      num_transformer_heads=4,
                      interpolation_freq=freq,
                      seq_len=len(series),
                      cond_len=5,
                      latent_cond_len=4,
                      key=key)

  # Run the loss function
  out = model.loss_fn(series,
                      key=key,
                      debug=False)

  # Sample a series
  sampled_series = model.sample(key, series, debug=False)

  # Check that the samples don't depend on unobserved values
  series2 = model.mask_observation_space_sequence(series)
  sampled_series2 = model.sample(key, series2, debug=False)
  assert jnp.allclose(sampled_series.yts, sampled_series2.yts)

  # Check downsampling
  latent_seq = model.downsample_seq_to_original_freq(sampled_series)

  import pdb; pdb.set_trace()

