from debug import *
import os
import argparse
import numpy as np
from Utils.io_utils import load_yaml_config
from diffusion_crf import *
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int, PRNGKeyArray, Scalar, Bool
import jax.random as random
import jax.tree_util as jtu
import optax
import jax.numpy as jnp
import optax
import pickle
import equinox as eqx
import json
import tqdm
from jax._src.util import curry
import tqdm
from main import get_dataset, load_jax_model
from Models.experiment_identifier import ExperimentIdentifier, find_all_experiments
from Models.empirical_metrics import wasserstein2_distance, compute_univariate_metrics
from Utils.discriminative_metric_jax import discriminative_score_metrics
from Utils.metric import get_context_fid_score
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import datetime
from typing import Optional, Dict, List, Union, Tuple, Set
from Models.experiment_samples import ExperimentResultData
"""
This module implements a comprehensive system for checkpointing experiment results during model evaluation.
It provides classes for:

1. SampleGenerationMetadata - Tracks and persists evaluation progress, including parameters and completion status
3. ResultDataCheckpointer - Manages the saving and loading of intermediate evaluation results
4. ExperimentManager - Coordinates multiple experiments and provides tools for filtering and analyzing results

The checkpointing system enables resumable evaluations of long-running experiments and facilitates
the aggregation and analysis of results across multiple experiment configurations.
"""

################################################################################################################

class SampleGenerationMetadata:
  """
  Improved metadata class for experiment evaluation that tracks completion by data indices.

  This class makes the evaluation process invariant to the checkpoint batch size,
  allowing for resuming evaluations with different batch sizes on different hardware.
  """

  def __init__(self,
               test_data_size: int,
               n_samples_for_empirical_distribution: int,
               highest_completed_index: int = -1,
               evaluation_started: Optional[str] = None,
               completion_time: Optional[str] = None):
    """
    Initialize the metadata with index-based tracking.

    Args:
      test_data_size: Size of the test dataset
      n_samples_for_empirical_distribution: Samples for empirical distribution
      highest_completed_index: Highest index that has been processed (-1 if none)
      evaluation_started: ISO timestamp when evaluation started
      completion_time: ISO timestamp when evaluation completed
    """
    self.test_data_size = test_data_size
    self.n_samples_for_empirical_distribution = n_samples_for_empirical_distribution

    # Initialize highest completed index
    self.highest_completed_index = highest_completed_index

    # Use current time if not provided
    if evaluation_started is None:
      self.evaluation_started = self._get_current_time()
    else:
      self.evaluation_started = evaluation_started

    self.completion_time = completion_time

    # If we should be complete but don't have a timestamp yet, update it
    if self.is_completed and not self.completion_time:
      self.completion_time = self._get_current_time()

    self.last_updated = self._get_current_time()

  def _get_current_time(self) -> str:
    """Get the current time as ISO format string."""
    return datetime.datetime.now().isoformat()

  @property
  def number_of_completed_iterations(self) -> int:
    """Get the number of completed iterations."""
    return self.highest_completed_index + 1

  @property
  def is_completed(self) -> bool:
    """Check if all indices have been processed."""
    return self.number_of_completed_iterations >= self.test_data_size

  @property
  def progress_percentage(self) -> float:
    """Calculate the progress percentage based on completed indices."""
    if self.test_data_size == 0:  # Avoid division by zero
      return 100.0 if self.is_completed else 0.0
    # Add 1 to highest_completed_index since it's 0-indexed (e.g., -1 means 0 completed)
    completed_count = min(self.number_of_completed_iterations, self.test_data_size)
    return round(100 * completed_count / self.test_data_size, 2)

  def get_next_indices(self, batch_size: int) -> Optional[List[int]]:
    """
    Get the next batch of indices to process.

    Args:
      batch_size: Number of indices to process in this batch

    Returns:
      List of indices that still need to be processed, or None if all indices are processed
    """
    # Start from the index after the highest completed one
    start_idx = self.number_of_completed_iterations

    # If start_idx is already at or beyond test_data_size, there's nothing left to process
    if start_idx >= self.test_data_size:
      return None

    # Ensure we don't go beyond test_data_size
    end_idx = min(start_idx + batch_size, self.test_data_size)

    # Return a list of consecutive indices
    return list(range(start_idx, end_idx))

  def mark_highest_index_completed(self, highest_index: int) -> None:
    """
    Mark all indices up to and including highest_index as completed.

    Args:
      highest_index: Highest index that has been processed
    """
    if highest_index > self.highest_completed_index:
      self.highest_completed_index = highest_index
      self.last_updated = self._get_current_time()

      # Update completion timestamp if newly completed
      was_complete = self.completion_time is not None
      if self.is_completed and not was_complete:
        self.completion_time = self._get_current_time()

  def parameters_match(self, test_data_size: int, n_samples_for_empirical_distribution: int) -> bool:
    """
    Check if parameters match the provided values.

    Args:
      test_data_size: Test data size to check
      n_samples_for_empirical_distribution: Empirical distribution sample count to check

    Returns:
      bool: True if parameters match, False otherwise
    """
    return (self.test_data_size == test_data_size and
            self.n_samples_for_empirical_distribution == n_samples_for_empirical_distribution)

  def to_dict(self) -> dict:
    """
    Convert metadata to a dictionary for serialization.

    Returns:
      dict: Dictionary representation of the metadata
    """
    return {
      "test_data_size": self.test_data_size,
      "n_samples_for_empirical_distribution": self.n_samples_for_empirical_distribution,
      "highest_completed_index": self.highest_completed_index,
      "completed": self.is_completed,  # Use the property
      "evaluation_started": self.evaluation_started,
      "completion_time": self.completion_time,
      "last_updated": self.last_updated,
      "progress_percentage": self.progress_percentage
    }

  @classmethod
  def from_dict(cls, data: dict) -> 'SampleGenerationMetadata':
    """
    Create metadata from a dictionary.

    Args:
      data: Dictionary containing metadata

    Returns:
      SampleGenerationMetadata: New metadata instance
    """
    return cls(
      test_data_size=data["test_data_size"],
      n_samples_for_empirical_distribution=data["n_samples_for_empirical_distribution"],
      highest_completed_index=data["highest_completed_index"],
      evaluation_started=data["evaluation_started"],
      completion_time=data.get("completion_time")  # Optional
    )

  @classmethod
  def load(cls, file_path: str) -> Optional['SampleGenerationMetadata']:
    """
    Load metadata from a file.

    Args:
      file_path: Path to the metadata file

    Returns:
      SampleGenerationMetadata: Loaded metadata, or None if file doesn't exist
    """
    if not os.path.exists(file_path):
      return None

    with open(file_path, 'r') as f:
      data = json.load(f)

    return cls.from_dict(data)

  def save(self, file_path: str) -> None:
    """
    Save metadata to a file.

    Args:
      file_path: Path to save the metadata
    """
    # Ensure directory exists
    os.makedirs(os.path.dirname(file_path), exist_ok=True)

    # Save as JSON
    with open(file_path, 'w') as f:
      json.dump(self.to_dict(), f, indent=2)

  def __str__(self) -> str:
    """
    Return a human-readable string representation of the metadata.

    Returns:
      str: Formatted metadata information
    """
    status = "Completed" if self.is_completed else "In Progress"
    completed_count = min(self.highest_completed_index + 1, self.test_data_size)

    return (
      f"Evaluation Metadata:\n"
      f"  Status: {status} ({self.progress_percentage:.1f}%)\n"
      f"  Test data size: {self.test_data_size}\n"
      f"  Samples for empirical distribution: {self.n_samples_for_empirical_distribution}\n"
      f"  Progress: {completed_count}/{self.test_data_size} samples\n"
      f"  Started: {self.evaluation_started}\n"
      f"  Last updated: {self.last_updated}\n"
      f"  Completion time: {self.completion_time or 'Not completed'}"
    )

################################################################################################################

class ResultDataCheckpointer:
  """
  An improved checkpointer that uses index-based tracking for experiment result data.

  This version is invariant to batch size changes, allowing evaluation to be resumed
  with different batch sizes depending on available hardware resources.
  """

  experiment_identifier: ExperimentIdentifier

  def __init__(self, experiment_identifier: ExperimentIdentifier, restart_evaluation: bool = False):
    """
    Initialize the checkpointer for a specific experiment.

    Args:
      experiment_identifier: The identifier for the experiment being checkpointed
      restart_evaluation: If True, clear any existing checkpoints and start from scratch
    """
    self.experiment_identifier = experiment_identifier
    self.restart_evaluation = restart_evaluation

    # If restart_evaluation is True, clear any existing checkpoints
    if restart_evaluation:
      self.clear_checkpoints()

  def __str__(self):
    """
    Return a string representation of the checkpointer.

    Returns:
      str: A formatted string showing the experiment details associated with this checkpointer
    """
    return (
      f"ResultDataCheckpointer:\n"
      f"  Config File: {self.experiment_identifier.config_name}\n"
      f"  Model: {self.experiment_identifier.model_name}\n"
      f"  Training Objective: {self.experiment_identifier.objective}\n"
      f"  SDE Type: {self.experiment_identifier.sde_type}\n"
      f"  Interpolation Frequency: {self.experiment_identifier.freq}\n"
      f"  Experiment Group: {self.experiment_identifier.group}\n"
      f"  Random Seed: {self.experiment_identifier.global_key_seed}"
    )

  def has_existing_results(self) -> bool:
    """
    Check if there are existing checkpoint results for this experiment.

    This method looks for checkpoint files in the experiment's checkpoint directory.
    If restart_evaluation was set to True during initialization, this will always return False.

    Returns:
      bool: True if checkpoint files exist, False otherwise
    """
    # If we're restarting, we should always report no existing results
    if self.restart_evaluation:
      return False

    # Get the checkpoint directory from the experiment identifier
    checkpoint_dir = self.experiment_identifier.checkpoint_dir

    # Check if the directory exists and has any checkpoint files
    if not os.path.exists(checkpoint_dir):
      return False

    # Look for any files matching the indices pattern
    for filename in os.listdir(checkpoint_dir):
      if filename.startswith("indices_") and filename.endswith(".npz"):
        return True

    return False

  def _get_indices_filename(self, start_index: int, end_index: int) -> str:
    """
    Get the filename for a specific index range's checkpoint data.

    Args:
      start_index: The starting index (inclusive)
      end_index: The ending index (inclusive)

    Returns:
      str: The full path to the checkpoint file for this index range
    """
    # The experiment identifier already handles creating the checkpoint directory
    checkpoint_dir = self.experiment_identifier.checkpoint_dir

    # Create a filename based on the index range
    filename = f"indices_{start_index:06d}_{end_index:06d}.npz"

    return os.path.join(checkpoint_dir, filename)

  def get_highest_completed_index(self) -> int:
    """
    Get the highest index that has been successfully processed.

    This method scans the checkpoint directory for index files and determines
    the highest index that has been processed. This is used to determine where
    to resume evaluation from after an interruption.

    Returns:
      int: The highest index that has been processed, or -1 if no indices have been processed
    """
    # If we have metadata, use that for a faster lookup
    metadata: Union[SampleGenerationMetadata, None] = self.get_evaluation_metadata()
    if metadata is None:
      return -1
    return metadata.highest_completed_index

  def clear_checkpoints(self):
    """
    Remove all checkpoint files for this experiment.

    This method deletes all checkpoint files in the experiment's checkpoint directory.
    It's called when restart_evaluation is True, or can be called manually to clean up.
    """
    checkpoint_dir = self._get_checkpoints_dir()

    if os.path.exists(checkpoint_dir):
      # Delete indices files
      for filename in os.listdir(checkpoint_dir):
        if (filename.startswith("indices_") and filename.endswith(".npz")) or \
           (filename.startswith("iteration_") and filename.endswith(".npz")):  # Also clean old format
          filepath = os.path.join(checkpoint_dir, filename)
          try:
            os.remove(filepath)
            print(f"Removed checkpoint file: {filepath}")
          except Exception as e:
            print(f"Warning: Could not remove {filepath}: {e}")

      # Delete metadata file if it exists
      metadata_path = self._get_evaluation_metadata_path()
      if os.path.exists(metadata_path):
        try:
          os.remove(metadata_path)
          print(f"Removed metadata file: {metadata_path}")
        except Exception as e:
          print(f"Warning: Could not remove metadata file: {e}")

    # Delete combined results cache if it exists
    combined_cache_path = os.path.join(checkpoint_dir, "combined_results.npz")
    if os.path.exists(combined_cache_path):
      try:
        os.remove(combined_cache_path)
        print(f"Removed combined results cache: {combined_cache_path}")
      except Exception as e:
        print(f"Warning: Could not remove combined results cache: {e}")

  def _get_checkpoints_dir(self) -> str:
    """
    Get the path to the checkpoints directory for this experiment.

    Returns:
      str: Path to the checkpoints directory
    """
    return self.experiment_identifier.checkpoint_dir

  def _get_evaluation_metadata_path(self) -> str:
    """
    Get the path to the metadata file for this experiment.

    Returns:
      str: Path to the metadata file
    """
    checkpoints_dir = self._get_checkpoints_dir()
    return os.path.join(checkpoints_dir, "metadata2.json")  # Using metadata2.json for the new format

  def has_evaluation_metadata(self) -> bool:
    """
    Check if metadata exists for this experiment.

    Returns:
      bool: True if metadata exists, False otherwise
    """
    return os.path.exists(self._get_evaluation_metadata_path())

  def get_evaluation_metadata(self) -> Optional[SampleGenerationMetadata]:
    """
    Get the metadata for this experiment.

    Returns:
      SampleGenerationMetadata: The metadata, or None if it doesn't exist
    """
    metadata_path = self._get_evaluation_metadata_path()
    return SampleGenerationMetadata.load(metadata_path)

  def save_evaluation_metadata(self, metadata: SampleGenerationMetadata) -> None:
    """
    Save metadata for this experiment.

    Args:
      metadata: The metadata to save
    """
    metadata_path = self._get_evaluation_metadata_path()
    metadata.save(metadata_path)

  def create_evaluation_metadata(self, test_data_size: int,
                     n_samples_for_empirical_distribution: int) -> SampleGenerationMetadata:
    """
    Create new metadata for this experiment.

    Args:
      test_data_size: Size of the test dataset
      n_samples_for_empirical_distribution: Number of samples for the empirical distribution

    Returns:
      SampleGenerationMetadata: The newly created metadata
    """
    # Create the metadata object
    metadata = SampleGenerationMetadata(
      test_data_size=test_data_size,
      n_samples_for_empirical_distribution=n_samples_for_empirical_distribution
    )

    # Save it to disk
    self.save_evaluation_metadata(metadata)

    return metadata

  def save_result_data(self, result_data: ExperimentResultData, indices: List[int]):
    """
    Save result data for a specific batch of indices to disk.

    This method converts the result data to NumPy arrays and saves it as a compressed
    .npz file named based on the index range.

    Args:
      result_data: The experiment result data to save
      indices: The indices that were processed in this batch

    Raises:
      Exception: If there's an error saving the data to disk
    """
    if not indices:
      raise ValueError("No indices provided to save_result_data")

    # Get the min and max indices for the filename
    start_index = min(indices)
    end_index = max(indices)

    # Get the filename for this index range
    filepath = self._get_indices_filename(start_index, end_index)

    # Convert result data to a numpy dictionary
    numpy_dict = result_data.to_numpy_dict()

    # Create the directory if it doesn't exist
    os.makedirs(os.path.dirname(filepath), exist_ok=True)

    try:
      # Save the dictionary to the file
      np.savez_compressed(filepath, **numpy_dict)
      print(f"Successfully saved checkpoint for indices {start_index}-{end_index} to {filepath}")

      # Update the metadata with the current progress
      if self.has_evaluation_metadata():
        metadata: SampleGenerationMetadata = self.get_evaluation_metadata()
        metadata.mark_highest_index_completed(end_index)
        self.save_evaluation_metadata(metadata)
      else:
        # Create new metadata if it doesn't exist yet
        metadata: SampleGenerationMetadata = self.create_evaluation_metadata(
          test_data_size=result_data.data_series.batch_size,
          n_samples_for_empirical_distribution=result_data.latent_seq.batch_size
        )
        metadata.mark_highest_index_completed(end_index)
        self.save_evaluation_metadata(metadata)

    except Exception as e:
      print(f"Error saving checkpoint for indices {start_index}-{end_index}: {e}")
      raise

  def load_all_results(self) -> Optional[ExperimentResultData]:
    """
    Load all checkpoint results and combine them into a single result.

    This method:
    1. Finds all checkpoint files for the experiment
    2. Loads each file in sequence
    3. Concatenates the results into a single ExperimentResultData object

    Returns:
      ExperimentResultData: A single result object with data from all checkpoints
                           combined, or None if no checkpoints exist
    """
    checkpoint_dir = self.experiment_identifier.checkpoint_dir

    if not os.path.exists(checkpoint_dir):
      return None

    # Path for the combined results cache file
    combined_cache_path = os.path.join(checkpoint_dir, "combined_results.npz")

    # Get the highest completed index
    highest_index = self.get_highest_completed_index()
    if highest_index < 0:
      return None

    # Check if we have an up-to-date combined cache file
    if os.path.exists(combined_cache_path):
      try:
        # Load metadata to check if cache is current
        cache_metadata = dict(np.load(combined_cache_path, allow_pickle=True))

        if cache_metadata.get('highest_index') == highest_index:
          # Cache is up-to-date, use it
          print(f"Loading combined results from cache ({combined_cache_path})")
          result_data = ExperimentResultData.from_numpy_dict(self.experiment_identifier, cache_metadata)
          return result_data
      except Exception as e:
        print(f"Error loading cached results: {e}, will rebuild from individual files")

    # Find all checkpoint files
    result_files = []
    for filename in os.listdir(checkpoint_dir):
      if filename.startswith("indices_") and filename.endswith(".npz"):
        result_files.append(os.path.join(checkpoint_dir, filename))

    if not result_files:
      return None

    # Sort files by starting index to ensure correct order
    result_files.sort(key=lambda x: int(os.path.basename(x).split("_")[1]))

    # Load and combine all results
    results = []
    for filepath in result_files:
      try:
        # Load the numpy file
        numpy_dict = dict(np.load(filepath))

        # Convert to ExperimentResultData
        result_data = ExperimentResultData.from_numpy_dict(self.experiment_identifier, numpy_dict)
        results.append(result_data)
      except Exception as e:
        print(f"Error loading checkpoint file {filepath}: {e}")

    # Concatenate all the results
    if results:
      # We need to split the results into the numpy and static parts
      _, static = eqx.partition(results[0], eqx.is_array)
      results_array = [eqx.partition(r, eqx.is_array)[0] for r in results]

      # Concatenate the arrays
      combined_results_array = jax.tree_util.tree_map(lambda *xs: jnp.concatenate(xs, axis=0), *results_array)

      # Combine the arrays and static parts
      combined_results = eqx.combine(combined_results_array, static)

      # Save the combined results to cache if metadata shows we're complete
      metadata = self.get_evaluation_metadata()
      if (metadata is not None) and metadata.is_completed and (combined_results is not None):
        try:
          # Convert to numpy dict
          numpy_dict = combined_results.to_numpy_dict()

          # Add metadata
          numpy_dict['highest_index'] = highest_index
          numpy_dict['timestamp'] = datetime.datetime.now().isoformat()

          # Save to disk
          np.savez_compressed(combined_cache_path, **numpy_dict)
          print(f"Saved combined results cache to {combined_cache_path}")
        except Exception as e:
          print(f"Error saving combined results cache: {e}")

      return combined_results

    return None

  def get_status(self) -> str:
    """
    Get a simple status report for this evaluation.

    Returns:
      str: A status string showing progress information
    """
    status_lines = []
    status_lines.append(f"Experiment: {self.experiment_identifier.get_model_identifier()}")

    # Check if metadata exists
    if not self.has_evaluation_metadata():
      highest_index = self.get_highest_completed_index()
      if highest_index >= 0:
        status_lines.append(f"No metadata found. Highest completed index: {highest_index}")
      else:
        status_lines.append("No metadata found. No completed indices.")
      return "\n".join(status_lines)

    # Report metadata-based status
    metadata = self.get_evaluation_metadata()
    completed_count = min(metadata.highest_completed_index + 1, metadata.test_data_size)
    status_lines.append(f"Progress: {metadata.progress_percentage:.1f}% ({completed_count}/{metadata.test_data_size} samples)")
    status_lines.append(f"Status: {'Completed' if metadata.is_completed else 'In progress'}")
    status_lines.append(f"Last updated: {metadata.last_updated}")

    return "\n".join(status_lines)

def get_evaluation_status(experiment_identifier: ExperimentIdentifier) -> Dict[str, Any]:
  """Get evaluation status summary for an experiment.

  Args:
    experiment_identifier: The experiment to get status for

  Returns:
    Dictionary with evaluation status information
  """
  checkpointer = ResultDataCheckpointer(experiment_identifier)

  # Check if metadata exists
  if not checkpointer.has_evaluation_metadata():
    highest_index = checkpointer.get_highest_completed_index()
    if highest_index >= 0:
      status_summary = f"Partial evaluation data found (no metadata). Highest index: {highest_index}"
    else:
      status_summary = "No evaluation data available"

    return {
      "is_complete": False,
      "status_summary": status_summary
    }

  # Get metadata for more detailed status
  metadata = checkpointer.get_evaluation_metadata()
  completed_count = min(metadata.highest_completed_index + 1, metadata.test_data_size)

  if metadata.is_completed:
    status_summary = f"Evaluation complete: {completed_count}/{metadata.test_data_size} samples processed"
  else:
    status_summary = f"Evaluation in progress: {metadata.progress_percentage:.1f}% ({completed_count}/{metadata.test_data_size} samples)"

  return {
    "is_complete": metadata.is_completed,
    "current_sample": metadata.highest_completed_index + 1,
    "total_samples": metadata.test_data_size,
    "progress_percentage": metadata.progress_percentage,
    "last_updated": metadata.last_updated,
    "status_summary": status_summary
  }

################################################################################################################

class ExperimentManager:
  """
  A manager for working with multiple experiments and their result data.
  Provides capabilities for filtering, grouping, and batch-loading experiment results.
  """

  def __init__(self, checkpointers=None):
    """
    Initialize an ExperimentManager with a collection of checkpointers.

    Args:
      checkpointers: Optional list of ResultDataCheckpointer instances. If None,
                     automatically loads all available checkpointers from disk.

    Examples:
      # Load all available experiments
      manager = ExperimentManager()

      # Create a manager with specific checkpointers
      specific_manager = ExperimentManager([ckpt1, ckpt2, ckpt3])

      # Create a filtered manager from an existing one
      subset = ExperimentManager(manager.filter(model_name='neural_crf').checkpointers)
    """
    if checkpointers is None:
      experiment_identifiers = find_all_experiments()
      self.checkpointers = [ResultDataCheckpointer(experiment_identifier=eid) for eid in experiment_identifiers]
    else:
      self.checkpointers = list(checkpointers)

    # Create a DataFrame for easy filtering and grouping
    self.df = self._create_dataframe()

  def _create_dataframe(self):
    """Convert checkpointers to a pandas DataFrame for easier manipulation"""
    data = []
    for i, ckpt in enumerate(self.checkpointers):
      ei: ExperimentIdentifier = ckpt.experiment_identifier

      # Get basic experiment info
      exp_info = {
        'index': i,
        'config_name': ei.config_name,
        'model_name': ei.model_name,
        'objective': ei.objective,
        'sde_type': ei.sde_type,
        'freq': ei.freq,
        'group': ei.group,
        'seed': ei.global_key_seed,
        'checkpointer': ckpt
      }

      if ckpt.has_evaluation_metadata() == False:
        continue

      # Add simple progress information
      metadata: SampleGenerationMetadata = ckpt.get_evaluation_metadata()
      exp_info['progress'] = f"{metadata.number_of_completed_iterations}/{metadata.test_data_size} ({metadata.progress_percentage:.1f}%)"
      exp_info['completed'] = metadata.is_completed

      data.append(exp_info)

    return pd.DataFrame(data)

  def filter(self, **kwargs):
    """
    Filter experiments by their attributes and return a new ExperimentManager.

    Args:
      **kwargs: Key-value pairs of attributes to filter by. The key must be one of:
                'config_name', 'model_name', 'objective', 'sde_type', 'freq',
                'group', or 'seed'. Values can be single items or lists/tuples.

    Returns:
      A new ExperimentManager containing only the experiments that match all criteria.

    Examples:
      # Filter to just one model type
      diffusion_exps = manager.filter(model_name='diffusion_crf')

      # Multiple conditions
      brownian_elbo = manager.filter(
          sde_type='brownian',
          objective='mse',
          freq=0
      )

      # Filter with multiple allowed values
      multiple_models = manager.filter(
          model_name=['neural_crf', 'diffusion_crf'],
          group='paper_final'
      )

      # Chain filters for progressive refinement
      filtered = manager.filter(freq=4).filter(seed=[0, 1, 2])
    """
    df = self.df.copy()
    for key, value in kwargs.items():
      if key in df.columns:
        if isinstance(value, (list, tuple)):
          df = df[df[key].isin(value)]
        else:
          df = df[df[key] == value]

    # Create a new manager with the filtered checkpointers
    indices = df['index'].tolist()
    filtered_checkpointers = [self.checkpointers[i] for i in indices]
    return ExperimentManager(filtered_checkpointers)

  def group_by(self, *attributes):
    """
    Group experiments by specified attributes.

    Args:
      *attributes: Attribute names to group by. Must be columns in the dataframe:
                  'config_name', 'model_name', 'objective', 'sde_type', 'freq',
                  'group', or 'seed'.

    Returns:
      A dictionary where:
        - Keys are strings representing the grouping values (e.g., "model_name=neural_crf, objective=elbo")
        - Values are ExperimentManager instances containing the experiments in that group

    Examples:
      # Group by a single attribute
      model_groups = manager.group_by('model_name')
      for name, group in model_groups.items():
          print(f"{name}: {len(group)} experiments")

      # Access a specific group
      neural_crf_exps = model_groups['model_name=neural_crf']

      # Group by multiple attributes
      complex_groups = manager.group_by('model_name', 'sde_type', 'objective')

      # Group after filtering
      filtered_groups = manager.filter(freq=4).group_by('model_name', 'seed')

      # Load results from each group for comparison
      group_results = {}
      for name, group in manager.group_by('model_name').items():
          group_results[name] = group.load_results()
    """
    grouped = self.df.groupby(list(attributes))
    result = {}

    for group_key, group_df in grouped:
      if not isinstance(group_key, tuple):
        group_key = (group_key,)

      # Get checkpointers for this group
      indices = group_df['index'].tolist()
      group_checkpointers = [self.checkpointers[i] for i in indices]

      # Create a descriptive key
      key_parts = []
      for i, attr in enumerate(attributes):
        key_parts.append(f"{attr}={group_key[i]}")
      group_name = ", ".join(key_parts)

      result[group_name] = ExperimentManager(group_checkpointers)

    return result

  def load_results(self):
    """
    Load and combine results from all checkpointers in this manager.

    This method:
    1. Loads all experiment results from each checkpointer
    2. Concatenates them along the batch dimension
    3. Returns a single combined ExperimentResultData object

    Returns:
      ExperimentResultData: A single result object with data from all experiments
                           batched together, or None if no results are available.

    Examples:
      # Load all results from all experiments
      all_results = manager.load_results()

      # Load results from a filtered subset
      mse_results = manager.filter(objective='mse').load_results()

      # Compare results across different models
      model_results = {}
      for name, group in manager.group_by('model_name').items():
          model_results[name] = group.load_results()

      # Process the loaded results
      results = manager.filter(model_name='neural_crf').load_results()
      if results is not None:
          crps, log_likelihood = results.get_crps_and_nll()
          print(f"CRPS: {float(jnp.mean(crps))}")

      # Visualize a sample from results
      results = manager.filter(model_name='diffusion_crf').load_results()
      if results is not None:
          results.create_plot(index=0)  # Plot the first sample
    """
    all_results = []

    for ckpt in self.checkpointers:
      result = ckpt.load_all_results()
      if result is not None:
        all_results.append(result)

    if not all_results:
      return None

    return all_results

  def __len__(self):
    return len(self.checkpointers)

  def __getitem__(self, idx):
    if isinstance(idx, slice):
      return ExperimentManager(self.checkpointers[idx])
    else:
      return self.checkpointers[idx]

  def __str__(self):
    if len(self.checkpointers) == 0:
      return "ExperimentManager: No experiments"

    summary = f"ExperimentManager: {len(self.checkpointers)} experiments\n"
    summary += f"Attributes: {', '.join(col for col in self.df.columns if col != 'checkpointer' and col != 'index')}\n\n"

    # Show a preview of the experiments with progress
    preview_cols = ['config_name', 'model_name', 'objective', 'progress', 'completed']
    preview = self.df[preview_cols].head(5)
    summary += str(preview)

    if len(self.checkpointers) > 5:
      summary += f"\n... and {len(self.checkpointers) - 5} more"

    return summary

if __name__ == "__main__":
  manager = ExperimentManager()
  ar = manager.filter(model_name='my_autoregressive', config_name='double_pendulum', seed=1).load_results()
  import pdb; pdb.set_trace()
