from debug import *
import os
import argparse
import numpy as np
from Utils.io_utils import load_yaml_config
from diffusion_crf import *
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int, PRNGKeyArray, Scalar, Bool
import jax.random as random
import jax.tree_util as jtu
import optax
import jax.numpy as jnp
import optax
import pickle
import equinox as eqx
import json
import tqdm
from jax._src.util import curry
import tqdm
from main import get_dataset, load_jax_model
from Models.experiment_identifier import ExperimentIdentifier, find_all_experiments
from Models.empirical_metrics import wasserstein2_distance, compute_univariate_metrics
from Utils.discriminative_metric_jax import discriminative_score_metrics
from Utils.metric import get_context_fid_score
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import datetime
from typing import Optional, Dict, List, Union, Tuple, Set

class ExperimentResultData(AbstractBatchableObject):
  """
  Container for experiment result data including original data series, latent sequences,
  and model samples.

  This class stores the results from model evaluation and provides methods for
  analyzing and visualizing these results.

  Attributes:
    data_series: TimeSeries representing the original data being evaluated (shape: T, D)
    latent_seq: Batch of latent sequences sampled from p(x_{1:N} | y_{1:N}) (shape: B, T, D)
    model_samples: Batch of model samples from q(x_{1:N} | y_{1:k}) (shape: B, T, D)
  """

  experiment_identifier: ExperimentIdentifier = eqx.field(static=True)
  data_series: TimeSeries # This represents a single time series (y_{1:N}) that we are evaluating.  shape is (T, D)
  latent_seq: TimeSeries # This is a batch of latent sequences (x_{1:N} ~ p(x_{1:N} | y_{1:N})) associated with data_series.  shape is (B, T, D)
  model_samples: TimeSeries # This is a batch of model samples (x_{1:N} ~ q(x_{1:N} | y_{1:k})) associated with data_series.  shape is (B, T, D)

  @property
  def batch_size(self):
    """
    Get the batch size of the data series.

    Returns:
      int: The batch size of the data_series
    """
    return self.data_series.batch_size

  @auto_vmap
  def get_crps_and_nll(self,
                       *,
                       compare_against_observations: bool = False,
                       evaluate_after_observations_times: bool = False):
    """
    Compute the Continuous Ranked Probability Score (CRPS) and negative log-likelihood of the model samples.

    This method evaluates how well the model's predicted distribution matches the
    ground truth distribution represented by the latent sequences.

    If compare_against_observations is True, then we will compute the CRPS and NLL
    of the model samples against the observations.  Otherwise, we will compute the
    CRPS and NLL of the model samples against the latent sequences.

    If evaluate_after_observations_times is True, then we will only use the
    observations after the last observation in the latent sequence.

    Returns:
      tuple: (crps, log_likelihood)
        - crps: Array of CRPS values
        - log_likelihood: Array of log-likelihood values

    Example:
      # Compute metrics for model evaluation
      crps, nll = result_data.get_crps_and_nll()

      # Average across dimensions
      mean_crps = jnp.mean(crps)
      mean_nll = jnp.mean(nll)

      print(f"Average CRPS: {mean_crps}, Average Negative Log-Likelihood: {mean_nll}")
    """

    # Retrieve the sequence that we'll use as the ground truth
    if compare_against_observations:
      true_seq = self.data_series.yts
    else:
      true_seq = self.latent_seq[0].yts

    # Retrieve the sequence that we'll use as the predicted sequence
    # We might need to truncate any latent dimensions that we don't want to compare against,
    # for example if we use Langevin dynamics to generate the latent sequence.  Ideally
    # we would pass the latent sequence through a decoder to go back to the original space,
    # but we don't have any setting where this is necessary.
    T, D = true_seq.shape
    pred_seq = self.model_samples.yts
    pred_seq = pred_seq[...,:D]

    # If we are only using future observations, then we need to truncate the predicted sequence
    # to the same length as the true sequence
    if evaluate_after_observations_times:
      config = self.experiment_identifier.create_config()
      dataset_config = config['dataset']
      pred_len = dataset_config['pred_length']
      true_seq = true_seq[...,-pred_len:,:]
      pred_seq = pred_seq[...,-pred_len:,:]

    assert true_seq.ndim == 2
    assert pred_seq.ndim == 3

    # Compute the CRPS and log likelihood of the test sequences under the model.
    eval_fn = jax.vmap(compute_univariate_metrics, in_axes=(1, 0)) # Vmap over the dimension axis
    eval_fn = jax.vmap(eval_fn, in_axes=(1, 0)) # Vmap over the time axis

    # Compute the CRPS and log likelihood of the test sequences under the model.
    crps, log_likelihood, _ = eval_fn(pred_seq, true_seq)

    nll = -log_likelihood
    return crps, nll

  def to_numpy_dict(self):
    """
    Convert the ExperimentResultData to a dictionary of NumPy arrays.

    This method is used for serializing the data to disk. It converts all
    JAX arrays to NumPy arrays and organizes them into a dictionary.

    Returns:
      dict: Dictionary containing all time series data as NumPy arrays

    Example:
      # Convert to numpy dict for saving
      numpy_dict = result_data.to_numpy_dict()

      # Save to disk
      np.savez_compressed('experiment_results.npz', **numpy_dict)
    """
    data_ts = np.array(self.data_series.ts)
    data_yts = np.array(self.data_series.yts)

    latent_seq_ts = np.array(self.latent_seq.ts)
    latent_seq_yts = np.array(self.latent_seq.yts)

    model_samples_ts = np.array(self.model_samples.ts)
    model_samples_yts = np.array(self.model_samples.yts)

    return dict(
      data_ts=data_ts,
      data_yts=data_yts,
      latent_seq_ts=latent_seq_ts,
      latent_seq_yts=latent_seq_yts,
      model_samples_ts=model_samples_ts,
      model_samples_yts=model_samples_yts
    )

  @staticmethod
  def from_numpy_dict(experiment_identifier: ExperimentIdentifier, numpy_dict: dict):
    """
    Create an ExperimentResultData instance from a dictionary of NumPy arrays.

    This is the inverse of to_numpy_dict() and is used for deserializing saved data.

    Args:
      numpy_dict: Dictionary containing the serialized time series data

    Returns:
      ExperimentResultData: Reconstructed instance

    Example:
      # Load previously saved results
      loaded_dict = np.load('experiment_results.npz')

      # Convert back to ExperimentResultData
      result_data = ExperimentResultData.from_numpy_dict(experiment_identifier, loaded_dict)

      # Now we can use methods on the data
      crps, nll = result_data.get_crps_and_nll()
    """
    as_jnp = lambda x: jnp.array(x)
    return ExperimentResultData(
      experiment_identifier=experiment_identifier,
      data_series=TimeSeries(as_jnp(numpy_dict['data_ts']), as_jnp(numpy_dict['data_yts'])),
      latent_seq=TimeSeries(as_jnp(numpy_dict['latent_seq_ts']), as_jnp(numpy_dict['latent_seq_yts'])),
      model_samples=TimeSeries(as_jnp(numpy_dict['model_samples_ts']), as_jnp(numpy_dict['model_samples_yts']))
    )

  def create_plot(self, index: Optional[int] = None,
                axes: Optional[List] = None,
                fig: Optional[plt.Figure] = None,
                show_plot: bool = True,
                add_title: bool = True,
                add_legend: bool = True):
    """
    Create a visualization of the ground truth, predictions, and latent series.

    This method generates a plot for each dimension of the data, showing:
    - Blue lines: Model predictions (samples)
    - Green lines: Latent sequences
    - Red markers: Original ground truth data, with first half emphasized

    Args:
      index: Index of the sequence in the batch to visualize
      axes: Optional list of axes to plot on (if None, new axes will be created)
      fig: Optional figure to plot on (if None, a new figure will be created)
      show_plot: Whether to call plt.show() after creating the plot
      add_title: Whether to add a title to the plot
      add_legend: Whether to add a legend to the plot

    Returns:
      tuple: (fig, axes) - The figure and axes objects used for the plot

    Example:
      # Visualize the first sequence in the batch
      result_data.create_plot(index=0)

      # Loop through several examples for comparison
      for i in range(5):
          result_data.create_plot(index=i)
          plt.savefig(f"sample_{i}.png")

      # Use with custom axes
      fig, axes = plt.subplots(n_rows, n_cols)
      result_data.create_plot(index=0, axes=axes, fig=fig, show_plot=False)
      plt.title("Custom title")
      plt.show()
    """
    if self.batch_size is not None:
      if index is None:
        raise ValueError("index must be provided if batch_size is not None")
      ground_truth = self.data_series[index]
      predictions = self.model_samples[index]
      latent_series = self.latent_seq[index]
    else:
      ground_truth = self.data_series
      predictions = self.model_samples
      latent_series = self.latent_seq

    # Get the number of points that we are conditioning on
    config = self.experiment_identifier.create_config()
    dataset_config = config['dataset']
    cond_len = dataset_config['seq_length'] - dataset_config['pred_length']

    num_dims = predictions.yts.shape[-1]
    ts = ground_truth.ts

    def series_to_dataframe(series, prefix="Sample"):
      df = pd.DataFrame({"time": series.ts[0]})
      for i in range(series.batch_size):
        df[f"{prefix} {i}"] = np.array(series.yts[i, :, k])
      return df

    # Check if we need to create new axes or use provided ones
    create_new_figure = fig is None or axes is None

    if create_new_figure:
      n_cols = 1
      n_rows = num_dims
      size = 6
      fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*size, 3*n_cols*size), sharex=True)

      # Set a publication-quality font
      plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Palatino', 'DejaVu Serif', 'Times'],
        'mathtext.fontset': 'stix',
      })

      # Handle case where there's only one dimension (axes won't be array)
      if n_rows == 1:
        axes = [axes]

      # ADJUSTABLE PARAMETER: Controls the top margin and related spacing
      # Increase this value (closer to 1.0) to reduce whitespace
      # Decrease this value (closer to 0.8) to add more whitespace
      top_margin = 0.95

      # Add a title to the figure with the model name
      if add_title:
        model_name = self.experiment_identifier.better_model_name
        dataset_name = self.experiment_identifier.config_name

        # Adjust figure to make room for title - controlled by top_margin
        plt.subplots_adjust(top=top_margin)

        # Position for title and legend
        title_y_pos = 0.98
        legend_y_pos = 0.94  # Place legend directly under the title

        fig.suptitle(f"{model_name}", fontsize=14, y=title_y_pos)
    else:
      # Ensure axes is a list for consistent indexing
      if not isinstance(axes, list) and not isinstance(axes, np.ndarray):
        axes = [axes]

      # Set the legend position for external axes
      legend_y_pos = 0.94

    # First pass: determine y-axis ranges for all plots
    y_ranges = []
    for k in range(num_dims):
      y_data = []
      if predictions is not None:
        y_data.extend(np.array(predictions.yts[..., k]).flatten())
      if latent_series is not None:
        y_data.extend(np.array(latent_series.yts[..., k]).flatten())
      if ground_truth is not None and k < ground_truth.yts.shape[-1]:
        y_data.extend(np.array(ground_truth.yts[..., k]).flatten())

      # Filter out non-finite values (NaN and Inf) before calculating range
      y_data_filtered = [val for val in y_data if np.isfinite(val)]

      if y_data_filtered:
        y_min = np.min(y_data_filtered)
        y_max = np.max(y_data_filtered)
        # Add a small buffer
        # Handle case where y_min and y_max might be the same
        if y_max > y_min:
            buffer = 0.1 * (y_max - y_min)
        elif y_max == y_min:
            buffer = 0.1 * abs(y_max) if y_max != 0 else 0.1 # Add a small fixed buffer if min=max=0
        else: # Should not happen if filtered correctly, but handle defensively
            buffer = 0.1
        # Check if buffer calculation resulted in non-finite values (e.g., due to extreme inputs)
        y_min_final = y_min - buffer
        y_max_final = y_max + buffer
        if not np.isfinite(y_min_final) or not np.isfinite(y_max_final):
             y_min_final, y_max_final = -1, 1 # Fallback to default range
        y_ranges.append((y_min_final, y_max_final))
      else:
        # Use default range if all data points were non-finite or list was empty
        y_ranges.append((-1, 1))

    # Create a tick formatter that has consistent decimal places
    from matplotlib.ticker import FuncFormatter

    def custom_formatter(x, pos):
      # Use a consistent format with 2 decimal places that handles both positive and negative numbers
      return f"{x:.2f}"

    formatter = FuncFormatter(custom_formatter)

    for k in range(num_dims):
      ax = axes[k]

      # Plot predictions and latent series
      if predictions is not None:
        series_to_dataframe(predictions).set_index('time').plot(ax=ax, alpha=0.2, color='blue')

      if latent_series is not None:
        series_to_dataframe(latent_series).set_index('time').plot(ax=ax, alpha=0.2, color='green')

      # Plot ground truth
      if ground_truth is not None and k < ground_truth.yts.shape[-1]:
        # First half: plot with red x's and small red squares around them
        ax.scatter(ts[:cond_len], ground_truth.yts[:cond_len, k], color='red', marker='x')
        ax.scatter(ts[:cond_len], ground_truth.yts[:cond_len, k], color='red', marker='s', facecolors='none', s=64)
        ax.plot(ts[:cond_len], ground_truth.yts[:cond_len, k], color='red', linewidth=0.5)

      # Set y-axis range from our calculated ranges
      y_min, y_max = y_ranges[k]
      ax.set_ylim(y_min, y_max)

      # Standardize y-ticks: use 5 evenly spaced ticks
      from matplotlib.ticker import MaxNLocator
      ax.yaxis.set_major_locator(MaxNLocator(nbins=5))

      # Apply custom tick formatter for y-axis
      ax.yaxis.set_major_formatter(formatter)

      # Make y-tick labels smaller and align them right
      ax.tick_params(axis='y', labelsize=8)
      ax.yaxis.set_tick_params(pad=1)

      # Right-align y-tick labels for better alignment with different length numbers
      for label in ax.get_yticklabels():
        label.set_horizontalalignment('right')

      # Handle x-tick labels visibility
      if create_new_figure and k < num_dims - 1:
        # Hide x-tick labels but keep ticks visible for non-bottom plots
        plt.setp(ax.get_xticklabels(), visible=False)
      else:
        # Make bottom x-tick labels visible and smaller
        ax.tick_params(axis='x', labelsize=8)
        plt.setp(ax.get_xticklabels(), visible=True)

      ax.legend().remove()

    # Ensure all x-tick marks are visible
    for ax in axes:
      ax.xaxis.set_tick_params(which='both', size=4, width=1, direction='out')

    # Add y-axis labels for each dimension
    for k, ax in enumerate(axes):
      # Add y-axis label with dimension index
      ax.set_ylabel(f"Dim {k}", fontsize=10)

      # Add extra space for y-label
      ax.yaxis.labelpad = 10

    # Add x-axis label only to the bottom plot (time)
    axes[-1].set_xlabel('Time', fontsize=10)

    # Add a legend to explain the colors
    if add_legend and create_new_figure:
      legend_elements = [
        plt.Line2D([0], [0], color='blue', alpha=0.5, lw=2, label='Model Samples'),
        plt.Line2D([0], [0], color='green', alpha=0.5, lw=2, label='Latent Sequence'),
        plt.Line2D([0], [0], color='red', marker='x', lw=1, label='Observations')
      ]
      # Position legend outside and above all plots
      fig.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, legend_y_pos),
                ncol=3, fontsize=9, frameon=True, borderaxespad=0.)

    # Adjust layout to account for right-aligned y-tick labels without affecting the title area
    if create_new_figure:
      plt.tight_layout(rect=[0, 0, 1, 0.95])

    if show_plot and create_new_figure:
      plt.show()
      if not fig._suptitle:  # Only close if not a custom figure we want to reuse
        plt.close()

    return fig, axes


def create_plots(results: List[ExperimentResultData],
                 index: int,
                 show_plot: bool = True,
                 title_suffixes: Optional[List[str]] = None,
                 sup_title_suffix: str = "",
                 use_max_dims: bool = True):
  """
  Create side-by-side plots of multiple experiment results for comparison.

  This function arranges multiple model results in columns, with each row
  representing a dimension of the data. Each column displays the results from
  one model, using the same styling and format as the individual create_plot method.

  Args:
    results: List of ExperimentResultData objects to compare
    index: Index of the sequence in the batch to visualize for all models
    show_plot: Whether to call plt.show() after creating the plot
    title_suffixes: Optional list of strings to append to model names in the plot titles
    sup_title_suffix: Optional string to append to the overall plot title
    use_max_dims: Whether to plot up to the maximum number of dimensions across
                  all results (True) or only dimensions present in all results (False)

  Returns:
    tuple: (fig, axes) - The figure and axes objects used for the plots

  Example:
    # Compare three different models on the same dataset
    create_plots([result1, result2, result3], index=0)

    # Save the comparison plot
    fig, _ = create_plots([result1, result2], index=2)
    fig.savefig("model_comparison.png", dpi=300)
  """
  if not results:
    raise ValueError("No results provided for plotting")

  if title_suffixes is not None and len(title_suffixes) != len(results):
    raise ValueError(f"Number of title suffixes must match number of results, got {len(title_suffixes)} and {len(results)}")

  # Find dimensions across all results
  dims_list = [result.model_samples.yts.shape[-1] for result in results]

  # Use either the minimum or maximum dimensions based on use_max_dims flag
  if use_max_dims:
    num_dims = max(dims_list)
  else:
    num_dims = min(dims_list)

  n_models = len(results)

  # Set a publication-quality font
  plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'Palatino', 'DejaVu Serif', 'Times'],
    'mathtext.fontset': 'stix',
  })

  # Create figure with one column per model, one row per dimension
  # Use a reasonable size that scales with the number of models
  width_per_model = 6
  height_per_dim = 3
  fig, axes = plt.subplots(num_dims, n_models,
                           figsize=(n_models*width_per_model, num_dims*height_per_dim),
                           sharex='col', sharey='row')

  # If there's only one dimension, axes won't be 2D
  if num_dims == 1:
    axes = np.array([axes])

  # If there's only one model, axes won't be 2D
  if n_models == 1:
    axes = axes.reshape(num_dims, 1)

  # Create column labels (model names)
  for i, result in enumerate(results):
    title_str = result.experiment_identifier.better_model_name
    if title_suffixes is not None:
      title_str += f" {title_suffixes[i]}"
    axes[0, i].set_title(title_str, fontsize=12)

  # Plot each model in its column
  for i, result in enumerate(results):
    # Extract column of axes for this model
    model_axes = axes[:, i]

    # Get the actual dimension count for this result
    result_dims = result.model_samples.yts.shape[-1]

    # Call the create_plot method for dimensions that exist in this result
    if result_dims > 0:
      # For dimensions that exist in this result, plot normally
      result.create_plot(index=index,
                       axes=model_axes[:result_dims].tolist(),
                       fig=fig,
                       show_plot=False,
                       add_title=False,
                       add_legend=False)

    # For dimensions that don't exist in this result, hide the axes or clear them
    for k in range(result_dims, num_dims):
      ax = model_axes[k]
      # Hide tick labels for empty axes
      ax.set_xticks([])
      ax.set_yticks([])
      # Add a light text indicating this dimension doesn't exist for this model
      if i == 0:  # Only add dimension label on leftmost column
        ax.set_ylabel(f"Dim {k}", fontsize=10, alpha=0.5)

    # Only the leftmost column needs y labels
    if i > 0:
      for ax in model_axes:
        ax.set_ylabel("")

  # Add a legend to explain the colors (only once for the entire figure)
  legend_elements = [
    plt.Line2D([0], [0], color='blue', alpha=0.5, lw=2, label='Model Samples'),
    plt.Line2D([0], [0], color='green', alpha=0.5, lw=2, label='Latent Sequence'),
    plt.Line2D([0], [0], color='red', marker='x', lw=1, label='Observations')
  ]

  # Position legend at the top center of the figure
  fig.legend(handles=legend_elements, loc='upper center',
            bbox_to_anchor=(0.5, 0.96), ncol=3, fontsize=10,
            frameon=True, borderaxespad=0.)

  # Add dataset name as a subtitle
  dataset_name = results[0].experiment_identifier.get_nice_dataset_name()
  fig.suptitle(f"Dataset: {dataset_name} {sup_title_suffix}", fontsize=18, y=1.0)

  # Calculate appropriate spacing based on grid dimensions
  # Use inverse scaling so spacing decreases as dimensions increase
  w_base, h_base = 0.2, 0.3  # Base values for a small grid
  w_scale = max(0.25, 1.0 / n_models)  # Scale spacing down as columns increase
  h_scale = max(0.5, 1.0 / num_dims)   # Scale spacing down as rows increase

  # Adjust layout for the overall figure
  plt.tight_layout(rect=[0, 0, 1, 0.97])  # Less reserved space at top
  plt.subplots_adjust(
      wspace=w_base * w_scale,  # Width spacing scales with columns
      hspace=h_base * h_scale,  # Height spacing scales with rows
      top=0.90  # Move plots much closer to the top (title and legend)
  )

  if show_plot:
    plt.show()
    if not fig._suptitle:  # Only close if not a custom figure we want to reuse
      plt.close()

  return fig, axes



