# src/ablation_study_basis.py
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time
from data_and_models import PropagatorDeepONet, RecurrentController

# ==============================================================================
# 1. TARGET GENERATION
# ==============================================================================
def generate_target(config, target_type, device):
    """Generates specific target profiles for 1D Heat Equation."""
    L = config['L']
    x = torch.linspace(0, L, config['M_SENSORS'], device=device)
    
    if target_type == 'sine':
        # Standard Heat Eq Sine: 0.6 + 0.3 * sin(2*pi*x)
        target = 0.6 + 0.3 * torch.sin(2 * np.pi * x / L)
        
    elif target_type == 'step':
        # Step: 0.6 in middle, 0 elsewhere
        target = torch.zeros_like(x)
        mask = (x > 0.3 * L) & (x < 0.7 * L)
        target[mask] = 0.6
        
    elif target_type == 'complex_gaussian':
        # Deterministic random seed for reproducibility inside this function
        rng = np.random.RandomState(42)
        target_np = np.zeros(config['M_SENSORS'])
        x_np = x.cpu().numpy()
        for _ in range(5):
            amp = rng.uniform(0.5, 1.5) * (1 if rng.rand() > 0.5 else -1)
            mean = rng.uniform(0.1 * L, 0.9 * L)
            sigma = rng.uniform(0.05 * L, 0.15 * L)
            target_np += amp * np.exp(-((x_np - mean)**2) / (2 * sigma**2))
        target = torch.from_numpy(target_np).float().to(device)
        
    else:
        raise ValueError(f"Unknown target: {target_type}")
        
    return target, x

# ==============================================================================
# 2. EVALUATION LOGIC
# ==============================================================================
def evaluate_single_m(m, args, config, device):
    # --- A. Load Models ---
    run_id = args.run_id_template.format(m)
    ctrl_dir = os.path.join(args.output_base_dir, run_id)
    ctrl_path = os.path.join(ctrl_dir, "recurrent_controller_model.pth")
    
    if not os.path.exists(ctrl_path):
        return None # Skip if model doesn't exist

    # Load Hyperparams
    with open(os.path.join(ctrl_dir, "hyperparams.yaml"), 'r') as f:
        ch_params = yaml.safe_load(f)
    
    prop_run_id = ch_params['deeponet_run_id']
    prop_dir = os.path.join(args.output_base_dir, prop_run_id)
    prop_path = os.path.join(prop_dir, "propagator_deeponet_best.pth")
    
    with open(os.path.join(prop_dir, "hyperparams.yaml"), 'r') as f:
        dh_params = yaml.safe_load(f)

    # Init Models
    ctrl_kwargs = {k: v for k, v in ch_params.items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = RecurrentController(
        M_sensors=config['M_SENSORS'], 
        num_basis_functions=m, 
        **ctrl_kwargs
    ).to(device)
    
    prop_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    prop_kwargs = {key: dh_params[key] for key in prop_arg_keys}
    propagator = PropagatorDeepONet(
        M_sensors=config['M_SENSORS'], 
        num_basis_functions=m, 
        trunk_input_dim=config['TRUNK_INPUT_DIM'], 
        **prop_kwargs
    ).to(device)

    controller.load_state_dict(torch.load(ctrl_path, map_location=device))
    propagator.load_state_dict(torch.load(prop_path, map_location=device))
    controller.eval()
    propagator.eval()

    # --- B. Run Inference on All Targets ---
    results_m = {}
    target_types = ['sine', 'step', 'complex_gaussian']
    
    # Sensor locations input for Propagator (Static)
    x_grid = torch.linspace(0, config['L'], config['M_SENSORS'], device=device)
    sensor_locs_input = x_grid.unsqueeze(0).unsqueeze(-1)

    for t_name in target_types:
        target_T, _ = generate_target(config, t_name, device)
        target_input = target_T.unsqueeze(0)
        
        T_current = torch.zeros(1, config['M_SENSORS'], device=device)
        
        # Timing Inference
        torch.cuda.synchronize() if device == "cuda" else None
        start_time = time.time()
        
        with torch.no_grad():
            hidden_state = None
            for _ in range(config['NT_SOLVER']):
                w_k, hidden_state = controller(T_current, target_input, hidden_state)
                T_current = propagator(T_current, w_k, sensor_locs_input).squeeze(-1)
        
        torch.cuda.synchronize() if device == "cuda" else None
        total_time = time.time() - start_time
        
        # Calculate MSE
        final_state = T_current.squeeze().cpu().numpy()
        target_np = target_T.cpu().numpy()
        mse = np.mean((final_state - target_np)**2)
        
        results_m[t_name] = {'mse': mse, 'time': total_time}
        
    return results_m

# ==============================================================================
# 3. PLOTTING
# ==============================================================================
def plot_ablation(all_data, basis_counts, output_dir):
    """Generates a single figure with 3 lines (one per target)."""
    plt.figure(figsize=(10, 6))
    
    targets = ['sine', 'step', 'complex_gaussian']
    colors = {'sine': 'blue', 'step': 'red', 'complex_gaussian': 'green'}
    markers = {'sine': 'o', 'step': 's', 'complex_gaussian': '^'}
    labels = {'sine': 'Sine Wave', 'step': 'Step Function', 'complex_gaussian': 'Complex Gaussian'}

    for t_name in targets:
        mses = []
        valid_ms = []
        for m in basis_counts:
            if m in all_data and all_data[m] is not None:
                mses.append(all_data[m][t_name]['mse'])
                valid_ms.append(m)
        
        if valid_ms:
            plt.plot(valid_ms, mses, marker=markers[t_name], color=colors[t_name], 
                     label=labels[t_name], linewidth=2, markersize=8)

    plt.xlabel('Number of Basis Functions ($M$)', fontsize=14)
    plt.ylabel('Final MSE (Log Scale)', fontsize=14)
    plt.yscale('log')
    plt.title('PDE-OP Performance vs. Basis Dimension', fontsize=16)
    plt.grid(True, which="both", ls="--", alpha=0.5)
    plt.legend(fontsize=12)
    plt.tight_layout()
    
    save_path = os.path.join(output_dir, "ablation_mse_vs_m.pdf")
    plt.savefig(save_path)
    print(f"\nPlot saved to: {save_path}")

# ==============================================================================
# 4. MAIN
# ==============================================================================
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id_template", type=str, required=True)
    parser.add_argument("--basis_counts", type=int, nargs='+', required=True)
    args = parser.parse_args()

    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # --- Data Collection ---
    all_data = {}
    print("\nRunning Ablation Evaluation...")
    for m in args.basis_counts:
        print(f"Processing M={m}...", end="\r")
        res = evaluate_single_m(m, args, config, device)
        all_data[m] = res
    print("Processing complete.        ")

    # --- Print Table ---
    print("\n" + "="*85)
    print(f"{'ABLATION RESULTS: Time & Accuracy vs. Basis Dimension (M)':^85}")
    print("="*85)
    
    header = f"{'M':<5} | {'Sine MSE':<12} {'Time(s)':<9} | {'Step MSE':<12} {'Time(s)':<9} | {'C.Gauss MSE':<12} {'Time(s)':<9}"
    print(header)
    print("-" * 85)

    for m in args.basis_counts:
        if all_data[m] is None:
            print(f"{m:<5} | {'MISSING':<22} | {'MISSING':<22} | {'MISSING':<22}")
            continue
            
        row_str = f"{m:<5} | "
        for t_name in ['sine', 'step', 'complex_gaussian']:
            d = all_data[m][t_name]
            row_str += f"{d['mse']:<12.2e} {d['time']:<9.4f} | "
        print(row_str.strip(" | "))

    print("="*85)

    # --- Generate Plot ---
    plot_ablation(all_data, args.basis_counts, args.output_base_dir)

if __name__ == "__main__":
    main()