# src/ablation_study_basis.py
# (This script is complete and ready to be used by the shell script)

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time
import copy

# --- Import all evaluation functions from your unified script ---
# (Make sure a correct version of this file exists in the same directory)
from unified_comparison import (
    run_recurrent_evaluation,
    run_adjoint_evaluation,
    run_mpc_evaluation,
    run_sbto_evaluation,
    run_neural_mpc_evaluation,
    generate_target_profile
)

# ===========================================================================
# 1. PLOTTING FUNCTION FOR ABLATION RESULTS
# ===========================================================================
def generate_ablation_plots(all_results, output_dir):
    """
    Generates and saves line plots summarizing the ablation study results.
    One plot for Final MSE vs. M, and one for Computation Time vs. M.
    """
    print("\n--- Generating Ablation Study Plots ---")
    os.makedirs(output_dir, exist_ok=True)

    basis_counts = sorted(all_results.keys())
    method_keys = list(all_results[basis_counts[0]].keys())
    
    styles = {
        'recurrent': {'label': 'PDE-OP (Ours)', 'color': 'r', 'marker': 'o'},
        'adjoint': {'label': 'Adjoint Method', 'color': 'm', 'marker': 's'},
        'mpc': {'label': 'LMPC', 'color': 'c', 'marker': '^'},
        'sbto': {'label': 'SBTO', 'color': 'g', 'marker': 'D'},
        'deeponet_mpc': {'label': 'DeepONet MPC', 'color': 'orange', 'marker': 'x'}
    }

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))

    # --- Plot 1: Final MSE vs. Number of Basis Functions ---
    for method in method_keys:
        mses = [all_results[m][method]['mse'] for m in basis_counts]
        ax1.plot(basis_counts, mses, **styles[method], markersize=10, linestyle='--')
    
    ax1.set_title('Performance vs. Basis Functions', fontsize=20)
    ax1.set_xlabel('Number of Basis Functions (M)', fontsize=16)
    ax1.set_ylabel('Final Mean Squared Error (MSE)', fontsize=16)
    ax1.set_yscale('log')
    ax1.grid(True, which="both", ls="--")
    ax1.tick_params(axis='both', which='major', labelsize=14)
    ax1.legend(fontsize=14)

    # --- Plot 2: Computation Time vs. Number of Basis Functions ---
    for method in method_keys:
        times = [all_results[m][method]['time'] for m in basis_counts]
        ax2.plot(basis_counts, times, **styles[method], markersize=10, linestyle='--')

    ax2.set_title('Computation Time vs. Basis Functions', fontsize=20)
    ax2.set_xlabel('Number of Basis Functions (M)', fontsize=16)
    ax2.set_ylabel('Total Time (seconds)', fontsize=16)
    ax2.grid(True, which="both", ls="--")
    ax2.tick_params(axis='both', which='major', labelsize=14)
    ax2.legend(fontsize=14)

    plt.tight_layout()
    save_path = os.path.join(output_dir, "ablation_study_basis_functions.pdf")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved ablation plots to: {save_path}")

# ===========================================================================
# 2. MAIN ABLATION ORCHESTRATOR
# ===========================================================================
def main(args):
    with open(args.config_path, 'r') as f:
        base_config = yaml.safe_load(f)

    all_ablation_results = {}
    target_type = 'sine' # Use a single, consistent target for the whole study

    for m in args.basis_counts:
        print(f"\n{'='*80}")
        print(f"--- EVALUATING ABLATION FOR M = {m} BASIS FUNCTIONS ---")
        print(f"{'='*80}")

        config = copy.deepcopy(base_config)
        config['NUM_BASIS_FUNCTIONS'] = m

        ablation_args = copy.deepcopy(args)
        ablation_args.run_id = args.run_id_template.format(m)

        x_grid = np.linspace(0, config['L'], config['M_SENSORS'])
        target_T_np = generate_target_profile(target_type, x_grid, config['L'])

        # --- Run all evaluations ---
        rec_states, _, rec_time = run_recurrent_evaluation(ablation_args, config, target_T_np)
        adj_states, _, adj_time = run_adjoint_evaluation(config, target_T_np)
        mpc_states, _, mpc_time = run_mpc_evaluation(config, target_T_np)
        sbto_states, _, sbto_time = run_sbto_evaluation(ablation_args, config, target_T_np)
        nmpc_states, _, nmpc_time = run_neural_mpc_evaluation(ablation_args, config, target_T_np)

        # --- Collect Metrics ---
        all_ablation_results[m] = {
            'recurrent': {'mse': np.mean((rec_states[-1] - target_T_np)**2), 'time': rec_time},
            'adjoint': {'mse': np.mean((adj_states[-1] - target_T_np)**2), 'time': adj_time},
            'mpc': {'mse': np.mean((mpc_states[-1] - target_T_np)**2), 'time': mpc_time},
            'sbto': {'mse': np.mean((sbto_states[-1] - target_T_np)**2), 'time': sbto_time},
            'deeponet_mpc': {'mse': np.mean((nmpc_states[-1] - target_T_np)**2), 'time': nmpc_time},
        }

    # --- Print Final Summary Table ---
    print(f"\n\n{'='*80}")
    print("--- ABLATION STUDY FINAL SUMMARY ---")
    print(f"{'='*80}")
    header = f"{'M':<5}" + "".join([f" | {method.upper():<20} (MSE / Time)" for method in all_ablation_results[args.basis_counts[0]].keys()])
    print(header)
    print("-" * len(header))
    for m in args.basis_counts:
        row = f"{m:<5}"
        for method in all_ablation_results[m].keys():
            mse = all_ablation_results[m][method]['mse']
            time = all_ablation_results[m][method]['time']
            row += f" | {mse:.2e} / {time:>6.2f}s "
        print(row)
    print("-" * len(header))

    # --- Generate and Save Plots ---
    output_dir = os.path.join(args.output_base_dir, "ablation_study_results")
    generate_ablation_plots(all_ablation_results, output_dir)

    np.savez_compressed(os.path.join(output_dir, "ablation_results.npz"), **all_ablation_results)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Ablation study on the number of basis functions for the 1D Heat Eq.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id_template", type=str, required=True)
    parser.add_argument("--basis_counts", type=int, nargs='+', required=True)
    # Include args for SBTO and Neural MPC from the main script
    parser.add_argument("--sbto_optim_steps", type=int, default=200)
    parser.add_argument("--sbto_lr", type=float, default=1e-2)
    parser.add_argument("--sbto_terminal_weight", type=float, default=1.0)
    parser.add_argument("--sbto_running_weight", type=float, default=0.1)
    parser.add_argument("--sbto_effort_weight", type=float, default=1e-5)
    parser.add_argument("--nmpc_horizon", type=int, default=10)
    parser.add_argument("--nmpc_optim_steps", type=int, default=25)
    parser.add_argument("--nmpc_lr", type=float, default=2e-2)
    parser.add_argument("--nmpc_terminal_weight", type=float, default=1.0)
    parser.add_argument("--nmpc_running_weight", type=float, default=0.1)
    parser.add_argument("--nmpc_effort_weight", type=float, default=1e-5)
    args = parser.parse_args()
    main(args)