"""
Test script for regression example.
"""

import jax
import jax.numpy as jnp
import numpy as np
import os
import argparse
import sys
import optax
import json
from foo.utils import get_file_directory
from model import create_model, evaluate_model
from data import generate_data, visualize_dataset
from optax_adam import (
   train_with_optax_adam, save_optax_adam_state, load_optax_adam_state
)
from optax_lbfgs import train_with_optax_lbfgs
from training import create_optimizer, train_optimizer
from utils import load_config, save_results, plot_results, clean_cache

def run_test(test_name, config, data_dir, figure_dir, run_phase1=True, args=None, test_idx=None):
   """
   Run a complete test for the specified optimizer.

   Args:
      test_name: Name of the test to run
      config: Configuration dictionary
      data_dir: Directory to save data
      figure_dir: Directory to save figures
      run_phase1: Whether to run Phase 1 (or load from cache)
      args: Command-line arguments
      test_idx: Index for multiple test runs, used for file naming

   Returns:
      dict: Test results
   """
   # Extract configurations
   problem_config = config["problem_config"]
   test_config = config["tests"][test_name]
   
   # Set the global physics coefficients in the model module
   import model
   model.diffusion_coeff = problem_config["diffusion_coeff"]
   model.convection_coeff = problem_config["convection_coeff"]

   # Check if we are using optimizers that will use the Optax implementation
   is_adam_test = test_config["name"].startswith("Adam")
   is_lbfgs_test = test_config["name"].startswith("L-BFGS")
   is_optax_test = is_adam_test or is_lbfgs_test

   # Setup JAX and data
   if problem_config["use_float64"]:
      dtype_jax = jnp.float64
      jax.config.update("jax_enable_x64", True)
   else:
      dtype_jax = jnp.float32

   # Get device
   jax_device = jax.devices()[0]

   # Set random seed
   seed = problem_config["seed"]
   np.random.seed(seed)
   key = jax.random.PRNGKey(seed)

   # Get Phase 2 seed for deterministic training
   phase2_seed = problem_config.get("phase2_seed", seed+1)

   # Get resampling frequency (default to 0 if not specified)
   resample_freq = problem_config.get("resample_freq", 0)

   # Generate initial dataset
   n_random_interior = problem_config["n_random_interior_batch"] * problem_config["n_batches"]
   n_close_interior = problem_config["n_close_interior_batch"] * problem_config["n_batches"]
   n_boundary = problem_config["n_boundary_batch"] * problem_config["n_batches"]
   epsilon = problem_config.get("epsilon", 0.01)
   
   random_interior_points, close_interior_points, boundary_points, interior_source_terms = generate_data(
      n_random_interior,
      n_close_interior,
      n_boundary,
      epsilon,
      seed
   )
   
   # Combine interior points for backward compatibility
   interior_points = jnp.concatenate([random_interior_points, close_interior_points], axis=0)
   f_interior = interior_source_terms

   # Visualize the initial dataset
   # Add test index suffix to filename if running multiple tests
   dataset_suffix = f"_{test_name}"
   if test_idx is not None:
      dataset_suffix += f"_run{test_idx}"

   # Use the figures/data subfolder for dataset visualizations
   figures_data_dir = os.path.join(figure_dir, 'data')
   os.makedirs(figures_data_dir, exist_ok=True)

   visualize_dataset(random_interior_points, close_interior_points, boundary_points, 
                  figures_data_dir, epsilon,
                  filename=f"dataset_visualization{dataset_suffix}.png")

   # Convert to JAX arrays
   X_interior = jnp.array(interior_points, dtype=dtype_jax)
   X_boundary = jnp.array(boundary_points, dtype=dtype_jax)
   f_interior = jnp.array(f_interior, dtype=dtype_jax)
   u_boundary = jnp.zeros((len(boundary_points), 1), dtype=dtype_jax)
   
   # Move to device if needed
   if jax_device:
      X_interior = jax.device_put(X_interior, jax_device)
      X_boundary = jax.device_put(X_boundary, jax_device)
      f_interior = jax.device_put(f_interior, jax_device)
      u_boundary = jax.device_put(u_boundary, jax_device)

   # Set the global coefficients in the model module before creating the model
   import model as model_module
   model_module.diffusion_coeff = problem_config["diffusion_coeff"]
   model_module.convection_coeff = problem_config["convection_coeff"]
   
   # Create model with identical initialization for fair comparison
   key, subkey = jax.random.split(key)
   model, state = create_model(
      subkey,
      problem_config["hidden_layers"],
      (X_interior[:1], X_boundary[:1]),
      dtype=dtype_jax
   )

   results = {}
   total_losses = []
   total_losses_pde = []
   total_losses_bc = []
   total_time = 0
   total_time_per_step = []  # Track cumulative time per optimization step

   # Initialize resampling counter to track across phases
   resampling_counter = 0

   # Define model save paths (for all tests)
   script_dir = get_file_directory()
   config_name = config.get("name", "default")
   outputs_dir = os.path.join(script_dir, 'outputs', config_name)
   MODEL_SAVE_DIR = os.path.join(outputs_dir, 'model')

   # Add test_idx suffix to file names if running multiple tests
   suffix = f"_run{test_idx}" if test_idx is not None else ""
   PHASE1_MODEL_FILE = os.path.join(MODEL_SAVE_DIR, f"phase1_adam_model{suffix}.pkl")
   PHASE1_RESULTS_FILE = os.path.join(MODEL_SAVE_DIR, f"phase1_adam_results{suffix}.json")
   PHASE1_OPTIMIZER_STATE_FILE = os.path.join(MODEL_SAVE_DIR, f"phase1_adam_optimizer_state{suffix}.pkl")
   os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

   continuous_adam = is_adam_test and run_phase1
   loaded_phase1_successfully = False
   state_after_phase1 = state
   optax_state = None

   # Phase 1: Attempt to load or run Adam first if requested
   if run_phase1:

      # Attempt to load cached Phase 1 Adam results
      try:
            print(f"Attempting to load cached Phase 1 Adam results from {MODEL_SAVE_DIR}...")
            # Load parameters using the class method
            _, loaded_params = model.__class__.load_model(name=f"phase1_adam_model{suffix}", path=MODEL_SAVE_DIR)

            # Update the current state's params
            state_after_phase1 = state.replace(params=loaded_params)
            state_after_phase1 = jax.device_put(state_after_phase1, jax_device) # Move loaded state to device

            # Load losses, time, and resampling_counter
            with open(PHASE1_RESULTS_FILE, 'r') as f:
               phase1_saved_results = json.load(f)

            loaded_phase1_losses = phase1_saved_results["losses"]
            loaded_phase1_time = phase1_saved_results["time"]
            loaded_phase1_time_per_step = phase1_saved_results.get("time_per_step", [loaded_phase1_time])
            # Retrieve the resampling counter to ensure we continue with the same resampling
            # pattern in Phase 2, regardless of whether Phase 1 was run or loaded from cache
            loaded_resampling_counter = phase1_saved_results.get("resampling_counter", 0)

            # Validate problem parameters to ensure compatibility
            current_params = {
               "n_random_interior_batch": problem_config["n_random_interior_batch"],
               "n_close_interior_batch": problem_config["n_close_interior_batch"],
               "n_boundary_batch": problem_config["n_boundary_batch"],
               "epsilon": problem_config["epsilon"],
               "diffusion_coeff": problem_config["diffusion_coeff"],
               "convection_coeff": problem_config["convection_coeff"],
               "n_batches": problem_config["n_batches"],
               "hidden_layers": problem_config["hidden_layers"],
               "use_float64": problem_config.get("use_float64", False),
               "resample_freq": problem_config.get("resample_freq", 0),
               "phase1_learning_rate": float(problem_config["phase1_learning_rate"]),
               "phase1_epochs": int(problem_config["phase1_epochs"]),
               "seed": int(problem_config["seed"])
            }

            loaded_params = phase1_saved_results.get("problem_params", {})
            # Check if essential parameters match
            params_match = True
            if loaded_params:
               # Critical parameters that must match exactly
               critical_params = ["seed", "phase1_learning_rate", "phase1_epochs", "epsilon", "diffusion_coeff", "convection_coeff"]

               for key, value in current_params.items():
                  if key in loaded_params and loaded_params[key] != value:
                        if key in critical_params:
                           print(f"ERROR: Critical parameter mismatch for '{key}': cached={loaded_params[key]}, current={value}")
                           print(f"      This parameter must match exactly for valid model loading")
                           params_match = False
                        else:
                           print(f"Warning: Parameter mismatch for '{key}': cached={loaded_params[key]}, current={value}")
            else:
               print("Warning: No problem parameters found in cached file")
               print("Critical parameters like learning rate, epochs, and seed cannot be validated")
               params_match = False

            if params_match:
               total_losses.extend(loaded_phase1_losses)
               total_time += loaded_phase1_time
               total_time_per_step.extend(loaded_phase1_time_per_step)
               resampling_counter = loaded_resampling_counter
               state = state_after_phase1  # Use loaded state for phase 2
               loaded_phase1_successfully = True
               print("Problem parameters validated successfully")
               print(f"Successfully loaded cached Phase 1 Adam results. Skipping local Phase 1 execution.")
            else:
               print("Parameter mismatch detected. Re-running phase 1 with current parameters.")
      except FileNotFoundError:
            print(f"No cached Phase 1 Adam results found at {MODEL_SAVE_DIR}. Local Phase 1 will run.")
      except Exception as e:
            print(f"Error loading cached Phase 1 Adam results: {e}. Local Phase 1 will run.")

      if not loaded_phase1_successfully:
         print("Executing Phase 1 (Adam)...")

         # Use Optax Adam for Phase 1 for all optimizer tests
         phase_tag = "Initial"
         print(f"Training with Phase 1 ({phase_tag} Optax Adam, lr={problem_config['phase1_learning_rate']})...")

         final_compat_state, phase1_states, phase1_losses, phase1_losses_pde, phase1_losses_bc, phase1_time, phase1_time_per_step, resampling_counter_after_p1, optax_state = train_with_optax_adam(
            state.params,              # Initial parameters
            model.optax_apply_fn,             # Model function
            state.batch_stats,         # Batch statistics
            ((X_interior, X_boundary), (f_interior, u_boundary)),# Training data
            problem_config["n_batches"],
            problem_config["phase1_epochs"],
            learning_rate=problem_config["phase1_learning_rate"],
            resample_freq=resample_freq,
            n_boundary_batch=problem_config["n_boundary_batch"],
            n_random_interior_batch=problem_config["n_random_interior_batch"],
            n_close_interior_batch=problem_config["n_close_interior_batch"],
            epsilon=problem_config["epsilon"],
            diffusion_coeff=problem_config["diffusion_coeff"],
            convection_coeff=problem_config["convection_coeff"],
            weight_pde=problem_config["weight_pde"],
            weight_bc=problem_config["weight_bc"],
            base_seed=seed,
            resample_counter=resampling_counter,
            print_every=problem_config["print_every"],
            phase_name=f"Phase 1 (Optax Adam, lr={problem_config['phase1_learning_rate']})",
            jax_device=jax_device,
            shuffle_seed=seed,
            dtype=dtype_jax
         )

         # Cache phase 1 result
         print(f"Saving Phase 1 Adam results to {MODEL_SAVE_DIR}...")

         # Save different state based on test type
         # Save parameters from Optax training
         model.save_model(final_compat_state.params, name=f"phase1_adam_model{suffix}", path=MODEL_SAVE_DIR)

         # Save Optax Adam state
         print(f"Saving Optax Adam state to {MODEL_SAVE_DIR}...")
         save_optax_adam_state(
            optax_state,
            MODEL_SAVE_DIR,
            PHASE1_OPTIMIZER_STATE_FILE
         )

         # Save problem parameters for validation
         problem_params = {
            "n_boundary_batch": problem_config["n_boundary_batch"],
            "n_random_interior_batch": problem_config["n_random_interior_batch"],
            "n_close_interior_batch": problem_config["n_close_interior_batch"],
            "epsilon": problem_config["epsilon"],
            "diffusion_coeff": problem_config["diffusion_coeff"],
            "convection_coeff": problem_config["convection_coeff"],
            "n_batches": problem_config["n_batches"],
            "hidden_layers": problem_config["hidden_layers"],
            "use_float64": problem_config.get("use_float64", False),
            "resample_freq": problem_config.get("resample_freq", 0),
            # Critical validation parameters:
            "phase1_learning_rate": float(problem_config["phase1_learning_rate"]),
            "phase1_epochs": int(problem_config["phase1_epochs"]),
            "seed": int(problem_config["seed"])
         }

         with open(PHASE1_RESULTS_FILE, 'w') as f:
            json.dump({
               "losses": phase1_losses,
               "losses_pde": phase1_losses_pde,
               "losses_bc": phase1_losses_bc,
               "time": float(phase1_time),
               "time_per_step": phase1_time_per_step,  # Save time per step for plotting
               "resampling_counter": int(resampling_counter_after_p1),  # Save final resampling counter to ensure continuity in Phase 2
               "problem_params": problem_params
            }, f, indent=3)
         print("Phase 1 Adam results saved/cached.")

         # Accumulate results from local Phase 1 run
         total_losses.extend(phase1_losses)
         total_losses_pde.extend(phase1_losses_pde)
         total_losses_bc.extend(phase1_losses_bc)
         total_time += phase1_time
         total_time_per_step.extend(phase1_time_per_step)

         # Update state based on test type
         state_after_phase1 = final_compat_state  # For Optax Adam, use compatible state

         resampling_counter = resampling_counter_after_p1  # Update resampling counter
      
      # Phase 2: Run the specified optimizer (or second part of continuous optax optimizer)
      if is_optax_test:  # For Optax-based optimizer tests, use the appropriate workflow
         
         phase2_lr = test_config["config"]["optimizer"]["learning_rate"]

         if is_adam_test:
            # --- Adam specific code for Phase 2 ---
            # Loaded Phase I state from cache, try to load Optax state
            # This should always be possible since if we don't have Phase I we should have aleady executed the Phase 1
            print(f"Setting up Phase 2 Optax Adam with lr={phase2_lr}...")

            # Try to load saved Optax Adam state
            optimizer_state_loaded = False
            loaded_optax_state = None

            if os.path.exists(PHASE1_OPTIMIZER_STATE_FILE):
               print(f"Loading Optax Adam state from {PHASE1_OPTIMIZER_STATE_FILE}...")
               loaded_optax_state, optimizer_state_loaded = load_optax_adam_state(
                  MODEL_SAVE_DIR,
                  PHASE1_OPTIMIZER_STATE_FILE,
                  state_after_phase1.apply_fn,
                  phase2_lr
               )

            if optimizer_state_loaded and loaded_optax_state is not None:
               print(f"Training with Phase 2 (Optax Adam, lr={phase2_lr}, with momentum, seed={phase2_seed})...")
               loaded_optax_state = loaded_optax_state.replace(apply_fn=model.optax_apply_fn)

               # Use loaded Optax state for Phase 2 training
               final_compat_state, phase2_states, phase2_losses, phase2_losses_pde, phase2_losses_bc, phase2_time, phase2_time_per_step, resampling_counter, _ = train_with_optax_adam(
                  loaded_optax_state.params,
                  loaded_optax_state.apply_fn,
                  loaded_optax_state.batch_stats,
                  ((X_interior, X_boundary), (f_interior, u_boundary)),
                  problem_config["n_batches"],
                  problem_config["phase2_epochs"],
                  learning_rate=phase2_lr,
                  resample_freq=resample_freq,
                  n_boundary_batch=problem_config["n_boundary_batch"],
                  n_random_interior_batch=problem_config["n_random_interior_batch"],
                  n_close_interior_batch=problem_config["n_close_interior_batch"],
                  epsilon=problem_config["epsilon"],
                  diffusion_coeff=problem_config["diffusion_coeff"],
                  convection_coeff=problem_config["convection_coeff"],
                  weight_pde=problem_config["weight_pde"],
                  weight_bc=problem_config["weight_bc"],
                  base_seed=phase2_seed,
                  resample_counter=resampling_counter,
                  print_every=problem_config["print_every"],
                  phase_name=f"Phase 2 (Optax Adam, lr={phase2_lr}, with momentum)",
                  jax_device=jax_device,
                  shuffle_seed=phase2_seed,
                  current_optax_state = loaded_optax_state,
                  dtype=dtype_jax
               )
            else:
               print(f"Failed to load Optax Adam state or no state file exists, something is wrong")
               sys.exit(1)

         elif is_lbfgs_test:
            # --- L-BFGS specific code for Phase 2 ---
            # Get L-BFGS specific parameters
            optimizer_config = test_config["config"]["optimizer"]
            max_linesearch_steps = optimizer_config.get("max_linesearch_steps", 20)
            memory_size = optimizer_config.get("memory_size", 10)
            scale_init_precond = optimizer_config.get("scale_init_precond", True)
            sufficient_decrease = optimizer_config.get("sufficient_decrease", 0.1)
            curvature = optimizer_config.get("curvature", 0.9)
            initial_guess_strategy = optimizer_config.get("initial_guess_strategy", "one")
            # Validate initial_guess_strategy (optax only accepts 'one' or 'keep')
            if initial_guess_strategy not in ["one", "keep"]:
               print(f"Warning: Invalid initial_guess_strategy '{initial_guess_strategy}'. Using 'one' instead.")
               initial_guess_strategy = "one"

            # Construct a readable description of the optimizer
            optimizer_desc = f"Optax L-BFGS (memory={memory_size}, steps={max_linesearch_steps})"
            print(f"Training with Phase 2 ({optimizer_desc}, seed={phase2_seed})...")

            # Use Optax L-BFGS for Phase 2 training
            final_compat_state, phase2_states, phase2_losses, phase2_losses_pde, phase2_losses_bc, phase2_time, phase2_time_per_step, resampling_counter, optax_state = train_with_optax_lbfgs(
               state_after_phase1.params,
               model.optax_apply_fn,
               state_after_phase1.batch_stats,
               ((X_interior, X_boundary), (f_interior, u_boundary)),
               problem_config["n_batches"],
               problem_config["phase2_epochs"],
               max_linesearch_steps=max_linesearch_steps,
               memory_size=memory_size,
               scale_init_precond=scale_init_precond,
               sufficient_decrease=sufficient_decrease,  # Used as slope_rtol in implementation
               curvature=curvature,                     # Used as curv_rtol in implementation
               initial_guess_strategy=initial_guess_strategy,
               resample_freq=resample_freq,
               n_boundary_batch=problem_config["n_boundary_batch"],
               n_random_interior_batch=problem_config["n_random_interior_batch"],
               n_close_interior_batch=problem_config["n_close_interior_batch"],
               epsilon=problem_config["epsilon"],
               diffusion_coeff=problem_config["diffusion_coeff"],
               convection_coeff=problem_config["convection_coeff"],
               weight_pde=problem_config["weight_pde"],
               weight_bc=problem_config["weight_bc"],
               base_seed=phase2_seed,
               resample_counter=resampling_counter,
               print_every=problem_config["print_every"],
               phase_name=f"Phase 2 ({optimizer_desc})",
               jax_device=jax_device,
               shuffle_seed=phase2_seed,
               dtype=dtype_jax
            )
      else:
         # For non-Optax optimizers, use our own workflow
         phase2_optimizer = create_optimizer(
               test_config["name"],
               state_after_phase1,
               test_config["config"],
               model,
               weights_pde=problem_config["weight_pde"],
               weights_bc=problem_config["weight_bc"]
         )

         print(f"Training with Phase 2 ({test_config['name']}, seed={phase2_seed})...")
         phase2_states, phase2_losses, phase2_losses_pde, phase2_losses_bc, phase2_time, phase2_time_per_step, resampling_counter = train_optimizer(
               phase2_optimizer,
               state_after_phase1,
               ((X_interior, X_boundary), (f_interior, u_boundary)),
               problem_config["n_batches"],
               problem_config["phase2_epochs"],
               resample_freq,
               problem_config["n_boundary_batch"],
               problem_config["n_random_interior_batch"],
               problem_config["n_close_interior_batch"],
               problem_config["epsilon"],
               problem_config["diffusion_coeff"],
               problem_config["convection_coeff"],
               problem_config["weight_pde"],
               problem_config["weight_bc"],
               phase2_seed,  # Use phase2_seed for resampling
               resampling_counter,
               problem_config["print_every"],
               f"Phase 2 ({test_config['name']})",
               jax_device,
               dtype_jax,
               shuffle_seed=phase2_seed  # Use phase2_seed for data shuffling
         )
      
      # Save results from phase 2
      total_losses.extend(phase2_losses)
      total_losses_pde.extend(phase2_losses_pde)
      total_losses_bc.extend(phase2_losses_bc)
      total_time += phase2_time

      # Add the time_per_step from phase 2, but offset by the last value from phase 1
      if total_time_per_step:
         last_p1_time = total_time_per_step[-1]
         adjusted_p2_times = [last_p1_time + t for t in phase2_time_per_step]
         total_time_per_step.extend(adjusted_p2_times)
      else:
         total_time_per_step.extend(phase2_time_per_step)

      # Use the final state for evaluation
      if is_adam_test or is_lbfgs_test:
         final_state = final_compat_state  # For Optax-based optimizers, use compatible state
      else:
         final_state = phase2_states[-1]  # For original optimizers, use final state
   
   # Prepare test data for evaluation
   epsilon = problem_config["epsilon"]
   X_eval, Y_source, G_pred, laplacian_pred, laplacian_true = evaluate_model(
      model, final_state.params, resolution=500, epsilon=epsilon
   )
   
   # Calculate error metrics for the operator (PDE residual)
   mse_operator = jnp.mean((laplacian_pred - laplacian_true) ** 2)
   mse_g = jnp.mean(G_pred ** 2)  # MSE of Green's function itself (no true solution to compare)
   
   print(f"{test_config['name']} final MSE (G): {mse_g:.6f}")
   print(f"{test_config['name']} final MSE (Operator): {mse_operator:.6f}")
   
   # Save results
   results = {
      "name": test_config["name"],
      "losses": total_losses,
      "losses_pde": total_losses_pde,
      "losses_bc": total_losses_bc,
      "training_time": total_time,
      "time_per_step": total_time_per_step,  # Include time per step for time-vs-loss plots
      "mse": float(mse_operator),  # Use operator MSE as the primary error metric
      "mse_g": float(mse_g),       # Also save the Green's function MSE
      "config": test_config["config"],
      "resample_freq": resample_freq,
      "continuous_adam": continuous_adam,
      "epsilon": float(epsilon)     # Save epsilon for reference
   }
   
   # Save test grid data for plotting
   results["plot_data"] = {
      "X_eval": X_eval.tolist(),
      "Y_source": Y_source.tolist(),
      "G_pred": G_pred.tolist(),
      "laplacian_pred": laplacian_pred.tolist(),
      "laplacian_true": laplacian_true.tolist()
   }
   
   # Save results to file
   save_results(results, test_name, data_dir, test_idx)
   return results

def run_all_tests(config, data_dir, figure_dir, args=None):
   """
   Run all tests defined in the configuration.
   If ntests > 1, runs each test multiple times with different seeds.
   """
   # Check if multiple tests are requested
   ntests = config["problem_config"].get("ntests", 1)

   # Dictionary to store results from all tests
   # If ntests > 1, the format will be:
   # results[test_name][test_idx] = {...} - test results for a specific run
   results = {}

   # Base seeds from config
   base_seed = config["problem_config"]["seed"]
   base_phase2_seed = config["problem_config"].get("phase2_seed", base_seed + 1)

   if ntests <= 1:
      # Original behavior - single run
      for test_name in config["tests"]:
         print(f"\n=========== Starting {test_name} test ===========")
         results[test_name] = run_test(test_name, config, data_dir, figure_dir, run_phase1=True, args=args)
   else:
      # Multiple runs with different seeds
      for test_name in config["tests"]:
         results[test_name] = []

         for test_idx in range(ntests):
            # Modify seeds for this specific test
            modified_config = config.copy()
            modified_config["problem_config"] = config["problem_config"].copy()

            # Set unique seeds for this test run
            modified_config["problem_config"]["seed"] = base_seed + test_idx
            modified_config["problem_config"]["phase2_seed"] = base_phase2_seed + test_idx

            print(f"\n=========== Starting {test_name} test (run {test_idx+1}/{ntests}) ===========")

            # Run the test with modified seeds
            test_result = run_test(test_name, modified_config, data_dir, figure_dir,
                                  run_phase1=True, args=args, test_idx=test_idx)

            # Store the results
            results[test_name].append(test_result)

   return results

def main():
   parser = argparse.ArgumentParser(description='Run regression tests for optimizers')
   parser.add_argument('--test', type=str, help='Run a specific test')
   parser.add_argument('--all', action='store_true', help='Run all tests')
   parser.add_argument('--plot', action='store_true', help='Generate plots from existing data')
   parser.add_argument('--clean', action='store_true', help='Remove cached Phase 1 Adam model, results, and optimizer state')
   parser.add_argument('--config', type=str, default='config.json', help='Path to custom configuration file')
   parser.add_argument('--ci', action='store_true', help='Show confidence intervals on plots (overrides config setting)')
   parser.add_argument('--no-ci', action='store_true', help='Do not show confidence intervals on plots (overrides config setting, this will override --ci)')
   args = parser.parse_args()

   # Get script directory
   script_dir = get_file_directory()

   # Use custom config if provided, otherwise use default
   if args.config != 'config.json' and not os.path.isabs(args.config):
      # If relative path is provided, make it relative to script directory
      config_path = os.path.join(script_dir, args.config)
   else:
      # Either it's the default 'config.json' or an absolute path
      config_path = args.config if os.path.isabs(args.config) else os.path.join(script_dir, args.config)

   # Check if config file exists
   if not os.path.exists(config_path):
      print(f"Error: Configuration file '{config_path}' not found.")
      sys.exit(1)

   print(f"Using configuration from: {config_path}")

   # Load configuration
   config = load_config(config_path)

   # Get config name (with default value if not specified)
   config_name = config.get("name", "default")
   print(f"Using configuration name: {config_name}")

   # Create outputs directory structure based on config name
   outputs_dir = os.path.join(script_dir, 'outputs', config_name)
   data_dir = os.path.join(outputs_dir, 'data')
   figure_dir = os.path.join(outputs_dir, 'figures')
   MODEL_SAVE_DIR = os.path.join(outputs_dir, 'model')

   # Create subdirectories for different types of figures
   figures_data_dir = os.path.join(figure_dir, 'data')
   figures_loss_dir = os.path.join(figure_dir, 'loss')
   figures_comparison_dir = os.path.join(figure_dir, 'comparison')

   if args.clean:
      clean_cache(MODEL_SAVE_DIR)
      return # Exit after cleaning

   # Ensure directories exist
   os.makedirs(data_dir, exist_ok=True)
   os.makedirs(figure_dir, exist_ok=True)
   os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
   os.makedirs(figures_data_dir, exist_ok=True)
   os.makedirs(figures_loss_dir, exist_ok=True)
   os.makedirs(figures_comparison_dir, exist_ok=True)

   # Check if Adam is the first test if running all or default
   # This is important for the caching strategy: Adam run creates the cache.
   if args.all or (not args.test and not args.plot and not args.clean): # Default execution path
      test_names_in_config = list(config["tests"].keys())
      if not test_names_in_config:
         print("Error: No tests defined in config.json.")
         sys.exit(1)

   # Process command line arguments
   if args.test:
      if args.test in config["tests"]:
         # Check if multiple tests are requested
         ntests = config["problem_config"].get("ntests", 1)

         if ntests <= 1:
            # Run single test (original behavior)
            run_test(args.test, config, data_dir, figure_dir, args=args)
         else:
            # Run multiple instances of the same test with different seeds
            print(f"Running {ntests} repetitions of {args.test} test...")

            # Get base seeds from config
            base_seed = config["problem_config"]["seed"]
            base_phase2_seed = config["problem_config"].get("phase2_seed", base_seed + 1)

            # Results container
            test_results = []

            # Run test multiple times with different seeds
            for test_idx in range(ntests):
               # Modify seeds for this specific test
               modified_config = config.copy()
               modified_config["problem_config"] = config["problem_config"].copy()

               # Set unique seeds for this test run
               modified_config["problem_config"]["seed"] = base_seed + test_idx
               modified_config["problem_config"]["phase2_seed"] = base_phase2_seed + test_idx

               print(f"\n=========== Starting {args.test} test (run {test_idx+1}/{ntests}) ===========")

               # Run the test with modified seeds
               test_result = run_test(args.test, modified_config, data_dir, figure_dir,
                                    run_phase1=True, args=args, test_idx=test_idx)

               # Store the results
               test_results.append(test_result)
            # do not plot in this setting
      else:
         print(f"Error: Test '{args.test}' not found in configuration")
   elif args.all:
      run_all_tests(config, data_dir, figure_dir, args=args)
   elif args.plot:
      from utils import plot_results
      # Check if confidence intervals should be shown (cmd line overrides config)
      show_ci = args.ci or config["problem_config"].get("show_confidence_intervals", False)
      if args.no_ci:
         # If --no-ci is provided, override all config settings
         show_ci = False

      # Create figure subdirectories
      figures_data_dir = os.path.join(figure_dir, 'data')
      figures_loss_dir = os.path.join(figure_dir, 'loss')
      figures_comparison_dir = os.path.join(figure_dir, 'comparison')

      # Ensure directories exist
      os.makedirs(figures_data_dir, exist_ok=True)
      os.makedirs(figures_loss_dir, exist_ok=True)
      os.makedirs(figures_comparison_dir, exist_ok=True)

      plot_results(config, data_dir, figure_dir,
                  figures_data_dir=figures_data_dir,
                  figures_loss_dir=figures_loss_dir,
                  figures_comparison_dir=figures_comparison_dir,
                  show_confidence_intervals=show_ci)
   else:
      # No arguments: run all tests and plot
      run_all_tests(config, data_dir, figure_dir, args=args)
      from utils import plot_results
      # Check if confidence intervals should be shown (cmd line overrides config)
      show_ci = args.ci or config["problem_config"].get("show_confidence_intervals", False)
      if args.no_ci:
         # If --no-ci is provided, override all config settings
         show_ci = False

      # Create figure subdirectories
      figures_data_dir = os.path.join(figure_dir, 'data')
      figures_loss_dir = os.path.join(figure_dir, 'loss')
      figures_comparison_dir = os.path.join(figure_dir, 'comparison')

      # Ensure directories exist
      os.makedirs(figures_data_dir, exist_ok=True)
      os.makedirs(figures_loss_dir, exist_ok=True)
      os.makedirs(figures_comparison_dir, exist_ok=True)

      plot_results(config, data_dir, figure_dir,
                  figures_data_dir=figures_data_dir,
                  figures_loss_dir=figures_loss_dir,
                  figures_comparison_dir=figures_comparison_dir,
                  show_confidence_intervals=show_ci)

if __name__ == "__main__":
   main()