# src/train_recurrent_controller_2d.py
# Adapted to train the controller for the 2D Heat Equation.
# Includes a 2D target generation function.

import torch
import torch.nn as nn
import torch.optim as optim
import os
import yaml
import argparse
import numpy as np
import matplotlib.pyplot as plt

# --- CHANGE: Import from the 2D models file ---
from data_and_models_2d import PropagatorDeepONet, RecurrentController

# --- CHANGE: New function to generate 2D target profiles ---
def generate_random_targets_2d(config, batch_size, device):
    """
    Generates a batch of diverse, random, smooth 2D target state profiles.
    The output is flattened to a vector of size (NX_SENSORS * NY_SENSORS).
    """
    nx, ny = config['NX_SENSORS'], config['NY_SENSORS']
    total_sensors = nx * ny
    targets = torch.zeros(batch_size, total_sensors, device=device)
    
    x = torch.linspace(0, config['L_X'], nx, device=device)
    y = torch.linspace(0, config['L_Y'], ny, device=device)
    grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')

    for i in range(batch_size):
        # Using a sum of 2D sine waves for diverse, smooth profiles
        num_waves = np.random.randint(1, 4)
        target_2d = torch.zeros(ny, nx, device=device)
        for _ in range(num_waves):
            amplitude = torch.randn(1, device=device).item() * 0.7
            freq_x = torch.randn(1, device=device).item() * 2.0
            freq_y = torch.randn(1, device=device).item() * 2.0
            phase_x = torch.rand(1, device=device).item() * 2 * np.pi
            phase_y = torch.rand(1, device=device).item() * 2 * np.pi
            
            target_2d += amplitude * torch.sin(freq_x * grid_x * np.pi + phase_x) * torch.sin(freq_y * grid_y * np.pi + phase_y)

        # Center the final profile around a random reference value
        mean_val = torch.rand(1, device=device).item() * 1.0 + 0.25
        target_2d = mean_val + (target_2d - torch.mean(target_2d))
        
        # Flatten the 2D profile to a vector for the model
        targets[i, :] = target_2d.reshape(-1)
            
    return torch.clip(targets, -1.5, 1.5)

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training 2D Recurrent Controller: {args.run_id} ---")
    print(f"Using device: {DEVICE}")
    print(f"Using frozen 2D Propagator simulator from run: {args.deeponet_run_id}")

    # --- 2. Load the FROZEN Physics Simulator (2D PropagatorDeepONet) ---
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    # --- CHANGE: Path to 2D model and hyperparams ---
    deeponet_model_path = os.path.join(deeponet_run_dir, "propagator_deeponet_2d_best.pth")
    deeponet_hyperparams_path = os.path.join(deeponet_run_dir, "hyperparams_propagator_2d.yaml")
    
    if not os.path.exists(deeponet_model_path):
        print(f"FATAL: Propagator model not found at {deeponet_model_path}"); return
    with open(deeponet_hyperparams_path, 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    
    # --- CHANGE: Calculate total sensors and basis functions for 2D ---
    total_sensors = config['NX_SENSORS'] * config['NY_SENSORS']
    total_basis_functions = config['NUM_BASIS_X'] * config['NUM_BASIS_Y']
    
    physics_simulator = PropagatorDeepONet(
        M_sensors=total_sensors,
        num_basis_functions=total_basis_functions,
        trunk_input_dim=config['TRUNK_INPUT_DIM'], # Should be 2
        **deeponet_kwargs
    ).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    for param in physics_simulator.parameters(): param.requires_grad = False
    print("Pre-trained 2D PropagatorDeepONet loaded and frozen.")

    # --- 3. Initialize Controller and Optimizer ---
    controller_kwargs = {k: v for k, v in vars(args).items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = RecurrentController(
        M_sensors=total_sensors,
        num_basis_functions=total_basis_functions,
        **controller_kwargs
    ).to(DEVICE)
    
    if args.optimizer.lower() == 'adamw': optimizer = optim.AdamW(controller.parameters(), lr=args.learning_rate)
    else: optimizer = optim.Adam(controller.parameters(), lr=args.learning_rate)
    mse_loss_fn = nn.MSELoss()

    # --- 4. The Self-Supervised Training Loop ---
    print("\n--- Starting Self-Supervised Training via Differentiable Rollout ---")
    # --- CHANGE: Create a 2D grid of sensor locations ---
    x_locs = np.linspace(0, config['L_X'], config['NX_SENSORS'])
    y_locs = np.linspace(0, config['L_Y'], config['NY_SENSORS'])
    grid_x, grid_y = np.meshgrid(x_locs, y_locs)
    sensor_locs_np = np.stack([grid_x.ravel(), grid_y.ravel()], axis=-1)
    sensor_locs = torch.from_numpy(sensor_locs_np).float().to(DEVICE).unsqueeze(0)

    total_losses, tracking_losses, effort_losses, terminal_losses = [], [], [], []

    for epoch in range(config['CONTROLLER_EPOCHS']):
        controller.train()
        # --- CHANGE: Use the 2D target generation function ---
        T_final_target = generate_random_targets_2d(config, config['CONTROLLER_BATCH_SIZE'], DEVICE)
        # Start from a zero initial state, sized for the 2D grid
        T_current = torch.full((config['CONTROLLER_BATCH_SIZE'], total_sensors), 0.0, device=DEVICE)
        
        controller_hidden_state = None
        total_effort, tracking_loss = 0.0, 0.0

        # The rollout loop logic remains the same because we operate on flattened vectors
        for _ in range(config['NT_SOLVER'] - 1):
            w_k, controller_hidden_state = controller(T_current, T_final_target, controller_hidden_state)
            total_effort += torch.mean(torch.sum(w_k**2, dim=1))
            T_current = physics_simulator(T_current, w_k, sensor_locs).squeeze(-1)
            tracking_loss += mse_loss_fn(T_current, T_final_target)
        
        # Last time step for terminal loss
        w_k, controller_hidden_state = controller(T_current, T_final_target, controller_hidden_state)
        total_effort += torch.mean(torch.sum(w_k**2, dim=1))
        T_current = physics_simulator(T_current, w_k, sensor_locs).squeeze(-1)
        
        terminal_loss = mse_loss_fn(T_current, T_final_target)
        effort_loss = total_effort / (config['NT_SOLVER'] - 1)
        average_tracking_loss = tracking_loss / (config['NT_SOLVER'] - 1)
        
        total_loss = (args.terminal_weight * terminal_loss +
                      args.running_weight * average_tracking_loss +
                      args.effort_weight * effort_loss)

        optimizer.zero_grad(); total_loss.backward()
        torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0)
        optimizer.step()
        
        total_losses.append(total_loss.item()); terminal_losses.append(terminal_loss.item())
        tracking_losses.append(args.running_weight * average_tracking_loss.item())
        effort_losses.append(args.effort_weight * effort_loss.item())
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{config['CONTROLLER_EPOCHS']} | Total Loss: {total_loss.item():.4f} | Terminal Loss: {terminal_loss.item():.4f}")

    # --- 5. Save Model and Outputs ---
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    # --- CHANGE: Update save paths for 2D ---
    torch.save(controller.state_dict(), os.path.join(output_dir, "recurrent_controller_2d_model.pth"))
    with open(os.path.join(output_dir, 'hyperparams_controller_2d.yaml'), 'w') as f: yaml.dump(vars(args), f)

    plt.figure(figsize=(10, 6))
    plt.plot(total_losses, label='Total Loss'); plt.plot(terminal_losses, label='Terminal Loss (Unweighted)')
    plt.plot(tracking_losses, label='Running Loss (Weighted)'); plt.plot(effort_losses, label='Effort Loss (Weighted)')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.title(f'2D Controller Loss ({args.run_id})'); plt.legend()
    plt.savefig(os.path.join(output_dir, "loss_curve_controller_2d.png"))
    print(f"\n--- Training Finished. Model and plots for '{args.run_id}' saved to {output_dir} ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Recurrent Controller for the 2D Heat Equation.")
    parser.add_argument("--config_path", type=str, required=True, help="Path to the 2D config YAML file.")
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True, help="Unique ID for this controller training run.")
    parser.add_argument("--deeponet_run_id", type=str, required=True, help="ID of the pre-trained 2D Propagator run.")
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--hidden_dim", type=int, required=True)
    parser.add_argument("--num_layers", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    parser.add_argument("--terminal_weight", type=float, required=True)
    parser.add_argument("--running_weight", type=float, required=True)
    parser.add_argument("--effort_weight", type=float, required=True)
    args = parser.parse_args()
    main(args)