import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from diffusion_crf.base import AbstractBatchableObject, auto_vmap, AbstractPotential
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo
from diffusion_crf.ssm.simple_encoder import AbstractEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
import diffusion_crf
from diffusion_crf import AbstractLinearSDE, ConditionedLinearSDE
import diffusion_crf.util as util
from diffusion_crf.crf import Messages, CRF
from Models.models.base import AbstractModel, AbstractEncoderDecoderModel
from diffusion_crf import TimeSeries, HarmonicOscillator, PaddingLatentVariableEncoderWithPrior

################################################################################################################

class AbstractBwdPredictorModel(AbstractModel, abc.ABC):
  """Autoregressive model"""

  linear_sde: eqx.AbstractVar[AbstractLinearSDE]
  encoder: eqx.AbstractVar[AbstractEncoder]
  transformer: eqx.AbstractVar[AbstractEncoderDecoderModel]

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  def predict_next_backward_messages(
    self,
    ts: Float[Array, 'T'], # The time points that we want to predict the message distribution at
    xts: TimeSeries, # The series that we have already generated
  ) -> AbstractPotential:
    raise NotImplementedError

  def predict_current_backward_messages(
    self,
    ts: Float[Array, 'T'],
    yts: TimeSeries,
    xts: TimeSeries,
    *,
    context: Optional[Float[Array, 'S C']] = None,
  ) -> AbstractPotential:
    raise NotImplementedError

class DummyModel(AbstractBwdPredictorModel):

  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder
  transformer: AbstractEncoderDecoderModel

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  def __init__(self,
               process_noise: float,
               obs_noise: float,
               obs_seq_len: int,
               cond_len: int):
    self.linear_sde = HarmonicOscillator(freq=1.0, coeff=0.0, process_noise=process_noise, observation_dim=1)
    self.encoder = PaddingLatentVariableEncoderWithPrior(y_dim=1, x_dim=2, sigma=obs_noise)
    self.transformer = None
    self.interpolation_freq = 0
    self.obs_seq_len = obs_seq_len
    self.cond_len = cond_len
    self.latent_cond_len = 0 # Dummy model doesn't have a latent conditioning length

  def loss_fn(
    self,
    yts: TimeSeries, # This is the observed series
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Scalar:
    raise NotImplementedError

  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    debug: Optional[bool] = False,
  ) -> TimeSeries:
    raise NotImplementedError

  def predict_next_backward_messages(
    self,
    ts: Float[Array, 'T'], # The time points that we want to predict the message distribution at
    xts: TimeSeries, # The series that we have already generated
  ) -> AbstractPotential:
    # Condition on xts
    identity_encoder = PaddingLatentVariableEncoderWithPrior(y_dim=2, x_dim=2, sigma=0.01)
    prob_series = identity_encoder(xts)
    sde = ConditionedLinearSDE(self.linear_sde, prob_series)

    # Discretize the SDE at the time points ts
    result = sde.discretize(ts)
    info = result.info
    crf = result.crf

    bwd = crf.get_backward_messages()
    return info.filter_new_times(bwd)
