import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Iterable, Literal, List
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import jax.tree_util as jtu
from diffusion_crf.base import *
from diffusion_crf.util.parallel_scan import parallel_segmented_scan, parallel_scan
from jax._src.util import curry
from diffusion_crf.sde.sde_base import AbstractLinearSDE, AbstractLinearTimeInvariantSDE
from diffusion_crf.matrix import *
from diffusion_crf.gaussian.dist import MixedGaussian, NaturalGaussian, StandardGaussian
import diffusion_crf.util as util

__all__ = ['TimeSeries',
           'ProbabilisticTimeSeries',
           'DiscretizeInfo',
           'interleave_times',
           'interleave_series']

################################################################################################################

def _make_windowed_batches(obj, window_size: int, T: Optional[int] = None):
  """Turn a single TimeSeries into a batch of TimeSeries from windows of size window_size"""
  if T is None:
    try:
      T = len(obj)
    except:
      T = obj.batch_size
  idx = jnp.arange(T - window_size + 1)[:, None] + jnp.arange(window_size)[None, :]
  return obj[idx]

################################################################################################################
class TimeSeries(AbstractBatchableObject):

  ts: Float[Array, 'N']
  yts: Float[Array, 'N D']
  observation_mask: Bool[Array, 'N D']  # True if a dimension of an observation is present at that time step

  mask_value: Optional[Union[Scalar, None]] = eqx.field(static=True)

  def __init__(self,
               ts: Float[Array, 'N'],
               yts: Float[Array, 'N D'],
               observation_mask: Optional[Union[Bool[Array, 'N D'],
                                                Bool[Array, 'N'],
                                                None]] = None,
               mask_value: Optional[Union[Scalar, None]] = None):
    self.ts = ts

    if observation_mask is None:
      observation_mask = jnp.ones_like(yts, dtype=bool)

    if observation_mask.shape == (yts.shape[0],):
      observation_mask = jnp.broadcast_to(observation_mask[:,None], yts.shape)

    self.observation_mask = observation_mask

    if mask_value is None:
      yts = jnp.where(self.observation_mask, yts, jnp.zeros_like(yts))
    else:
      yts = jnp.where(self.observation_mask, yts, jnp.full_like(yts, mask_value))
    self.mask_value = mask_value
    self.yts = yts

    # Check that all of the shapes are compatible
    if self.yts.shape != self.observation_mask.shape:
      raise ValueError("yts and observation_mask must have the same shape")
    if self.ts.shape != self.yts.shape[:-1]:
      raise ValueError("ts must have the same shape as yts except for the last dimension")

  def slice_dimensions(self, indices: Sequence[int]):
    """Slice the time series to only include the specified dimensions"""
    return TimeSeries(self.ts, self.yts[..., indices], self.observation_mask[..., indices])

  @property
  def points(self):
    # Useful for plotting
    return jnp.where(self.observation_mask, self.yts, jnp.nan)

  @property
  def batch_size(self):
    if self.ts.ndim == 1:
      return None
    elif self.ts.ndim == 2:
      return self.ts.shape[0]
    else:
      return self.ts.shape[:-1]

  def __len__(self):
    return self.ts.shape[-1]

  @property
  def observation_dim(self):
    return self.yts.shape[-1]

  def is_fully_uncertain(self):
    # 0 if there is any observed dimension at that time step, 1 otherwise
    return ~jnp.any(self.observation_mask, axis=-1)

  @property
  def mask(self):
    # 1 if there is any observed dimension at that time step, 0 otherwise
    return jnp.any(self.observation_mask, axis=-1)

  def get_missing_observation_mask(self) -> Bool[Array, 'N']:
    return self.is_fully_uncertain()

  def to_probabilistic_series(self):
    return ProbabilisticTimeSeries(self.ts,
                                   self.yts,
                                   certainty=self.observation_mask*jnp.inf,
                                   parameterization='mixed')

  def add_to_plot(self, ax: 'plt.Axes', **kwargs):
    return ax.plot(self.ts, self.points, **kwargs)

  def make_windowed_batches(self, window_size: int):
    """Turn a single TimeSeries into a batch of TimeSeries from windows of size window_size"""
    return _make_windowed_batches(self, window_size)

  def plot_series(self,
                  index: Optional[Union[int, Literal['all']]] = None,
                  axes: Optional[List] = None,
                  fig: Optional['plt.Figure'] = None,
                  show_plot: bool = True,
                  add_title: bool = True,
                  title: Optional[str] = None,
                  line_colors: Optional[Union[str, List[str]]] = 'blue',
                  line_alpha: float = 0.7,
                  line_width: float = 1,
                  marker_colors: Optional[Union[str, List[str]]] = None,
                  marker_size: float = 25,
                  marker_style: str = 'o',
                  figsize: Optional[Tuple[float, float]] = None,
                  fig_width: float = 6,
                  fig_height_factor: float = 3,
                  legend_loc: str = 'upper center',
                  batch_color: str = 'blue',
                  batch_line_width: float = 0.5,
                  min_alpha: float = 0.1,
                  max_alpha: float = 1.0,
                  alpha_scaling: Literal['linear', 'sqrt', 'log'] = 'sqrt'):
    """
    Create a visualization of a TimeSeries object.

    This method generates a plot for each dimension of the time series, showing:
    - Lines for observed values
    - Markers for observation points with the observation mask applied
    - Multiple samples if the TimeSeries is batched

    Args:
      self: The TimeSeries object to visualize
      index: Index of the sequence if the self is batched. If 'all', plot all samples.
             If None and self is batched, will plot all samples.
      axes: Optional list of axes to plot on (if None, new axes will be created)
      fig: Optional figure to plot on (if None, a new figure will be created)
      show_plot: Whether to call plt.show() after creating the plot
      add_title: Whether to add a title to the plot
      title: Optional title string to use instead of default
      line_colors: Color(s) for the connecting lines. Can be a string or a list of colors for multiple samples.
      line_alpha: Alpha (transparency) for the connecting lines
      line_width: Width of the connecting lines
      marker_colors: Color(s) for the markers. If None, will use line_colors.
      marker_size: Size of the markers
      marker_style: Marker style (e.g., 'o', 's', '^')
      figsize: Directly specify figure size as (width, height) in inches
      fig_width: Width of the figure in inches (used if figsize is None)
      fig_height_factor: Multiplier for height per dimension (used if figsize is None)
      legend_loc: Location for the legend
      batch_color: Color for batched samples when index='all' (default: blue)
      batch_line_width: Line width for batched samples when index='all' (default: 0.5)
      min_alpha: Minimum alpha value for batched samples (default: 0.1)
      max_alpha: Maximum alpha value for batched samples (default: 1.0)
      alpha_scaling: Method to scale alpha with batch size ('linear', 'sqrt', 'log')

    Returns:
      tuple: (fig, axes) - The figure and axes objects used for the plot
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import math
    from matplotlib.ticker import FuncFormatter, MaxNLocator
    from matplotlib.cm import get_cmap

    # Convert marker_colors to match line_colors if not specified
    if marker_colors is None:
      marker_colors = line_colors

    # Check if we're dealing with a batched TimeSeries
    is_batched = self.batch_size is not None and isinstance(self.batch_size, int)
    plot_all_samples = is_batched and (index == 'all' or index is None)

    # Function to calculate alpha based on batch size (same as in plot_multiple_series)
    def calculate_alpha(batch_size):
      # If batch size is too small, return max_alpha
      if batch_size <= 1:
        return max_alpha

      # Apply scaling based on batch size directly
      if alpha_scaling == 'linear':
        # Linear scaling: alpha decreases as batch size increases
        alpha = max_alpha / batch_size
      elif alpha_scaling == 'sqrt':
        # Square root scaling: more gradual decrease for small batch sizes
        alpha = max_alpha / math.sqrt(batch_size)
      elif alpha_scaling == 'log':
        # Logarithmic scaling: even more gradual decrease
        alpha = max_alpha / (1 + math.log(batch_size))
      else:
        alpha = (min_alpha + max_alpha) / 2  # Default fallback

      # Ensure alpha is within bounds
      alpha = max(min_alpha, min(max_alpha, alpha))

      return alpha

    # Get the number of samples to plot
    num_samples = self.batch_size if plot_all_samples else 1

    # For batched samples with index='all', use a single color with dynamic alpha
    if plot_all_samples:
      # Use batch_color for all samples when plotting all batches
      line_colors = [batch_color] * num_samples
      marker_colors = [batch_color] * num_samples

      # Calculate alpha based on batch size
      effective_alpha = calculate_alpha(num_samples)
    else:
      # Single sample case (or specified index)
      if is_batched and index is not None:
        sample = self[index]
      else:
        sample = self

      # Ensure colors are properly set
      if isinstance(line_colors, list):
        line_colors = line_colors[0] if line_colors else 'blue'
      if isinstance(marker_colors, list):
        marker_colors = marker_colors[0] if marker_colors else line_colors

    # Extract observation dimension
    if plot_all_samples:
      num_dims = self[0].observation_dim  # Get dimensions from first sample
    else:
      num_dims = sample.observation_dim

    # Check if we need to create new axes or use provided ones
    create_new_figure = fig is None or axes is None

    if create_new_figure:
      n_cols = 1
      n_rows = num_dims

      # Calculate figure size
      if figsize is None:
        figsize = (n_cols*fig_width, fig_height_factor*n_rows)

      fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, sharex=True)

      # Set a publication-quality font
      plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Palatino', 'DejaVu Serif', 'Times'],
        'mathtext.fontset': 'stix',
      })

      # Handle case where there's only one dimension (axes won't be array)
      if n_rows == 1:
        axes = [axes]

      # Add a title to the figure
      if add_title:
        title_text = title if title is not None else "Time Series Plot"
        fig.suptitle(title_text, fontsize=14, y=0.98)
        plt.subplots_adjust(top=0.95)  # Make room for title
    else:
      # Ensure axes is a list for consistent indexing
      if not isinstance(axes, list) and not isinstance(axes, np.ndarray):
        axes = [axes]

    # First pass: determine y-axis ranges for all plots across all samples if plotting multiple
    y_ranges = []

    if plot_all_samples:
      # Process all samples to find global min/max for each dimension
      for k in range(num_dims):
        all_observed_values = []

        for i in range(num_samples):
          sample = self[i]
          ts = np.array(sample.ts)
          values = np.array(sample.yts)
          mask = np.array(sample.observation_mask)

          # Collect observed values for this dimension across all samples
          observed_values = values[:, k][mask[:, k]]
          if len(observed_values) > 0:
            all_observed_values.extend(observed_values)

        if all_observed_values:
          y_min = np.min(all_observed_values)
          y_max = np.max(all_observed_values)
          # Add a small buffer
          buffer = 0.1 * (y_max - y_min) if y_max > y_min else 0.1 * abs(y_max)
          y_ranges.append((y_min - buffer, y_max + buffer))
        else:
          y_ranges.append((-1, 1))  # Default range if no data
    else:
      # Single sample case
      ts = np.array(sample.ts)
      values = np.array(sample.yts)
      mask = np.array(sample.observation_mask)

      for k in range(num_dims):
        # Get all observed values for this dimension
        observed_values = values[:, k][mask[:, k]]
        if len(observed_values) > 0:
          y_min = np.min(observed_values)
          y_max = np.max(observed_values)
          # Add a small buffer
          buffer = 0.1 * (y_max - y_min) if y_max > y_min else 0.1 * abs(y_max)
          y_ranges.append((y_min - buffer, y_max + buffer))
        else:
          y_ranges.append((-1, 1))  # Default range if no data

    # Create a tick formatter that has consistent decimal places
    def custom_formatter(x, pos):
      # Use a consistent format with 2 decimal places
      return f"{x:.2f}"

    formatter = FuncFormatter(custom_formatter)

    # Track handles for legend
    legend_handles = []
    legend_labels = []

    # Plot each dimension
    for k in range(num_dims):
      ax = axes[k]

      if plot_all_samples:
        # Plot all samples from the batch on the same axes with styled lines
        # Add a single legend entry for the first sample only
        legend_added = False

        for i in range(num_samples):
          sample = self[i]
          ts = np.array(sample.ts)
          values = np.array(sample.yts)
          mask = np.array(sample.observation_mask[:, k])

          # Plot observed points
          observed_indices = np.where(mask)[0]
          if len(observed_indices) > 0:
            # Only add one legend entry for the batch
            if k == 0 and not legend_added:
              # Connected line with batch color, thin line, scaled alpha
              line, = ax.plot(ts[observed_indices], values[observed_indices, k],
                      color=batch_color, linewidth=batch_line_width, alpha=effective_alpha,
                      label="All Samples")

              legend_handles.append(line)
              legend_labels.append("All Samples")
              legend_added = True
            else:
              # No legend entry for other dimensions or samples
              ax.plot(ts[observed_indices], values[observed_indices, k],
                      color=batch_color, linewidth=batch_line_width, alpha=effective_alpha)

            # No markers for batched samples
      else:
        # Single sample case - use normal styling
        ts = np.array(sample.ts)
        values = np.array(sample.yts)
        mask = np.array(sample.observation_mask[:, k])

        # Plot time series values
        observed_indices = np.where(mask)[0]

        # Add lines for observation continuity
        if len(observed_indices) > 0:
          # Plot line through observed points
          line, = ax.plot(ts[observed_indices], values[observed_indices, k],
                  color=line_colors, linewidth=line_width, alpha=line_alpha)

          if marker_style is not None:
            # Plot observed points with markers
            ax.scatter(ts[observed_indices], values[observed_indices, k],
                      color=marker_colors, marker=marker_style, s=marker_size)

          # Add to legend for single sample case
          if k == 0:
            legend_handles.append(line)
            legend_labels.append("Observations")

      # Set y-axis range from our calculated ranges
      y_min, y_max = y_ranges[k]
      ax.set_ylim(y_min, y_max)

      # Standardize y-ticks: use 5 evenly spaced ticks
      ax.yaxis.set_major_locator(MaxNLocator(nbins=5))

      # Apply custom tick formatter for y-axis
      ax.yaxis.set_major_formatter(formatter)

      # Make y-tick labels smaller and align them right
      ax.tick_params(axis='y', labelsize=8)
      ax.yaxis.set_tick_params(pad=1)

      # Right-align y-tick labels for better alignment with different length numbers
      for label in ax.get_yticklabels():
        label.set_horizontalalignment('right')

      # Handle x-tick labels visibility
      if create_new_figure and k < num_dims - 1:
        # Hide x-tick labels but keep ticks visible for non-bottom plots
        plt.setp(ax.get_xticklabels(), visible=False)
      else:
        # Make bottom x-tick labels visible and smaller
        ax.tick_params(axis='x', labelsize=8)
        plt.setp(ax.get_xticklabels(), visible=True)

      if ax.get_legend() is not None:
        ax.legend().remove()

      # Add y-axis label for each dimension
      ax.set_ylabel(f"Dim {k}", fontsize=10)
      ax.yaxis.labelpad = 10

    # Add x-axis label only to the bottom plot
    axes[-1].set_xlabel('Time', fontsize=10)

    # Ensure all x-tick marks are visible
    for ax in axes:
      ax.xaxis.set_tick_params(which='both', size=4, width=1, direction='out')

    # Add a legend if we have any legend handles
    if create_new_figure and legend_handles:
      fig.legend(handles=legend_handles, labels=legend_labels,
                loc=legend_loc, bbox_to_anchor=(0.5, 0.94),
                ncol=min(2, len(legend_handles)), fontsize=9,
                frameon=True, borderaxespad=0.)

    # Adjust layout
    if create_new_figure:
      plt.tight_layout(rect=[0, 0, 1, 0.95])

    if show_plot and create_new_figure:
      plt.show()
      if not fig._suptitle:  # Only close if not a custom figure we want to reuse
        plt.close()

    return fig, axes

  @staticmethod
  def plot_multiple_series(series_list: List['TimeSeries'],
                          index: Optional[Union[int, Literal['all']]] = 'all',
                          titles: Optional[List[str]] = None,
                          show_plot: bool = True,
                          common_title: Optional[str] = None,
                          line_colors: Optional[Union[str, List[str]]] = 'blue',
                          line_alpha: float = 0.7,
                          line_width: float = 1,
                          marker_colors: Optional[Union[str, List[str]]] = None,
                          marker_size: float = 25,
                          marker_style: str = 'o',
                          figsize: Optional[Tuple[float, float]] = None,
                          width_per_series: float = 6,
                          height_per_dim: float = 3,
                          batch_color: str = 'blue',
                          batch_line_width: float = 0.5,
                          min_alpha: float = 0.1,
                          max_alpha: float = 1.0,
                          alpha_scaling: Literal['linear', 'sqrt', 'log'] = 'sqrt',
                          use_max_dims: bool = False):
    """
    Create side-by-side plots of multiple TimeSeries objects for comparison.

    This function arranges multiple time series in columns, with each row
    representing a dimension of the data.

    Args:
      series_list: List of TimeSeries objects to compare
      index: Index of the sequence if the series are batched. If 'all', plot all
             samples from each batched series overlaid on the same axes.
      titles: Optional list of titles for each series (column)
      show_plot: Whether to call plt.show() after creating the plot
      common_title: Optional overall title for the plot
      line_colors: Color(s) for the connecting lines. Can be a string or list of colors.
      line_alpha: Alpha (transparency) for the connecting lines
      line_width: Width of the connecting lines
      marker_colors: Color(s) for the markers. If None, will use line_colors.
      marker_size: Size of the markers
      marker_style: Marker style (e.g., 'o', 's', '^')
      figsize: Directly specify figure size as (width, height) in inches
      width_per_series: Width in inches allocated for each series column
      height_per_dim: Height in inches allocated for each dimension row
      batch_color: Color for batched samples when index='all' (default: lightblue)
      batch_line_width: Line width for batched samples when index='all' (default: 0.5)
      min_alpha: Minimum alpha value for batched samples (default: 0.1)
      max_alpha: Maximum alpha value for batched samples (default: 1.0)
      alpha_scaling: Method to scale alpha with batch size ('linear', 'sqrt', 'log')
      use_max_dims: Whether to plot up to the maximum number of dimensions across
                    all series (True) or only dimensions present in all series (False)

    Returns:
      tuple: (fig, axes) - The figure and axes objects used for the plots
    """
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.ticker import FuncFormatter, MaxNLocator
    from matplotlib.cm import get_cmap
    import math

    if not series_list:
      raise ValueError("No series provided for plotting")

    # Convert marker_colors to match line_colors if not specified
    if marker_colors is None:
      marker_colors = line_colors

    # Check if we need to plot all samples from batched series
    plot_all_samples = index == 'all'

    # Number of columns in the plot (one per series in series_list)
    n_series = len(series_list)

    # Function to calculate alpha based on batch size
    def calculate_alpha(batch_size):
      # If batch size is too small, return max_alpha
      if batch_size <= 1:
        return max_alpha

      # Apply scaling based on batch size directly
      if alpha_scaling == 'linear':
        # Linear scaling: alpha decreases as batch size increases
        # For batch size 2, alpha will be close to max_alpha
        # As batch size increases, alpha approaches min_alpha
        alpha = max_alpha / batch_size
      elif alpha_scaling == 'sqrt':
        # Square root scaling: more gradual decrease for small batch sizes
        alpha = max_alpha / math.sqrt(batch_size)
      elif alpha_scaling == 'log':
        # Logarithmic scaling: even more gradual decrease
        alpha = max_alpha / (1 + math.log(batch_size))
      else:
        alpha = (min_alpha + max_alpha) / 2  # Default fallback

      # Ensure alpha is within bounds
      alpha = max(min_alpha, min(max_alpha, alpha))

      return alpha

    # Prepare color mappings for each series
    series_colors = []
    legend_elements = []

    for i, series in enumerate(series_list):
      # Check if current series is batched
      is_batched = series.batch_size is not None and isinstance(series.batch_size, int)
      series_batch_size = series.batch_size if is_batched else 1

      # Generate or get colors for this series
      if is_batched and plot_all_samples:
        # For batched samples with index='all', use a fixed light blue color
        colors = [batch_color] * series_batch_size
      else:
        # For non-batched or specific index, use the regular colors
        if isinstance(line_colors, str):
          colors = [line_colors]
        elif isinstance(line_colors, list):
          colors = [line_colors[i % len(line_colors)]]

      series_colors.append(colors)

    # Prepare marker colors
    if marker_colors is None:
      marker_colors_list = series_colors
    else:
      marker_colors_list = []
      for i, series in enumerate(series_list):
        is_batched = series.batch_size is not None and isinstance(series.batch_size, int)

        if is_batched and plot_all_samples:
          # For batched with index='all', we won't use markers so set empty list
          colors = [batch_color] * series.batch_size
        else:
          # For non-batched or specific index
          if isinstance(marker_colors, str):
            colors = [marker_colors]
          elif isinstance(marker_colors, list):
            colors = [marker_colors[i % len(marker_colors)]]

        marker_colors_list.append(colors)

    # Find dimensions across all series, accounting for potential batching
    dims_list = []
    for series in series_list:
      if series.batch_size is not None and isinstance(series.batch_size, int) and plot_all_samples:
        # For batched series, check first sample's dimensions
        dims_list.append(series[0].observation_dim)
      else:
        dims_list.append(series.observation_dim)

    # Use either the minimum or maximum dimensions based on use_max_dims flag
    if use_max_dims:
      plot_dims = max(dims_list)
    else:
      plot_dims = min(dims_list)

    if plot_dims == 0:
      raise ValueError("All series have 0 dimensions, cannot plot.")

    # Set a publication-quality font
    plt.rcParams.update({
      'font.family': 'serif',
      'font.serif': ['Times New Roman', 'Palatino', 'DejaVu Serif', 'Times'],
      'mathtext.fontset': 'stix',
    })

    # Create figure with one column per series in series_list
    if figsize is None:
      figsize = (n_series*width_per_series, plot_dims*height_per_dim)

    fig, axes = plt.subplots(plot_dims, n_series,
                          figsize=figsize,
                          sharex='col', sharey='row')

    # Handle single dimension or single series case
    if plot_dims == 1:
      axes = np.array([axes])
    if n_series == 1:
      axes = axes.reshape(plot_dims, 1)

    # Set column titles if provided
    if titles is not None:
      for i, title in enumerate(titles[:n_series]):
        axes[0, i].set_title(title, fontsize=12)

    # Format tick formatter for consistent decimal places
    def custom_formatter(x, pos):
      return f"{x:.2f}"
    formatter = FuncFormatter(custom_formatter)

    # First pass to determine y-axis ranges for all plots
    y_ranges = []
    for k in range(plot_dims):
      all_values = []

      for i, series in enumerate(series_list):
        is_batched = series.batch_size is not None and isinstance(series.batch_size, int)

        if is_batched and plot_all_samples:
          # Collect values from all samples in the batch
          for j in range(series.batch_size):
            sample = series[j]
            if k < sample.observation_dim:
              points = np.array(sample.points)
              mask = np.array(sample.observation_mask[:, k])
              if np.any(mask):
                observed_values = points[mask, k]
                all_values.extend([v for v in observed_values if np.isfinite(v)])
        else:
          # Single series or specific index
          if is_batched and index is not None and index != 'all':
            current_series = series[index]
          else:
            current_series = series

          if k < current_series.observation_dim:
            points = np.array(current_series.points)
            mask = np.array(current_series.observation_mask[:, k])
            if np.any(mask):
              observed_values = points[mask, k]
              all_values.extend([v for v in observed_values if np.isfinite(v)])

      if all_values:
        y_min = np.min(all_values)
        y_max = np.max(all_values)
        # Add buffer, handling edge case where min == max
        if y_max > y_min:
          buffer = 0.1 * (y_max - y_min)
        elif y_max == y_min:
          buffer = 0.1 * abs(y_max) if y_max != 0 else 0.1
        else:
          buffer = 0.1

        y_ranges.append((y_min - buffer, y_max + buffer))
      else:
        y_ranges.append((-1, 1))  # Default range if no finite data

    # Plot each series
    legend_handles = []
    legend_labels = []

    for i, series in enumerate(series_list):
      is_batched = series.batch_size is not None and isinstance(series.batch_size, int)

      # Get the colors for this series
      colors = series_colors[i]
      marker_colors = marker_colors_list[i]

      # Plot each dimension
      for k in range(plot_dims):
        ax = axes[k, i]

        # Skip plotting if this dimension doesn't exist for this series
        if (is_batched and plot_all_samples and k >= series[0].observation_dim) or \
           (not (is_batched and plot_all_samples) and k >= series.observation_dim):
          # Hide tick labels for empty axes
          ax.set_xticks([])
          ax.set_yticks([])
          if i == 0:
            ax.set_ylabel(f"Dim {k}", fontsize=10, alpha=0.5)
          continue

        if is_batched and plot_all_samples:
          # Calculate alpha based on batch size for this series
          effective_alpha = calculate_alpha(series.batch_size)

          # Plot all samples from the batch on the same axes with thin light blue lines
          # Add a single legend entry for the first sample only
          legend_added = False

          for j in range(series.batch_size):
            sample = series[j]

            if k < sample.observation_dim:
              ts = np.array(sample.ts)
              values = np.array(sample.yts)
              mask = np.array(sample.observation_mask[:, k])

              # Plot observed points
              observed_indices = np.where(mask)[0]
              if len(observed_indices) > 0:
                # Only add one legend entry for the batch
                if k == 0 and not legend_added:
                  series_label = f"{titles[i]}" if titles is not None and i < len(titles) else f"Series {i+1}"

                  # Connected line with light blue, thin line, scaled alpha
                  line, = ax.plot(ts[observed_indices], values[observed_indices, k],
                          color=batch_color, linewidth=batch_line_width, alpha=effective_alpha,
                          label=series_label)

                  legend_handles.append(line)
                  legend_labels.append(series_label)
                  legend_added = True
                else:
                  # No legend entry for other dimensions or samples
                  ax.plot(ts[observed_indices], values[observed_indices, k],
                          color=batch_color, linewidth=batch_line_width, alpha=effective_alpha)

                # No markers for batched samples
        else:
          # Single series or specific index case - use normal styling
          current_series = series
          if is_batched and index is not None and index != 'all':
            current_series = series[index]

          if k < current_series.observation_dim:
            ts = np.array(current_series.ts)
            values = np.array(current_series.yts)
            mask = np.array(current_series.observation_mask[:, k])

            # Plot observed points
            observed_indices = np.where(mask)[0]
            if len(observed_indices) > 0:
              # Only add to legend for the first dimension
              if k == 0:
                series_label = f"{titles[i]}" if titles is not None and i < len(titles) else f"Series {i+1}"
                line, = ax.plot(ts[observed_indices], values[observed_indices, k],
                        color=colors[0], linewidth=line_width, alpha=line_alpha,
                        label=series_label)
                legend_handles.append(line)
                legend_labels.append(series_label)
              else:
                ax.plot(ts[observed_indices], values[observed_indices, k],
                        color=colors[0], linewidth=line_width, alpha=line_alpha)

              # Add markers for non-batched samples
              ax.scatter(ts[observed_indices], values[observed_indices, k],
                      color=marker_colors[0], marker=marker_style, s=marker_size)

        # Set y-axis range
        if k < len(y_ranges):
          y_min, y_max = y_ranges[k]
          ax.set_ylim(y_min, y_max)
        else:
          ax.set_ylim(-1, 1)  # Default range

        # Standardize y-ticks
        ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
        ax.yaxis.set_major_formatter(formatter)

        # Style y-tick labels
        ax.tick_params(axis='y', labelsize=8)
        ax.yaxis.set_tick_params(pad=1)

        # Right-align y-tick labels
        for label in ax.get_yticklabels():
          label.set_horizontalalignment('right')

        # Handle x-tick labels visibility
        if k < plot_dims - 1:
          plt.setp(ax.get_xticklabels(), visible=False)
        else:
          ax.tick_params(axis='x', labelsize=8)
          plt.setp(ax.get_xticklabels(), visible=True)

        # Only leftmost column needs y labels
        if i == 0:
          ax.set_ylabel(f"Dim {k}", fontsize=10)
          ax.yaxis.labelpad = 10

        # Bottom row gets x-axis label
        if k == plot_dims - 1:
          ax.set_xlabel('Time', fontsize=10)

        # Ensure tick marks are visible
        ax.xaxis.set_tick_params(which='both', size=4, width=1, direction='out')

        # Remove any auto-generated legend
        if ax.get_legend() is not None:
          ax.legend().remove()

    # Add a common legend if we have any legend handles
    if legend_handles:
      # Calculate number of columns based on the number of legend entries
      ncol = min(5, len(legend_handles))

      fig.legend(handles=legend_handles, labels=legend_labels,
                loc='upper center', bbox_to_anchor=(0.5, 0.99),
                ncol=ncol, fontsize=10, frameon=True, borderaxespad=0.)

    # Add overall title if provided
    if common_title:
      fig.suptitle(common_title, fontsize=18, y=0.995)

    # Calculate spacing that scales with dimensions
    w_base, h_base = 0.2, 0.3
    w_scale = max(0.25, 1.0 / n_series)
    h_scale = max(0.5, 1.0 / plot_dims)

    # Layout adjustments
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.subplots_adjust(
      wspace=w_base * w_scale,
      hspace=h_base * h_scale,
      top=0.95
    )

    if show_plot:
      plt.show()

    return fig, axes


################################################################################################################

def _process_potentials(_,
                        ts: Float[Array, 'N'],
                        xts: Float[Array, 'N D'],
                        standard_deviation: Optional[Float[Array, 'N D']] = None,
                        certainty: Union[Float[Array, 'N D'], None] = None,
                        parameterization: Optional[Literal['natural', 'mixed', 'standard']] = 'natural'):
    """A probabilistic time series is a time series of Gaussian potential functions for the latent variable

    **Arguments**:
      - `ts`: The times at which the potential functions are evaluated
      - `xts`: The potential functions evaluated at each time
      - `standard_deviation`: The standard deviation of the potential functions.
                              If None, then we assume that the potential functions are fully certain.
      - `certainty`: The certainty (inverse of standard deviation) of the potential functions.
                     Positive values indicate certainty.  0 for fully uncertain and None if fully certain.
      - `parameterization`: The parameterization of the potential functions.  Defaults to 'natural'.
    """
    if certainty is not None:
      if xts.shape != certainty.shape:
        raise ValueError("xts and certainty must have the same shape")
    if ts.shape != xts.shape[:-1]:
      raise ValueError("ts must have the same shape as xts except for the last dimension")

    # Determine which potentials are fully certain and which are fully uncertain
    if standard_deviation is None and certainty is None:
      # We will be fully certain for all observations
      standard_deviation = jnp.zeros_like(xts)
      certainty = jnp.ones_like(xts)*jnp.inf
      is_fully_certain = jnp.ones_like(ts, dtype=bool)
      is_fully_uncertain = jnp.zeros_like(ts, dtype=bool)
      parameterization = 'mixed' # Only used mixed parameterization if we're doing a bridge

    elif standard_deviation is not None and certainty is not None:
      raise ValueError("Both standard_deviation and certainty cannot be provided")

    elif standard_deviation is None:
      # Use certainty to determine which potentials are fully certain
      is_fully_certain = jnp.isinf(certainty)
      is_fully_uncertain = jnp.where(jnp.abs(certainty) < 1e-10, True, False)

      is_fully_certain = jnp.all(is_fully_certain, axis=-1)
      is_fully_uncertain = jnp.all(is_fully_uncertain, axis=-1)

    elif certainty is None:
      # Use standard deviation to determine which potentials are fully certain
      is_fully_certain = jnp.where(jnp.abs(standard_deviation) < 1e-10, True, False)
      is_fully_uncertain = jnp.isinf(standard_deviation)

      certainty = 1/standard_deviation
      certainty = jnp.where(is_fully_uncertain, 0.0, certainty)
      certainty = jnp.where(is_fully_certain, jnp.inf, certainty)

      is_fully_certain = jnp.all(is_fully_certain, axis=-1)
      is_fully_uncertain = jnp.all(is_fully_uncertain, axis=-1)

    else:
      raise ValueError("Either standard_deviation or certainty must be provided")

    assert is_fully_certain.shape == ts.shape
    assert is_fully_uncertain.shape == ts.shape

    # Turn the unceratinty into a Matrix type
    Jinv = jax.vmap(util.to_matrix)(certainty)

    # Certainties of inf correspond to fully certain potentials
    def set_total_certainty(Jinv, mask):
      def set_totally_certain(Jinv):
        return eqx.tree_at(lambda x: x.tags, Jinv, TAGS.inf_tags)
      return util.where(mask, set_totally_certain(Jinv), Jinv)
    Jinv = jax.vmap(set_total_certainty)(Jinv, is_fully_certain)

    # Certainties of 0 correspond to fully uncertain potentials
    def set_total_uncertainty(Jinv, mask):
      def set_totally_uncertain(Jinv):
        return eqx.tree_at(lambda x: x.tags, Jinv, TAGS.zero_tags)
      return util.where(mask, set_totally_uncertain(Jinv), Jinv)
    Jinv = jax.vmap(set_total_uncertainty)(Jinv, is_fully_uncertain)

    # Create the potentials
    def process_potential(x: Float[Array, 'D'],
                          Jinv: AbstractMatrix):

      if parameterization == 'natural':
        h = Jinv@x
        potential = NaturalGaussian(Jinv, h)
      elif parameterization == 'mixed':
        potential = MixedGaussian(x, Jinv)
      elif parameterization == 'standard':
        potential = StandardGaussian(x, Jinv.get_inverse())
      else:
        raise ValueError(f"Unknown parameterization: {parameterization}")
      return potential

    # Construct the node potentials
    node_potentials = jax.vmap(process_potential)(xts, Jinv)
    return node_potentials

################################################################################################################

class ProbabilisticTimeSeries(AbstractBatchableObject):

  ts: Float[Array, 'N']
  node_potentials: AbstractPotential # A Gaussian potential function for the latent variable at each time

  def __init__(self,
               ts: Float[Array, 'N'],
               xts: Float[Array, 'N D'],
               standard_deviation: Optional[Float[Array, 'N D']] = None,
               certainty: Union[Float[Array, 'N D'], None] = None,
               parameterization: Optional[Literal['natural', 'mixed', 'standard']] = 'natural'):
    """A probabilistic time series is a time series of Gaussian potential functions for the latent variable.
    This initializer will work if the inputs are batched.

    **Arguments**:
      - `ts`: The times at which the potential functions are evaluated
      - `xts`: The potential functions evaluated at each time
      - `standard_deviation`: The standard deviation of the potential functions.
                              If None, then we assume that the potential functions are fully certain.
      - `certainty`: The certainty (inverse of standard deviation) of the potential functions.
                     Positive values indicate certainty.  0 for fully uncertain and None if fully certain.
      - `parameterization`: The parameterization of the potential functions.  Defaults to 'natural'.
    """
    if isinstance(xts, AbstractPotential):
      # Hack to get easy initialization
      self.ts = ts
      self.node_potentials = xts
      return

    if certainty is not None:
      if xts.shape != certainty.shape:
        raise ValueError("xts and certainty must have the same shape")
    if ts.shape != xts.shape[:-1]:
      raise ValueError("ts must have the same shape as xts except for the last dimension")

    self.ts = ts
    self.node_potentials = auto_vmap(_process_potentials)(self,
                                                          ts,
                                                          xts,
                                                          standard_deviation,
                                                          certainty,
                                                          parameterization=parameterization)

  @classmethod
  def from_potentials(cls, ts: Float[Array, 'N'], node_potentials: AbstractPotential):
    """Alternative initializer that directly takes node potentials"""
    return ProbabilisticTimeSeries(ts, node_potentials)

  @property
  def batch_size(self):
    if self.ts.ndim == 1:
      return None
    elif self.ts.ndim == 2:
      return self.ts.shape[0]
    else:
      return self.ts.shape[:-1]

  def __len__(self):
    return self.ts.shape[-1]

  @property
  def is_fully_uncertain(self):
    return self.node_potentials.J.tags.is_zero

  @property
  def is_fully_certain(self):
    return self.node_potentials.J.tags.is_inf

  def to_mixed(self):
    return eqx.tree_at(lambda x: x.node_potentials, self, self.node_potentials.to_mixed())

  def to_nat(self):
    return eqx.tree_at(lambda x: x.node_potentials, self, self.node_potentials.to_nat())

  def to_std(self):
    return eqx.tree_at(lambda x: x.node_potentials, self, self.node_potentials.to_std())

  def make_windowed_batches(self, window_size: int):
    """Turn a single TimeSeries into a batch of TimeSeries from windows of size window_size"""
    return _make_windowed_batches(self, window_size)

################################################################################################################

class DiscretizeInfo(eqx.Module):
  new_indices: Float[Array, 'T_new']
  base_indices: Float[Array, 'T_old']
  ts: Float[Array, 'T_new + T_old']

  @property
  def new_indices_mask(self):
    new_mask = jnp.ones(self.new_indices.shape, dtype=bool)
    base_mask = jnp.zeros(self.base_indices.shape, dtype=bool)
    return self.interleave(new_mask, base_mask)

  @property
  def base_indices_mask(self):
    new_mask = jnp.zeros(self.new_indices.shape, dtype=bool)
    base_mask = jnp.ones(self.base_indices.shape, dtype=bool)
    return self.interleave(new_mask, base_mask)

  @property
  def new_times(self):
    return self.ts[self.new_indices]

  @property
  def base_times(self):
    return self.ts[self.base_indices]

  def transpose(self):
    return DiscretizeInfo(new_times=self.base_indices, base_times=self.new_indices)

  def __init__(self, new_times: Union[Float[Array, 'T2'], None], base_times: Float[Array, 'T1']):
    if new_times is None or new_times.size == 0:
      self.new_indices = None
      self.base_indices = jnp.arange(len(base_times))
      self.ts = base_times
    elif base_times is None or base_times.size == 0:
      self.new_indices = jnp.arange(len(new_times))
      self.base_indices = None
      self.ts = new_times
    else:

      # Find the position of the old potentials in the expanded node potentials
      indices_for_base_times = jnp.arange(len(base_times))
      def get_new_index(old_index):
        # Count the number of new times that are less than the old time
        return old_index + (new_times <= base_times[old_index]).sum()
      new_indices_for_base_times = jax.vmap(get_new_index)(indices_for_base_times)

      # Find the positions of the new times in the expanded node potentials
      indices_for_new_times = jnp.arange(len(new_times))
      def get_new_index(old_index):
        return old_index + (base_times < new_times[old_index]).sum()
      new_indices_for_new_times = jax.vmap(get_new_index)(indices_for_new_times)

      # Get the combined times
      combined_ts = jnp.concatenate([base_times, new_times])
      sorted_indices = jnp.argsort(combined_ts)
      ts = combined_ts[sorted_indices]

      self.new_indices = new_indices_for_new_times
      self.base_indices = new_indices_for_base_times
      self.ts = ts

  def interleave(self, new_xts: Float[Array, 'T_new D'], base_xts: Float[Array, 'T_old D']) -> Float[Array, 'T_new + T_old D']:
    """Interleave the new times and positions with the base times and positions"""

    if self.new_indices is None:
      return base_xts

    new_params, new_static = eqx.partition(new_xts, eqx.is_array)
    base_params, base_static = eqx.partition(base_xts, eqx.is_array)

    # Allocate memory for the combined times and positions
    combined_ts = self.ts
    T = combined_ts.shape[0]
    def zeros_like(x):
      return jnp.zeros((T, *x.shape[1:]), dtype=x.dtype)
    combined_params = jtu.tree_map(zeros_like, base_params)

    # Fill the buffer with the base times
    filled_params = util.fill_array(combined_params, self.base_indices, base_params)

    # Fill the buffer with the new times
    combined_params = util.fill_array(filled_params, self.new_indices, new_params)

    combined = eqx.combine(combined_params, base_static)
    return combined

  def filter_base_times(self, xts):
    return jtu.tree_map(lambda x: x[self.base_indices], xts)

  def filter_new_times(self, xts):
    return jtu.tree_map(lambda x: x[self.new_indices], xts)

################################################################################################################

def interleave_times(new_times: Union[Float[Array, 'T2'], None], base_times: Float[Array, 'T1']) -> DiscretizeInfo:
  return DiscretizeInfo(new_times, base_times)

################################################################################################################

def interleave_series(new_series: Union[TimeSeries, ProbabilisticTimeSeries],
                      base_series: Union[TimeSeries, ProbabilisticTimeSeries]) -> Union[TimeSeries, ProbabilisticTimeSeries]:
  """Interleave the new series with the base series"""
  if isinstance(new_series, TimeSeries) and isinstance(base_series, TimeSeries):
    # Do not convert time series to probabilistic time series
    pass
  else:
    if isinstance(new_series, TimeSeries):
      new_series = new_series.to_probabilistic_series()
    if isinstance(base_series, TimeSeries):
      base_series = base_series.to_probabilistic_series()

  # If either one is a mixed parameterization, then we need to convert the other one to mixed
  # Otherwise convert both to natural gaussian
  new_series_is_mixed = isinstance(new_series.node_potentials, MixedGaussian)
  base_series_is_mixed = isinstance(base_series.node_potentials, MixedGaussian)

  if new_series_is_mixed or base_series_is_mixed:
    new_series = new_series.to_mixed()
    base_series = base_series.to_mixed()
  else:
    new_series = new_series.to_nat()
    base_series = base_series.to_nat()

  interleave_info = interleave_times(new_series.ts, base_series.ts)
  out = interleave_info.interleave(new_series, base_series)
  return out

################################################################################################################

if __name__ == "__main__":
  from debug import *
  import jax
  import jax.numpy as jnp
  import jax.random as random
  import matplotlib.pyplot as plt
  import tqdm
  from diffusion_crf.sde.simple_sdes import BrownianMotion
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior

  import pickle
  data = pickle.load(open('data_dump.pkl', 'rb'))

  ts, yts, observation_mask = data['ts'], data['yts'], data['observation_mask']
  series = TimeSeries(ts, yts, observation_mask)
  y_dim = yts.shape[-1]
  x_dim = 2*y_dim
  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                    x_dim=x_dim,
                                                    sigma=0.01)
  prob_series = encoder(series)
  prob_series_batches = _make_windowed_batches(prob_series, window_size=100)
  import pdb; pdb.set_trace()

  inverse_covs = jnp.ones_like(yts)*798.6803363271825
  inverse_covs = jnp.where(observation_mask, inverse_covs, jnp.zeros_like(inverse_covs))

  means = jnp.where(observation_mask, yts, jnp.zeros_like(yts))

  x_dim = yts.shape[-1]
  key = random.PRNGKey(0)

  inverse_covs = inverse_covs.at[0].add(1.0)
  observation_mask = observation_mask.at[0].set(True)

  sde = BrownianMotion(0.1, x_dim)
  # sde = DenseLTISDE(key, x_dim)
  # sde = OrnsteinUhlenbeck(0.1, 0.1, x_dim)
  # sde = VariancePreserving(beta_min=0.1, beta_max=20, dim=x_dim)

  pts = ProbabilisticTimeSeries(ts,
                                means)

  pts = ProbabilisticTimeSeries(ts,
                                means,
                                certainty=inverse_covs)

  pts2 = ProbabilisticTimeSeries.from_potentials(pts.ts, pts.node_potentials)

  bridge = ConditionedLinearSDE(sde, pts2, parallel=False)

  crf = bridge.discretize()

  save_times = jnp.linspace(ts[0], ts[-1], 1000)

  def get_samples(key):
    return bridge.sample(key, save_times)
  samples = jax.vmap(get_samples)(random.split(key, 16))


  x0 = yts[0]
  flow = bridge.get_flow(ts[0], x0)


  from diffusion_crf.sde.ode_sde_solve import ODESolverParams, ode_solve
  ode_solver_params = ODESolverParams(solver='dopri5',
                                    stepsize_controller='pid',
                                    n_steps=20000)

  x0 = yts[0]
  def sample_ode(x0):
    return ode_solve(bridge.get_flow, x0, save_times=save_times, params=ode_solver_params)
  samples_ode = jax.vmap(sample_ode)(samples.yts[:,0])

  import pdb; pdb.set_trace()
