# src/train_recurrent_controller.py
# (With modifications for ablation study automation)

import torch
import torch.nn as nn
import torch.optim as optim
import os
import yaml
import argparse
import numpy as np
import matplotlib.pyplot as plt

from data_and_models import PropagatorDeepONet, RecurrentController

def generate_random_targets(config, batch_size, device):
    # This function remains unchanged
    targets = torch.zeros(batch_size, config['M_SENSORS'], device=device)
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=device)
    for i in range(batch_size):
        num_waves = np.random.randint(1, 4)
        for _ in range(num_waves):
            amplitude = torch.randn(1, device=device).item() * 0.7
            frequency = torch.randn(1, device=device).item() * 3.0
            phase = torch.rand(1, device=device).item() * 2 * np.pi
            targets[i, :] += amplitude * torch.sin(frequency * sensor_locs * np.pi + phase)
    mean_val = torch.rand(1, device=device).item() * 1.0 + 0.25
    for i in range(batch_size):
        targets[i, :] = mean_val + (targets[i, :] - torch.mean(targets[i, :]))
    return torch.clip(targets, -1.5, 1.5)

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # --- Override config with command-line arg if provided ---
    if args.num_basis_functions is not None:
        print(f"Overriding NUM_BASIS_FUNCTIONS from config. Using M = {args.num_basis_functions}.")
        config['NUM_BASIS_FUNCTIONS'] = args.num_basis_functions

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training Heat Eq. Recurrent Controller: {args.run_id} (M={config['NUM_BASIS_FUNCTIONS']}) ---")
    print(f"Using device: {DEVICE}")
    print(f"Using frozen Propagator simulator from run: {args.deeponet_run_id}")

    # --- 2. Load the FROZEN Physics Simulator (PropagatorDeepONet) ---
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "propagator_deeponet_best.pth")
    deeponet_hyperparams_path = os.path.join(deeponet_run_dir, "hyperparams.yaml")
    if not os.path.exists(deeponet_model_path):
        print(f"FATAL: Propagator model not found at {deeponet_model_path}"); return
    with open(deeponet_hyperparams_path, 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    for param in physics_simulator.parameters(): param.requires_grad = False
    print("Pre-trained PropagatorDeepONet loaded and frozen.")

    # --- 3. Initialize Controller and Optimizer ---
    controller_kwargs = {k: v for k, v in vars(args).items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], **controller_kwargs).to(DEVICE)
    
    if args.optimizer.lower() == 'adamw': optimizer = optim.AdamW(controller.parameters(), lr=args.learning_rate)
    else: optimizer = optim.Adam(controller.parameters(), lr=args.learning_rate)
    mse_loss_fn = nn.MSELoss()

    # --- 4. The Self-Supervised Training Loop ---
    print("\n--- Starting Self-Supervised Training via Differentiable Rollout ---")
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    total_losses, tracking_losses, effort_losses, terminal_losses = [], [], [], []

    for epoch in range(config['CONTROLLER_EPOCHS']):
        controller.train()
        T_final_target = generate_random_targets(config, config['CONTROLLER_BATCH_SIZE'], DEVICE)
        T_current = torch.full((config['CONTROLLER_BATCH_SIZE'], config['M_SENSORS']), 0.0, device=DEVICE)
        controller_hidden_state = None
        total_effort, tracking_loss = 0.0, 0.0

        for _ in range(config['NT_SOLVER'] - 1):
            w_k, controller_hidden_state = controller(T_current, T_final_target, controller_hidden_state)
            total_effort += torch.mean(torch.sum(w_k**2, dim=1))
            T_current = physics_simulator(T_current, w_k, sensor_locs).squeeze(-1)
            tracking_loss += mse_loss_fn(T_current, T_final_target)
        
        w_k, controller_hidden_state = controller(T_current, T_final_target, controller_hidden_state)
        total_effort += torch.mean(torch.sum(w_k**2, dim=1))
        T_current = physics_simulator(T_current, w_k, sensor_locs).squeeze(-1)
        
        terminal_loss = mse_loss_fn(T_current, T_final_target)
        effort_loss = total_effort / config['NT_SOLVER']
        average_tracking_loss = tracking_loss / config['NT_SOLVER']
        
        total_loss = (args.terminal_weight * terminal_loss +
                      args.running_weight * average_tracking_loss +
                      args.effort_weight * effort_loss)

        optimizer.zero_grad(); total_loss.backward()
        torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0)
        optimizer.step()
        
        total_losses.append(total_loss.item()); terminal_losses.append(terminal_loss.item())
        tracking_losses.append(args.running_weight * average_tracking_loss.item())
        effort_losses.append(args.effort_weight * effort_loss.item())
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{config['CONTROLLER_EPOCHS']} | Total Loss: {total_loss.item():.4f} | Terminal Loss: {terminal_loss.item():.4f}")

    # --- 5. Save Model and Outputs ---
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(controller.state_dict(), os.path.join(output_dir, "recurrent_controller_model.pth"))
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)

    plt.figure(figsize=(10, 6))
    plt.plot(total_losses, label='Total Loss'); plt.plot(terminal_losses, label='Terminal Loss (Unweighted)')
    plt.plot(tracking_losses, label='Running Loss (Weighted)'); plt.plot(effort_losses, label='Effort Loss (Weighted)')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.title(f'Heat Eq. Controller Loss ({args.run_id})'); plt.legend()
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))
    print(f"\n--- Training Finished. Model and plots for '{args.run_id}' saved to {output_dir} ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Recurrent Controller for the Heat Equation.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--deeponet_run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--hidden_dim", type=int, required=True)
    parser.add_argument("--num_layers", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    parser.add_argument("--terminal_weight", type=float, required=True)
    parser.add_argument("--running_weight", type=float, required=True)
    parser.add_argument("--effort_weight", type=float, required=True)
    # --- NEW ARGUMENT ---
    parser.add_argument("--num_basis_functions", type=int, default=None, help="Override NUM_BASIS_FUNCTIONS from config file.")
    args = parser.parse_args()
    main(args)