import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import einops
import equinox as eqx
import abc
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo
from diffusion_crf.ssm.simple_encoder import AbstractEncoder
from diffusion_crf.ssm.simple_decoder import AbstractDecoder
from diffusion_crf import AbstractLinearSDE, ConditionedLinearSDE, interleave_times
import diffusion_crf.util as util
from diffusion_crf.crf import Messages, CRF

"""
This module defines the base abstractions for neural diffusion CRF models. It provides:

1. AbstractEncoderDecoderModel - Interface for encoder-decoder architectures that process time series data
2. CRFState - Container for precomputed Conditional Random Field states and messages
3. AbstractModel - Base class for all probabilistic time series models with common utilities for:
   - Time series discretization and interpolation
   - CRF state management
   - Sampling and loss computation interfaces

These abstractions form the foundation for implementing various diffusion-based and
autoregressive models for time series data, ensuring consistent interfaces for training,
sampling, and evaluation across different model implementations.
"""

################################################################################################################

class AbstractHiddenState(AbstractBatchableObject, abc.ABC):
  """An abstract class for hidden states that can be used to condition the model."""
  pass

class AbstractEncoderDecoderModel(AbstractBatchableObject, abc.ABC):

  @abc.abstractmethod
  def create_context(self, series_enc: TimeSeries) -> Float[Array, 'S C']:
    """Create a representation series for the decoder network.

    Arguments:
      series_enc: The series to encode.

    Returns:
      context: The context for the decoder network
    """
    pass

  @abc.abstractmethod
  def __call__(self,
               condition_series: TimeSeries,
               latent_series: TimeSeries,
               context: Optional[Float[Array, 'S C']] = None) -> Float[Array, 'T C']:
    """Apply the encoder-decoder model to the series.  We must still pass the
    series_enc to the model even if context is provided because the decoder
    network takes as input the series_dec that is prepended with part of the
    end of series_enc.

    Arguments:
      condition_series: The series to condition on.
      latent_series: The series to decode.
      context: The context for the decoder network.  Will have the same length as `series_dec`.
    """
    pass

  def get_initial_state(self) -> AbstractHiddenState:
    """Initialize the hidden state for sequential processing.

    Returns:
      An AbstractHiddenState object containing the initial internal state
      of the model for autoregressive generation.
    """
    raise NotImplementedError

  def single_step(
    self,
    t: float,
    xt: Float[Array, 'D'],
    state: AbstractHiddenState
  ) -> Tuple[Float[Array, 'D'], AbstractHiddenState]:
    """Process a single time step for autoregressive generation.

    Args:
      xt: Input tensor for the current time step
      state: Current hidden state of the model.

    Returns:
      A tuple containing:
        - Output tensor
        - Updated hidden state for the next time step.
    """
    raise NotImplementedError

################################################################################################################

class CRFState(AbstractBatchableObject):
  """Precomputed state for a CRF"""
  cond_sde: ConditionedLinearSDE
  messages: Messages
  crf: CRF
  discretization_info: DiscretizeInfo

  @property
  def batch_size(self) -> Union[None, int, Tuple[int]]:
    return self.crf.batch_size

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               yts: Float[Array, 'S D'],
               discretization_info: DiscretizeInfo,
               parameterization: Literal['natural', 'mixed'] = 'natural'):
    prob_series = encoder(yts, parameterization=parameterization)

    # Construct the CRF p(x | Y)
    self.cond_sde = ConditionedLinearSDE(linear_sde, prob_series, parallel=jax.devices()[0].platform == 'gpu')

    # Discretize the CRF at a specified set of times.  Use AbstractModel.get_discretization_info
    # to do this without worrying about duplicating times.
    self.crf = self.cond_sde.discretize(info=discretization_info)
    self.discretization_info = discretization_info

    # Precompute the forward and backward messages
    self.messages = Messages.from_messages(None, self.crf, need_fwd=True, need_bwd=True)

################################################################################################################

class AbstractModel(AbstractBatchableObject, abc.ABC):
  """Autoregressive model"""

  linear_sde: eqx.AbstractVar[AbstractLinearSDE]
  encoder: eqx.AbstractVar[AbstractEncoder]
  transformer: eqx.AbstractVar[AbstractEncoderDecoderModel]

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  @property
  def generation_len(self) -> int:
    """The number of time points to generate (in the latent space!).  This is the number
    of time points that we will sample from the model.  It is computed by finding the index of
    where pred_len starts in the latent sequence, and then subtracting the latent_cond_len.

    This is also one more than the number of transitions that we will make because we count 1
    for the initial point that we generate.
    """
    return self.latent_seq_len - self.latent_generation_start_index

  @property
  def latent_generation_start_index(self) -> int:
    """The index in the latent sequence that corresponds to the cond_len index in the observation sequence.
    This is where the prediction starts in the latent space."""
    return self.latent_prediction_start_index - self.latent_cond_len

  @property
  def latent_prediction_start_index(self) -> int:
    """The index in the latent sequence that corresponds to the cond_len index in the observation sequence.
    This is where the prediction starts in the latent space."""
    return (1 + self.interpolation_freq) * self.cond_len

  @property
  def pred_len(self) -> int:
    """The number of time points to predict (in the observation space!)"""
    return self.obs_seq_len - self.cond_len

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.linear_sde.batch_size

  @property
  def latent_seq_len(self) -> int:
    return (1 + self.interpolation_freq)*self.obs_seq_len - self.interpolation_freq

  def make_fully_observed(self, yts: TimeSeries) -> TimeSeries:
    """This is a helper function that will mark the series as fully observed.
    This is purely so that observations are not ignored when we need to
    pass the times of the unobserved parts of the series to the model."""
    return TimeSeries(yts.ts, yts.yts)

  def make_crf_state(self, yts: TimeSeries, parameterization: Literal['natural', 'mixed'] = 'natural') -> CRFState:
    """Make a CRF state for a given observed series.  Assume that yts is the entire observed (but maybe masked)
    sequence so that we know at what times we get potentials at."""
    assert len(yts) == self.obs_seq_len, f'yts must have length {self.obs_seq_len} but has length {len(yts)}.'
    yts = self.make_fully_observed_series(yts)
    return CRFState(self.linear_sde, self.encoder, yts, self.get_discretization_info(yts.ts), parameterization=parameterization)

  def get_discretization_info(self, ts: Float[Array, 'T']) -> DiscretizeInfo:
    """Get the discretization info for the data.  We do a uniform interpolation of
    the input times to get the new times."""
    assert ts.ndim == 1
    dts = jnp.diff(ts)
    dts = jnp.concatenate([dts[:1], dts])
    offsets = dts[:,None]*jnp.arange(-self.interpolation_freq, 1)/(self.interpolation_freq + 1)
    new_times = ts[:,None] + offsets
    out = new_times[...,:-1].ravel()

    # We don't want to include points outside of the observed times
    out = out[self.interpolation_freq:]

    info = DiscretizeInfo(out, ts)
    return info

  def downsample_seq_to_original_freq(self, xts: TimeSeries) -> TimeSeries:
    """Downsample a sequence to the original frequency.  This is useful for
    comparing the performance of the model at different interpolation frequencies."""
    assert len(xts) == self.latent_seq_len
    xts_downsampled = xts[::self.interpolation_freq + 1]
    return xts_downsampled

  def basic_interpolation(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries
  ) -> TimeSeries:
    """This is a simple interpolation that upsamples yts by a factor of `freq`
    by simply conditioning the linear SDE on the potentials that
    we get from the input series and then sampling from the conditioned SDE.

    Arguments:
      key: The key to use for the random number generator.
      yts: The series to upsample.

    Returns:
      xts_upsampled: The upsampled series in the latent space.
    """
    # assert len(yts) == self.obs_seq_len, f'yts must not be upsampled!  Got length {len(yts)} but expected {self.obs_seq_len}.'
    assert yts.yts.shape[-1] == self.encoder.y_dim, f'yts must have the same dimension as the observed data'
    prob_series = self.encoder(yts)
    cond_sde = ConditionedLinearSDE(self.linear_sde, prob_series)
    discretization_info = self.get_discretization_info(prob_series.ts)
    crf = cond_sde.discretize(info=discretization_info)
    sample = crf.sample(key)
    xts_upsampled = TimeSeries(discretization_info.ts, sample)
    # assert len(xts_upsampled) == self.latent_seq_len, f'xts_upsampled must have length {self.latent_seq_len} but has length {len(xts_upsampled)}.'
    return xts_upsampled

  @abc.abstractmethod
  def loss_fn(
    self,
    yts: TimeSeries, # This is the observed series
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Scalar:
    """Compute the loss function from the series.

    Arguments:
      yts: The series to compute the loss function from.  This must NOT be upsampled!
      key: The key to use for the random number generator.
      debug: Whether to enter pdb after computing the loss function for debugging.
    """
    raise NotImplementedError

  @abc.abstractmethod
  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    debug: Optional[bool] = False,
  ) -> TimeSeries:
    """Returns a latent series that is sampled from the model.

    Arguments:
      key: The key to use for the random number generator.
      yts: The observed series to sample from.
      debug: Whether to enter pdb after sampling for debugging.
                            the latent encoding of the input time series.
    """
    raise NotImplementedError

  def mask_observation_space_sequence(self, series: TimeSeries) -> TimeSeries:
    """Mask the unobserved parts of the series.  This is useful for checking that the samples
    don't depend on unobserved values."""
    assert len(series) == self.obs_seq_len, f'series must have length {self.obs_seq_len} but has length {len(series)}.'
    mask = jnp.zeros_like(series.yts).astype(bool)
    mask = mask.at[:self.cond_len, :].set(True)
    return TimeSeries(ts=series.ts, yts=series.yts, observation_mask=mask)

  def make_fully_observed_series(self, series: TimeSeries) -> TimeSeries:
    """Force the series to be fully observed."""
    points = jnp.where(series.observation_mask, series.yts, 0.0)
    return TimeSeries(ts=series.ts, yts=points)

  def get_random_discretization_info(self, key: PRNGKeyArray, ts: Float[Array, 'T']) -> DiscretizeInfo:
    """During training, we need to be able to generate random times in between the observed times"""
    offsets = random.uniform(key, (ts.shape[0],)) - 1.0
    new_times = ts[:-1] + offsets[:-1]*(ts[1:] - ts[:-1])
    new_times = jnp.concatenate([new_times, (ts[-1] + offsets[-1])[None]], axis=0)

    # We don't want to include points outside of the range of the observed times
    new_times = new_times[self.interpolation_freq:]

    info = interleave_times(new_times=new_times, base_times=ts)
    return info

class EmptyModel(AbstractModel):

  linear_sde: AbstractLinearSDE
  encoder: AbstractEncoder
  transformer: AbstractEncoderDecoderModel

  interpolation_freq: int = eqx.field(static=True)
  obs_seq_len: int = eqx.field(static=True)
  cond_len: int = eqx.field(static=True) # The number of observed time points that we condition on
  latent_cond_len: int = eqx.field(static=True) # How many extra points before the latent prediction start to generate from

  def __init__(self,
               linear_sde: AbstractLinearSDE,
               encoder: AbstractEncoder,
               interpolation_freq: int,
               obs_seq_len: int,
               cond_len: int):
    self.linear_sde = linear_sde
    self.encoder = encoder
    self.transformer = None
    self.interpolation_freq = interpolation_freq
    self.obs_seq_len = obs_seq_len
    self.cond_len = cond_len
    self.latent_cond_len = 0 # Dummy model doesn't have a latent conditioning length

  def loss_fn(
    self,
    yts: TimeSeries, # This is the observed series
    key: PRNGKeyArray,
    debug: Optional[bool] = False
  ) -> Scalar:
    raise NotImplementedError

  def sample(
    self,
    key: PRNGKeyArray,
    yts: TimeSeries, # This is the observed series
    debug: Optional[bool] = False,
  ) -> TimeSeries:
    raise NotImplementedError
