"""
Utility functions for regression example.
"""

import os
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from scipy import stats

def aggregate_results(test_results):
   """
   Aggregate results from multiple test runs.

   Args:
      test_results: List of result dictionaries from multiple test runs

   Returns:
      dict: Aggregated results with mean and confidence intervals
   """
   # Create a new aggregated result dictionary
   aggregated = {}

   # Copy common attributes from the first result
   for key in ['name', 'config', 'resample_freq', 'continuous_adam', 'plot_data']:
      if key in test_results[0]:
         aggregated[key] = test_results[0][key]

   # Aggregate losses - calculate mean and 95% confidence intervals
   losses_all = [np.array(r['losses']) for r in test_results]
   losses_array = np.array(losses_all)

   # Calculate mean
   losses_mean = np.mean(losses_array, axis=0)

   # Calculate 95% confidence intervals using t-distribution
   n_runs = len(test_results)
   if n_runs > 1:
      # Since we plot in log scale, we should calculate confidence intervals in log domain
      # This ensures confidence intervals are always positive and look correct on log scale
      log_losses = np.log(np.maximum(1e-10, losses_array))  # Avoid log(0)
      log_mean = np.mean(log_losses, axis=0)
      log_se = stats.sem(log_losses, axis=0)
      t_val = stats.t.ppf(0.975, n_runs - 1)

      # Calculate confidence intervals in log domain
      log_ci_lower = log_mean - t_val * log_se
      log_ci_upper = log_mean + t_val * log_se

      # Transform back to linear domain
      ci_lower = np.exp(log_ci_lower)
      ci_upper = np.exp(log_ci_upper)

      # Optional: recalculate mean as geometric mean for consistency
      # losses_mean = np.exp(log_mean)  # Uncomment if you want geometric mean
   else:
      # With only one run, we can't calculate confidence intervals
      ci_lower = losses_mean
      ci_upper = losses_mean

   # Store in aggregated results
   aggregated['losses'] = losses_mean.tolist()
   aggregated['losses_ci_lower'] = ci_lower.tolist()
   aggregated['losses_ci_upper'] = ci_upper.tolist()

   # Aggregate time_per_step - calculate mean and 95% confidence intervals
   if 'time_per_step' in test_results[0]:
      time_steps_all = [np.array(r['time_per_step']) for r in test_results]

      # Handle potential different lengths by truncating to shortest
      min_length = min(len(t) for t in time_steps_all)
      time_steps_all = [t[:min_length] for t in time_steps_all]

      time_steps_array = np.array(time_steps_all)

      # Calculate mean
      time_steps_mean = np.mean(time_steps_array, axis=0)

      # Calculate 95% confidence intervals
      if n_runs > 1:
         # Note: for timing, we don't use log scale - but still ensure we don't have negative CIs
         se = stats.sem(time_steps_array, axis=0)
         ci_lower = np.maximum(0, time_steps_mean - t_val * se)  # Ensure non-negative
         ci_upper = time_steps_mean + t_val * se
      else:
         ci_lower = time_steps_mean
         ci_upper = time_steps_mean

      # Store in aggregated results
      aggregated['time_per_step'] = time_steps_mean.tolist()
      aggregated['time_per_step_ci_lower'] = ci_lower.tolist()
      aggregated['time_per_step_ci_upper'] = ci_upper.tolist()

   # Aggregate MSE - calculate mean and 95% confidence intervals
   mse_all = [r['mse'] for r in test_results]
   mse_mean = np.mean(mse_all)

   if n_runs > 1:
      # For MSE, we can use log domain too for better intervals with skewed distributions
      log_mse = np.log(np.maximum(1e-10, mse_all))
      log_mse_mean = np.mean(log_mse)
      log_mse_se = stats.sem(log_mse)

      # CI in log domain
      log_mse_ci_lower = log_mse_mean - t_val * log_mse_se
      log_mse_ci_upper = log_mse_mean + t_val * log_mse_se

      # Transform back
      mse_ci_lower = np.exp(log_mse_ci_lower)
      mse_ci_upper = np.exp(log_mse_ci_upper)

      # Alternative - ensure non-negative in linear domain
      # mse_se = stats.sem(mse_all)
      # mse_ci_lower = np.maximum(0, mse_mean - t_val * mse_se)
      # mse_ci_upper = mse_mean + t_val * mse_se
   else:
      mse_ci_lower = mse_mean
      mse_ci_upper = mse_mean

   # Store in aggregated results
   aggregated['mse'] = float(mse_mean)
   aggregated['mse_ci_lower'] = float(mse_ci_lower)
   aggregated['mse_ci_upper'] = float(mse_ci_upper)

   # Aggregate training_time - calculate mean
   aggregated['training_time'] = np.mean([r['training_time'] for r in test_results])

   return aggregated

def load_config(config_path="config.json"):
   """
   Load configuration from JSON file.
   
   Args:
      config_path: Path to the configuration file
      
   Returns:
      dict: Configuration dictionary
   """
   with open(config_path, 'r') as f:
      return json.load(f)

def save_results(results, test_name, data_dir, test_idx=None):
   """
   Save test results to a JSON file.

   Args:
      results: Results dictionary
      test_name: Name of the test
      data_dir: Directory to save results
      test_idx: Index for multiple test runs, used for file naming

   Returns:
      str: Path to the saved file
   """
   # Add test index suffix if provided
   suffix = f"_run{test_idx}" if test_idx is not None else ""
   output_file = os.path.join(data_dir, f"data_{test_name}{suffix}.json")

   with open(output_file, 'w') as f:
      json.dump(results, f, indent=3)

   print(f"Results saved to {output_file}")
   return output_file

def plot_results(config, data_dir, figure_dir,
               figures_data_dir=None, figures_loss_dir=None, figures_comparison_dir=None,
               pre_aggregated_results=None, show_confidence_intervals=False):
   """
   Generate comparison plots from saved test results.

   Args:
      config: Configuration dictionary
      data_dir: Directory containing test results
      figure_dir: Directory to save plot images
      figures_data_dir: Directory to save dataset visualization plots
      figures_loss_dir: Directory to save loss plots
      figures_comparison_dir: Directory to save comparison plots
      pre_aggregated_results: Optional pre-aggregated results from multiple runs
      show_confidence_intervals: Whether to show confidence intervals (default: False)
   """
   # Get plot_every parameter from config, default to 1 (plot every point)
   plot_every = config["problem_config"].get("plot_every", 1)
   # Set default directories if not provided
   if figures_data_dir is None:
      figures_data_dir = os.path.join(figure_dir, 'data')
   if figures_loss_dir is None:
      figures_loss_dir = os.path.join(figure_dir, 'loss')
   if figures_comparison_dir is None:
      figures_comparison_dir = os.path.join(figure_dir, 'comparison')

   # Ensure directories exist
   os.makedirs(figures_data_dir, exist_ok=True)
   os.makedirs(figures_loss_dir, exist_ok=True)
   os.makedirs(figures_comparison_dir, exist_ok=True)
   # If pre-aggregated results are provided, use them directly
   if pre_aggregated_results is not None:
      results = pre_aggregated_results
      return_early = True
   else:
      return_early = False
      # Check if we have multiple test runs
      ntests = config["problem_config"].get("ntests", 1)

   if return_early:
      pass  # Results already set from pre_aggregated_results
   elif ntests <= 1:
      # Original behavior - single run
      results = {}
      for test_name in config["tests"]:
         data_file = os.path.join(data_dir, f"data_{test_name}.json")
         if os.path.exists(data_file):
            with open(data_file, 'r') as f:
                  results[test_name] = json.load(f)
   else:
      # Load results from multiple test runs
      results = {}
      for test_name in config["tests"]:
         test_results = []
         for test_idx in range(ntests):
            data_file = os.path.join(data_dir, f"data_{test_name}_run{test_idx}.json")
            if os.path.exists(data_file):
               with open(data_file, 'r') as f:
                  test_results.append(json.load(f))

         if test_results:
            # Aggregate results from multiple runs
            results[test_name] = aggregate_results(test_results)
   
   if not results:
      print("No test results found. Run tests first.")
      return
   
   # Create figure directory if it doesn't exist
   os.makedirs(figure_dir, exist_ok=True)
   
   # Extract data for plotting
   names = [config["tests"][name]["name"] for name in results]
   total_epochs = len(next(iter(results.values()))["losses"])
   phase1_epochs = config["problem_config"]["phase1_epochs"]

   # Apply running average if plot_every > 1
   def apply_running_average(data, window_size):
      """Apply running average to data with specified window size."""
      if window_size <= 1:
         return data
      smoothed = []
      indices = range(0, len(data), window_size)
      for i in indices:
         smoothed.append(np.mean(data[i:i+window_size])) # Let Python handle index out of bounds
      return np.array(smoothed), np.array(indices) + 1

   # Prepare plot data
   plot_data = {
      "names": names,
      "losses": [],
      "indices": [],
      "times": [results[name]["training_time"] for name in results],
      "time_per_step": [],
      "mses": [results[name]["mse"] for name in results]
   }

   # Process losses and time_per_step with running average if needed
   for name in results:
      original_losses = np.array(results[name]["losses"])
      original_time_per_step = np.array(results[name].get("time_per_step", [results[name]["training_time"]]))

      if plot_every > 1:
         # Apply running average
         smoothed_losses, indices = apply_running_average(original_losses, plot_every)
         plot_data["losses"].append(smoothed_losses.tolist())
         plot_data["indices"].append(indices.tolist())

         # Also apply to time_per_step if available and same length as losses
         if len(original_time_per_step) == len(original_losses):
            smoothed_time = original_time_per_step[plot_every-1::plot_every].tolist()
            if len(original_losses) % plot_every != 0:
               last_idx = len(original_time_per_step) - 1
               smoothed_time.append(original_time_per_step[last_idx])
            plot_data["time_per_step"].append(smoothed_time)
         else:
            plot_data["time_per_step"].append(original_time_per_step.tolist())
      else:
         # Use original data
         plot_data["losses"].append(original_losses.tolist())
         plot_data["indices"].append(list(range(1, len(original_losses) + 1)))
         plot_data["time_per_step"].append(original_time_per_step.tolist())
   
   # Find best model by MSE
   best_model = min(results.items(), key=lambda x: x[1]["mse"])[0]
   best_model_name = config["tests"][best_model]["name"]
   
   # First, find Adam data and all its time information
   adam_data = None
   adam_test_name = None
   adam_phase1_time = 0
   adam_phase1_losses = []
   adam_phase2_time = 0

   # Get Adam's information first
   for test_name, test_data in results.items():
      if config["tests"][test_name]["name"] == "Adam":
         adam_test_name = test_name
         times = test_data.get("time_per_step", [])
         losses = test_data["losses"]

         if len(times) >= phase1_epochs:
            # Split Adam data into Phase I and Phase II
            adam_phase1_times = times[:phase1_epochs]
            adam_phase1_losses = losses[:phase1_epochs]

            # Get time at the end of Phase I - this is where other methods' Phase II will start
            adam_phase1_time = adam_phase1_times[-1] if adam_phase1_times else 0

            # Get Adam's Phase II data
            if len(times) > phase1_epochs:
               adam_phase2_times = times[phase1_epochs:]
               adam_phase2_losses = losses[phase1_epochs:]

               # Calculate Phase II duration for Adam
               adam_phase2_time = adam_phase2_times[-1] - adam_phase1_time

               adam_data = {
                  "full_times": times,
                  "full_losses": losses,
                  "phase1_times": adam_phase1_times,
                  "phase1_losses": adam_phase1_losses,
                  "phase2_times": adam_phase2_times,
                  "phase2_losses": adam_phase2_losses,
                  "phase2_duration": adam_phase2_time
               }
         break

   # ------------------------------------------------------------------------
   # 1. Create the epoch vs loss plot
   # ------------------------------------------------------------------------
   plt.figure(figsize=(10, 8))

   # Plot Adam first to ensure it's visible for the entire range
   adam_test = None
   non_adam_tests = []

   # Separate Adam from other optimizers
   for test_name in results:
      if config["tests"][test_name]["name"] == "Adam":
         adam_test = test_name
      else:
         non_adam_tests.append(test_name)

   # Plot Adam first (if it exists)
   if adam_test:
      # Get loss data from our processed plot_data
      adam_idx = list(results.keys()).index(adam_test)

      losses = np.array(plot_data["losses"][adam_idx]) / float(adam_data["full_losses"][0])
      epochs_range = plot_data["indices"][adam_idx]

      # Check if we have confidence interval data
      has_ci = "losses_ci_lower" in results[adam_test] and "losses_ci_upper" in results[adam_test]

      if has_ci and show_confidence_intervals:
         # If using running average, we also need to process the CIs
         if plot_every > 1:
            ci_lower, _ = apply_running_average(np.array(results[adam_test]["losses_ci_lower"]), plot_every)
            ci_upper, _ = apply_running_average(np.array(results[adam_test]["losses_ci_upper"]), plot_every)
         else:
            ci_lower = results[adam_test]["losses_ci_lower"]
            ci_upper = results[adam_test]["losses_ci_upper"]

         ci_lower = np.array(ci_lower) / float(adam_data["full_losses"][0])
         ci_upper = np.array(ci_upper) / float(adam_data["full_losses"][0])

         # Plot mean line
         plt.semilogy(epochs_range, losses,
                     label=config["tests"][adam_test]["name"], linewidth=2)

         # Plot confidence interval with semi-transparent fill
         plt.fill_between(epochs_range, ci_lower, ci_upper,
                        alpha=0.3, color=plt.gca().lines[-1].get_color())
      else:
         # Plot just the mean line without confidence intervals
         plt.semilogy(epochs_range, losses,
                     label=config["tests"][adam_test]["name"], linewidth=2)

   # Then plot other optimizers
   for i, test_name in enumerate(non_adam_tests):
      # Get index of this test in the plot_data
      test_idx = list(results.keys()).index(test_name)

      # Get smoothed loss data
      losses = np.array(plot_data["losses"][test_idx]) / float(adam_data["full_losses"][0])
      has_ci = "losses_ci_lower" in results[test_name] and "losses_ci_upper" in results[test_name]

      # Calculate adjusted phase1_epochs if using running average
      full_epochs_range = plot_data["indices"][test_idx]
      phase1_epochs_range = []
      phase2_epochs_range = []
      for i in range(len(full_epochs_range)):
         if full_epochs_range[i] <= phase1_epochs:
            phase1_epochs_range.append(full_epochs_range[i])
         else:
            phase2_epochs_range.append(full_epochs_range[i])
      
      # only plot Phase II data
      if has_ci and show_confidence_intervals:
         # Process CIs with running average if needed
         if plot_every > 1:
            ci_lower, _ = apply_running_average(np.array(results[test_name]["losses_ci_lower"]), plot_every)
            ci_upper, _ = apply_running_average(np.array(results[test_name]["losses_ci_upper"]), plot_every)
            ci_lower = ci_lower[len(phase1_epochs_range):]
            ci_upper = ci_upper[len(phase1_epochs_range):]
         else:
            ci_lower = results[test_name]["losses_ci_lower"][phase1_epochs:]
            ci_upper = results[test_name]["losses_ci_upper"][phase1_epochs:]

         ci_lower = np.array(ci_lower) / float(adam_data["full_losses"][0])
         ci_upper = np.array(ci_upper) / float(adam_data["full_losses"][0])

         # Plot mean line
         plt.semilogy(phase2_epochs_range, losses[len(phase1_epochs_range):],
                     label=config["tests"][test_name]["name"], linewidth=2)

         # Plot confidence interval with semi-transparent fill
         try:
            plt.fill_between(phase2_epochs_range, ci_lower, ci_upper,
                           alpha=0.3, color=plt.gca().lines[-1].get_color())
         except:
            print(f"Error plotting confidence intervals for epoch vs loss plot for {test_name}")
      else:
         # Plot just the mean line without confidence intervals
         plt.semilogy(phase2_epochs_range, losses[len(phase1_epochs_range):],
                     label=config["tests"][test_name]["name"], linewidth=2)

   # Add a vertical line to mark the transition from phase 1 to phase 2
   if phase1_epochs > 0:
      plt.axvline(x=phase1_epochs, color='gray', linestyle='--', alpha=0.7)

   plt.xlabel('Epoch', fontsize=20)
   plt.ylabel('Relative Loss', fontsize=20)
   plt.title('Epoch vs Relative Training Loss', fontsize=25)
   plt.grid(True)
   plt.legend(fontsize=20)
   plt.tight_layout()

   # Save the plot
   epoch_loss_file = os.path.join(figures_loss_dir, 'epoch_vs_loss.png')
   plt.savefig(epoch_loss_file, dpi=300)
   plt.close()

   # ------------------------------------------------------------------------
   # 2. Create the MSE bar chart
   # ------------------------------------------------------------------------
   """
   plt.figure(figsize=(10, 8))
   x_pos = np.arange(len(names))

   # Plot MSE bars
   mse_values = [results[name]["mse"] for name in results]
   plt.bar(x_pos, mse_values)

   # Add error bars for confidence intervals if available and enabled
   if show_confidence_intervals:
      for i, name in enumerate(results):
         # Check if we have confidence interval data
         if "mse_ci_lower" in results[name] and "mse_ci_upper" in results[name]:
            ci_lower = results[name]["mse_ci_lower"]
            ci_upper = results[name]["mse_ci_upper"]

            # Plot error bars for 95% confidence interval
            plt.errorbar(x_pos[i], mse_values[i],
                      yerr=[[mse_values[i]-ci_lower], [ci_upper-mse_values[i]]],
                      fmt='none', color='black', capsize=5)

   plt.ylabel('Mean Squared Error')
   plt.title('Final MSE Comparison (with 95% CI)')
   plt.xticks(x_pos, names, rotation=45, ha='right')
   plt.tight_layout()

   # Save the plot
   mse_file = os.path.join(figures_loss_dir, 'mse_comparison.png')
   plt.savefig(mse_file, dpi=300)
   plt.close()
   """
   
   # ------------------------------------------------------------------------
   # 3. Create the Time vs Loss plot
   # ------------------------------------------------------------------------
   plt.figure(figsize=(10, 8))

   # If we didn't find Adam data, we need to handle this case
   if not adam_data:
      # Single optimizer case without Adam as reference
      print("No Adam data found for time vs loss plot reference, adapting plot for single optimizer")

      # Just plot the time-vs-loss data for the single optimizer directly
      if len(results) == 1:
         test_name = next(iter(results.keys()))
         optimizer_name = config["tests"][test_name]["name"]

         # Get time step data
         if "time_per_step" in results[test_name]:
            times = results[test_name]["time_per_step"]
            losses = results[test_name]["losses"]

            # Check if we have confidence interval data
            has_ci = "losses_ci_lower" in results[test_name] and "losses_ci_upper" in results[test_name]

            # Plot the data directly
            if len(times) == len(losses):
               # Plot mean line
               plt.semilogy(times, losses, label=optimizer_name, linewidth=2)

               # Plot confidence intervals if available and enabled
               if has_ci and show_confidence_intervals:
                  ci_lower = results[test_name]["losses_ci_lower"]
                  ci_upper = results[test_name]["losses_ci_upper"]
                  plt.fill_between(times, ci_lower, ci_upper,
                                 alpha=0.3, color=plt.gca().lines[-1].get_color())

            # Set axis labels and title
            plt.xlabel('Optimization Time (s)')
            plt.ylabel('Loss (log scale)')
            plt.title('Time vs Loss')
            plt.grid(True)
            plt.legend()

            # Continue with the rest of the plots
            return

      # If multiple optimizers but no Adam, print warning and skip this plot
      print("Warning: No Adam data found for time vs loss plot reference, skipping time vs loss plot")
      return

   # Collect other optimizers' Phase II times relative to their own start
   # Our goal is to find the minimum Phase II duration
   optimizer_phase2_durations = {"Adam": adam_phase2_time}
   optimizer_data = {"Adam": adam_data}

   # Process non-Adam optimizers
   for test_name, test_data in results.items():
      optimizer_name = config["tests"][test_name]["name"]
      if optimizer_name != "Adam":
         times = test_data.get("time_per_step", [])
         losses = test_data["losses"]

         # We're only interested in Phase II for other optimizers
         if len(times) > phase1_epochs and len(losses) > phase1_epochs:
            # Get only Phase II portion
            phase2_times = times[phase1_epochs:]
            phase2_losses = losses[phase1_epochs:]

            if len(phase2_times) > 0:
               # Calculate how long Phase II took for this optimizer
               phase2_duration = phase2_times[-1] - phase2_times[0]
               optimizer_phase2_durations[optimizer_name] = phase2_duration

               # Store data for plotting later
               optimizer_data[optimizer_name] = {
                  "phase2_times": phase2_times,
                  "phase2_losses": phase2_losses,
                  "phase2_duration": phase2_duration
               }

   # Find the shortest Phase II duration among all optimizers
   min_phase2_duration = min(optimizer_phase2_durations.values())

   # Calculate where to cut off the x-axis
   # (Phase I time + shortest Phase II duration)
   max_plot_time = adam_phase1_time + min_phase2_duration

   print(f"Time plot: Phase I ends at {adam_phase1_time:.2f}s, shortest Phase II is {min_phase2_duration:.2f}s")
   print(f"Limiting time vs loss plot to: {max_plot_time:.2f}s")

   # Now plot Adam (full curve)
   if "Adam" in optimizer_data:
      # Get the index of Adam in our plot_data
      adam_idx = list(results.keys()).index(adam_test)

      # Use the processed time_per_step and losses if plot_every > 1
      if plot_every > 1:
         # We already have prepared data with running averages
         adam_times = plot_data["time_per_step"][adam_idx]
         adam_losses = plot_data["losses"][adam_idx]

         # Filter out points beyond max_plot_time
         plot_times = []
         plot_losses = []
         for i, t in enumerate(adam_times):
            if t <= max_plot_time and i < len(adam_losses):
               plot_times.append(t)
               plot_losses.append(adam_losses[i])
      else:
         # Get times and losses up to the cutoff point
         plot_times = []
         plot_losses = []
         for i, t in enumerate(adam_data["full_times"]):
            if t <= max_plot_time and i < len(adam_data["full_losses"]):
               plot_times.append(t)
               plot_losses.append(adam_data["full_losses"][i])

      # Check if we have confidence interval data
      has_ci = ("losses_ci_lower" in results[adam_test] and "losses_ci_upper" in results[adam_test] and
               "time_per_step_ci_lower" in results[adam_test] and "time_per_step_ci_upper" in results[adam_test])

      # If we have CIs and plot_every > 1, process them
      if has_ci and plot_every > 1:
         # Apply running average to CIs
         ci_lower_processed, _ = apply_running_average(np.array(results[adam_test]["losses_ci_lower"]), plot_every)
         ci_upper_processed, _ = apply_running_average(np.array(results[adam_test]["losses_ci_upper"]), plot_every)

         # Filter to match plot_times
         plot_ci_lower = []
         plot_ci_upper = []
         for i, t in enumerate(adam_times):
            if t <= max_plot_time and i < len(ci_lower_processed):
               plot_ci_lower.append(ci_lower_processed[i])
               plot_ci_upper.append(ci_upper_processed[i])
      # If we have CIs but plot_every is 1, collect original CIs
      elif has_ci:
         plot_ci_lower = []
         plot_ci_upper = []
         for i, t in enumerate(adam_data["full_times"]):
            if t <= max_plot_time and i < len(adam_data["full_losses"]):
               plot_ci_lower.append(results[adam_test]["losses_ci_lower"][i])
               plot_ci_upper.append(results[adam_test]["losses_ci_upper"][i])

      if plot_times:
         # Plot mean line
         relative_losses = np.array(plot_losses) / float(adam_data["full_losses"][0])
         plt.semilogy(plot_times, relative_losses, label="Adam", linewidth=2)

         # Plot confidence intervals if available and enabled
         if has_ci and show_confidence_intervals and plot_ci_lower and plot_ci_upper:
            relative_ci_lower = np.array(plot_ci_lower) / float(adam_data["full_losses"][0])
            relative_ci_upper = np.array(plot_ci_upper) / float(adam_data["full_losses"][0])
            plt.fill_between(plot_times, relative_ci_lower, relative_ci_upper,
                           alpha=0.3, color=plt.gca().lines[-1].get_color())

   # Plot other optimizers (Phase I from Adam + their own Phase II)
   for optimizer_name, data in optimizer_data.items():
      if optimizer_name != "Adam":
         # Find corresponding test name for this optimizer
         current_test = None
         for test_name in results:
            if config["tests"][test_name]["name"] == optimizer_name:
               current_test = test_name
               break

         if not current_test:
            continue

         # Get the index of this optimizer in our plot_data
         test_idx = list(results.keys()).index(current_test)

         # Check if we have confidence interval data
         has_ci = ("losses_ci_lower" in results[current_test] and
                  "losses_ci_upper" in results[current_test])

         # Calculate how many data points correspond to phase1 when using running averages
         if plot_every > 1:
            adam_idx = list(results.keys()).index(adam_test)
            # Number of phase1 data points in the running average
            adjusted_phase1 = (phase1_epochs + plot_every - 1) // plot_every
            # Get Adam's phase1 data from our processed plot_data
            phase1_times = plot_data["time_per_step"][adam_idx][:adjusted_phase1]
            phase1_losses = plot_data["losses"][adam_idx][:adjusted_phase1]

            # Get this optimizer's phase2 data
            all_times = plot_data["time_per_step"][test_idx]
            all_losses = plot_data["losses"][test_idx]

            if adjusted_phase1 < len(all_times):
               phase2_times = all_times[adjusted_phase1:]
               phase2_losses = all_losses[adjusted_phase1:]
            else:
               phase2_times = []
               phase2_losses = []

            # Prepare data for CI if needed
            if has_ci:
                # Process CIs for Phase 1 (from Adam)
                if "losses_ci_lower" in results[adam_test] and "losses_ci_upper" in results[adam_test]:
                    # Apply running average to Adam's CIs for Phase 1
                    adam_ci_lower, _ = apply_running_average(np.array(results[adam_test]["losses_ci_lower"]), plot_every)
                    adam_ci_upper, _ = apply_running_average(np.array(results[adam_test]["losses_ci_upper"]), plot_every)
                    adam_ci_lower = adam_ci_lower[:adjusted_phase1]
                    adam_ci_upper = adam_ci_upper[:adjusted_phase1]

                    # For Phase 2, use this optimizer's CIs
                    if adjusted_phase1 < len(apply_running_average(np.array(results[current_test]["losses_ci_lower"]), plot_every)[0]):
                        test_ci_lower, _ = apply_running_average(np.array(results[current_test]["losses_ci_lower"]), plot_every)
                        test_ci_upper, _ = apply_running_average(np.array(results[current_test]["losses_ci_upper"]), plot_every)
                        test_ci_lower = test_ci_lower[adjusted_phase1:]
                        test_ci_upper = test_ci_upper[adjusted_phase1:]
                    else:
                        test_ci_lower = []
                        test_ci_upper = []
         else:
            # For non-Adam optimizers, we use:
            # 1. Adam's Phase I data
            # 2. This optimizer's Phase II data, but with adjusted time

            # Start with Adam's Phase I data
            phase1_times = list(adam_data["phase1_times"])
            phase1_losses = list(adam_data["phase1_losses"])

            # Get Phase 2 data
            phase2_times = data["phase2_times"]
            phase2_losses = data["phase2_losses"]

            # Prepare CI data if available
            if has_ci:
                if "losses_ci_lower" in results[adam_test] and "losses_ci_upper" in results[adam_test]:
                    adam_ci_lower = list(results[adam_test]["losses_ci_lower"][:phase1_epochs])
                    adam_ci_upper = list(results[adam_test]["losses_ci_upper"][:phase1_epochs])

                    if len(results[current_test]["losses"]) > phase1_epochs:
                        test_ci_lower = list(results[current_test]["losses_ci_lower"][phase1_epochs:])
                        test_ci_upper = list(results[current_test]["losses_ci_upper"][phase1_epochs:])
                    else:
                        test_ci_lower = []
                        test_ci_upper = []

         # For plotting, we need:
         # 1. Phase 1 times and losses (from Adam)
         # 2. Phase 2 times adjusted to start after Adam's Phase 1

         # Start with Phase 1 data
         plot_times = list(phase1_times)
         plot_losses = list(phase1_losses)

         # If we have CI data
         if has_ci:
            plot_ci_lower = list(adam_ci_lower) if 'adam_ci_lower' in locals() else []
            plot_ci_upper = list(adam_ci_upper) if 'adam_ci_upper' in locals() else []

         # Now adjust Phase 2 times to start after Phase 1
         if len(phase2_times) > 0:
            # For running averages, we need to adjust the start time differently
            if plot_every > 1:
               # Get the last time point from Adam Phase 1
               phase1_end_time = phase1_times[-1] if phase1_times else adam_phase1_time

               # Adjust Phase 2 times to start right after Phase 1
               phase2_start_time = phase2_times[0]  # First time in Phase 2

               # Add adjusted Phase 2 points
               for i, t in enumerate(phase2_times):
                  # Calculate time relative to Phase 2 start
                  relative_time = t - phase2_start_time

                  # Adjust to start after Phase 1
                  adjusted_time = phase1_end_time + relative_time

                  # Only include points within our cutoff
                  if adjusted_time <= max_plot_time and i < len(phase2_losses):
                        plot_times.append(adjusted_time)
                        plot_losses.append(phase2_losses[i])

                        # Add CI data if available
                        if has_ci and i < len(test_ci_lower):
                           plot_ci_lower.append(test_ci_lower[i])
                           plot_ci_upper.append(test_ci_upper[i])
            else:
               # Original logic for non-running averages
               for i, t in enumerate(phase2_times):
                  # Calculate time relative to start of Phase II
                  relative_time = t - phase2_times[0]

                  # Add to Phase I end time to position correctly
                  adjusted_time = adam_phase1_time + relative_time

                  # Only include if within our plot cutoff
                  if adjusted_time <= max_plot_time and i < len(phase2_losses):
                        plot_times.append(adjusted_time)
                        plot_losses.append(phase2_losses[i])

                        if has_ci and i < len(test_ci_lower):
                           plot_ci_lower.append(test_ci_lower[i])
                           plot_ci_upper.append(test_ci_upper[i])

         # Now plot if we have points for Phase 2
         if len(plot_times) > len(phase1_times):  # Only if we added Phase II points
            # Find the index where Phase 1 ends
            phase1_end_idx = len(phase1_times)

            relative_losses = np.array(plot_losses) / float(adam_data["full_losses"][0])

            # Plot Phase I with low opacity
            #plt.semilogy(plot_times[:phase1_end_idx], relative_losses[:phase1_end_idx],
            #            alpha=0.3, linestyle='--', color='gray', linewidth=2)

            # Plot Phase II with full opacity
            plt.semilogy(plot_times[phase1_end_idx:], relative_losses[phase1_end_idx:],
                        label=optimizer_name, linewidth=2)

            # Plot confidence intervals if available and enabled
            if has_ci and show_confidence_intervals and len(plot_ci_lower) >= len(plot_times):
               # Phase II with CIs - with matching color
               ph2_color = plt.gca().lines[-1].get_color()
               relative_ci_lower = np.array(plot_ci_lower) / float(adam_data["full_losses"][0])
               relative_ci_upper = np.array(plot_ci_upper) / float(adam_data["full_losses"][0])
               try:
                  plt.fill_between(plot_times[phase1_end_idx:],
                                 relative_ci_lower[phase1_end_idx:],
                                 relative_ci_upper[phase1_end_idx:],
                                 alpha=0.3, color=ph2_color)
               except:
                  print(f"Error plotting confidence intervals for time vs loss plot for {optimizer_name}")

   # Add a vertical line to mark the transition from Phase I to Phase II
   if adam_phase1_time > 0:
      plt.axvline(x=adam_phase1_time, color='gray', linestyle='--', alpha=0.7)

   plt.xlabel('Optimization Time (s)', fontsize=20)
   plt.ylabel('Relative Loss', fontsize=20)
   plt.title('Time vs Relative Training Loss', fontsize=25)
   plt.xlim(0, max_plot_time)  # Set x-axis limit from 0 to max_plot_time
   plt.grid(True)
   plt.legend(fontsize=20)
   plt.tight_layout()

   # Save the time vs loss plot
   time_loss_file = os.path.join(figures_loss_dir, 'time_vs_loss.png')
   plt.savefig(time_loss_file, dpi=300)
   plt.close()

   # Get plot data for the best model
   plot_data = results[best_model]["plot_data"]
   X_test = np.array(plot_data["X_test"])
   Y_test = np.array(plot_data["Y_test"])
   Z_true = np.array(plot_data["Z_true"])
   Z_pred = np.array(plot_data["Z_pred"])

   # ------------------------------------------------------------------------
   # 4. Create the true function plot
   # ------------------------------------------------------------------------
   plt.figure(figsize=(10, 8))
   ax = plt.axes(projection='3d')
   surf = ax.plot_surface(X_test, Y_test, Z_true, cmap='viridis', alpha=0.8)
   ax.set_xlabel('X', fontsize=20)
   ax.set_ylabel('Y', fontsize=20)
   ax.set_zlabel('Z', fontsize=20)
   ax.set_title('True Franke Function', fontsize=25)
   plt.tight_layout()

   # Save the plot
   true_function_file = os.path.join(figures_comparison_dir, 'true_function.png')
   plt.savefig(true_function_file, dpi=300)
   plt.close()

   # ------------------------------------------------------------------------
   # 5. Create the best model prediction plot
   # ------------------------------------------------------------------------
   plt.figure(figsize=(10, 8))
   ax = plt.axes(projection='3d')
   surf = ax.plot_surface(X_test, Y_test, Z_pred, cmap='viridis', alpha=0.8)
   ax.set_xlabel('X')
   ax.set_ylabel('Y')
   ax.set_zlabel('Z')
   ax.set_title(f'Best Model ({best_model_name}) Prediction')
   plt.tight_layout()

   # Save the plot
   prediction_file = os.path.join(figures_comparison_dir, 'best_model_prediction.png')
   plt.savefig(prediction_file, dpi=300)
   plt.close()

   # ------------------------------------------------------------------------
   # 6. Create the error plot
   # ------------------------------------------------------------------------
   plt.figure(figsize=(10, 8))
   ax = plt.axes(projection='3d')
   error = np.abs(Z_pred - Z_true)
   surf = ax.plot_surface(X_test, Y_test, error, cmap='hot')
   ax.set_xlabel('X')
   ax.set_ylabel('Y')
   ax.set_zlabel('Error')
   ax.set_title(f'Absolute Error (MSE: {results[best_model]["mse"]:.6f})')
   plt.tight_layout()

   # Save the plot
   error_file = os.path.join(figures_comparison_dir, 'error.png')
   plt.savefig(error_file, dpi=300)
   plt.close()

   print(f"Loss plots saved to {figures_loss_dir}")
   print(f"Comparison plots saved to {figures_comparison_dir}")
   print(f"Data plots saved to {figures_data_dir}")

def clean_cache(model_dir):
   """
   Clean cached model states.
   
   Args:
      model_dir: Directory containing model files
      
   Returns:
      bool: True if any files were removed
   """
   model_dir = Path(model_dir)
   model_file = model_dir / "phase1_adam_model.pkl"
   results_file = model_dir / "phase1_adam_results.json"
   optimizer_file = model_dir / "phase1_adam_optimizer_state.pkl"
   
   print(f"Cleaning cached Phase 1 Adam model, results, and optimizer state from {model_dir}...")
   removed_files = False
   
   try:
      if model_file.exists():
         model_file.unlink()
         print(f"Removed {model_file}")
         removed_files = True
      else:
         print(f"{model_file} not found.")
   except Exception as e:
      print(f"Error removing {model_file}: {e}")

   try:
      if results_file.exists():
         results_file.unlink()
         print(f"Removed {results_file}")
         removed_files = True
      else:
         print(f"{results_file} not found.")
   except Exception as e:
      print(f"Error removing {results_file}: {e}")

   try:
      if optimizer_file.exists():
         optimizer_file.unlink()
         print(f"Removed {optimizer_file}")
         removed_files = True
      else:
         print(f"{optimizer_file} not found.")
   except Exception as e:
      print(f"Error removing {optimizer_file}: {e}")

   if not removed_files and not model_dir.exists():
      print(f"Cache directory {model_dir} does not exist or is empty.")
   elif not removed_files and model_dir.exists() and not any(model_dir.iterdir()):
      print(f"Cache directory {model_dir} is empty.")
      
   return removed_files