import jax
import jax.numpy as jnp
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type, Iterable
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from diffusion_crf.base import *
from diffusion_crf.crf import *
from plum import dispatch
import diffusion_crf.util as util
from jax._src.util import curry
from diffusion_crf.timeseries import TimeSeries, ProbabilisticTimeSeries, DiscretizeInfo, interleave_times

__all__ = ['DiscretizeResult',
           'AbstractContinuousCRF']

################################################################################################################

class DiscretizeResult(eqx.Module):
  crf: CRF
  info: DiscretizeInfo

class AbstractContinuousCRF(AbstractBatchableObject, abc.ABC):

  probabilistic_time_series: eqx.AbstractVar[ProbabilisticTimeSeries]
  parallel: bool = eqx.field(static=True)

  def __init__(self, probabilistic_time_series: ProbabilisticTimeSeries, parallel: bool = False):
    assert isinstance(probabilistic_time_series, ProbabilisticTimeSeries)
    self.probabilistic_time_series = probabilistic_time_series
    self.parallel = parallel

  @property
  def ts(self) -> Float[Array, 'T']:
    return self.probabilistic_time_series.ts

  @property
  def node_potentials(self) -> AbstractPotential:
    return self.probabilistic_time_series.node_potentials

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    node_batch_size = self.node_potentials.batch_size
    if isinstance(node_batch_size, Iterable):
      return node_batch_size[-2]
    elif isinstance(node_batch_size, int):
      return None
    else:
      raise ValueError(f"Invalid batch size: {node_batch_size}")

  @abc.abstractmethod
  def get_base_transition_distribution(self, s: Float[Array, 'D'], t: Float[Array, 'D']) -> AbstractTransition:
    pass

  def log_prob(self, series: TimeSeries) -> Scalar:
    """Evaluate the log of the probability of x under the distribution"""
    result = self.discretize(series.ts)
    crf = result.crf
    info = result.info
    return crf.marginalize(info.new_indices).log_prob(series.yts)

  def sample(self, key: PRNGKeyArray, ts: Float[Array, 'T']) -> TimeSeries:
    crf_result = self.discretize(ts)
    crf = crf_result.crf
    info = crf_result.info
    samples = crf.sample(key)
    xts = info.filter_new_times(samples)
    return TimeSeries(ts, xts)

  def discretize(self, ts: Union[Float[Array, 'T'], None] = None, info: Optional[DiscretizeInfo] = None) -> Union[DiscretizeResult, CRF]:
    if ts is not None:
      assert ts.ndim == 1

    # Interleave these with the new times
    if info is None:
      info = interleave_times(ts, self.ts)
    all_ts = info.ts
    assert all_ts.shape[-1] > 1, 'There must be at least 2 times (including the times of this continuous CRF and the new times) when discretizing!'

    # Make a set of empty node potentials
    zero = self.node_potentials[0].total_uncertainty_like(self.node_potentials[0])
    def make_zero(i):
      return zero
    node_potentials = eqx.filter_vmap(make_zero)(jnp.arange(all_ts.shape[-1]))

    # Place our priors in the node potentials
    def fill_node_potentials(potential, i):
      return jtu.tree_map(lambda t, elt: t.at[i].set(elt), node_potentials, potential)
    node_potentials = fill_node_potentials(self.node_potentials, info.base_indices)

    # Make the transitions for the new times
    s, t = all_ts[:-1], all_ts[1:]
    def make_transition_potential(s, t):
      return self.get_base_transition_distribution(s, t)

    transitions = eqx.filter_vmap(make_transition_potential)(s, t)

    crf = CRF(node_potentials, transitions, parallel=self.parallel)
    if ts is None:
      return crf
    return DiscretizeResult(crf, info)

################################################################################################################

if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt

  # turn on x64
  jax.config.update('jax_enable_x64', True)

  # Test interleave_times
  base_times = jnp.linspace(0, 1, 10)
  new_times = jnp.linspace(0, 1, 5)
  info = interleave_times(new_times, base_times)
  print(info.new_indices)
  print(info.base_indices)
  print(info.ts)
