# src/train_recurrent_controller.py
# FINAL version for the Burgers' project, with corrected target generation.

import torch
import torch.nn as nn
import torch.optim as optim
import os
import yaml
import argparse
import numpy as np
import matplotlib.pyplot as plt

from data_and_models import PropagatorDeepONet, RecurrentController

def generate_random_targets(config, batch_size, device):
    """
    Generates a batch of diverse, random, smooth target state profiles
    that all satisfy the zero Dirichlet boundary conditions (v(0)=v(L)=0).
    This version uses pure PyTorch operations for compatibility.
    """
    targets = torch.zeros(batch_size, config['M_SENSORS'], device=device)
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=device)

    for i in range(batch_size):
        profile_type = np.random.choice(['sine', 'gaussian_bump', 'parabolic_segment'])

        if profile_type == 'sine':
            num_waves = np.random.randint(1, 4)
            for _ in range(num_waves):
                # Keep amplitude as a Tensor
                amplitude = torch.randn(1, device=device) * 0.7
                wave_num = np.random.randint(1, 5)
                targets[i, :] += amplitude * torch.sin(wave_num * np.pi * sensor_locs / config['L'])

        elif profile_type == 'gaussian_bump':
            # Keep all random variables as Tensors
            amplitude = torch.randn(1, device=device) * 1.5
            center = torch.rand(1, device=device) * config['L']
            width = torch.rand(1, device=device) * (config['L'] / 4) + 0.05

            gaussian = amplitude * torch.exp(-((sensor_locs - center)**2) / (2 * width**2))
            
            # Correction line must also use Tensors
            g0 = amplitude * torch.exp(-((0 - center)**2) / (2 * width**2))
            gL = amplitude * torch.exp(-((config['L'] - center)**2) / (2 * width**2))
            correction_line = g0 + (gL - g0) * sensor_locs / config['L']
            targets[i, :] = gaussian - correction_line

        elif profile_type == 'parabolic_segment':
            amplitude = torch.randn(1, device=device) * 4.0
            targets[i, :] = amplitude * sensor_locs * (config['L'] - sensor_locs)
            
    return torch.clip(targets, -1.5, 1.5)

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training Burgers' Recurrent Controller: {args.run_id} ---")
    print(f"Using frozen Propagator simulator from run: {args.deeponet_run_id}")

    # --- 2. Load the FROZEN Physics Simulator (PropagatorDeepONet) ---
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    deeponet_hyperparams_path = os.path.join(deeponet_run_dir, "hyperparams.yaml")
    if not os.path.exists(deeponet_model_path):
        print(f"FATAL: Propagator model not found at {deeponet_model_path}"); return
    with open(deeponet_hyperparams_path, 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    for param in physics_simulator.parameters(): param.requires_grad = False
    print("Pre-trained Burgers' PropagatorDeepONet loaded and frozen.")

    # --- 3. Initialize Controller and Optimizer ---
    controller_kwargs = {k: v for k, v in vars(args).items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], control_scale=config['CONTROL_SCALE'], **controller_kwargs).to(DEVICE)
    if args.optimizer.lower() == 'adamw': optimizer = optim.AdamW(controller.parameters(), lr=args.learning_rate)
    else: optimizer = optim.Adam(controller.parameters(), lr=args.learning_rate)
    mse_loss_fn = nn.MSELoss()

    # --- 4. The Self-Supervised Training Loop ---
    print("\n--- Starting Self-Supervised Training via Differentiable Rollout ---")
    x_grid_sensors = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    total_losses, tracking_losses, effort_losses, terminal_losses = [], [], [], []
    for epoch in range(config['CONTROLLER_EPOCHS']):
        controller.train()
        x_final_target = generate_random_targets(config, config['CONTROLLER_BATCH_SIZE'], DEVICE)
        x_current = generate_random_targets(config, config['CONTROLLER_BATCH_SIZE'], DEVICE) * 0.5
        controller_hidden_state = None
        total_effort, running_tracking_loss = 0.0, 0.0
        for _ in range(config['NT_SOLVER'] - 1):
            w_k, controller_hidden_state = controller(x_current, x_final_target, controller_hidden_state)
            total_effort += torch.mean(torch.sum(w_k**2, dim=1))
            x_next_pred = physics_simulator(x_current, w_k, x_grid_sensors).squeeze(-1)
            x_current = x_next_pred
            running_tracking_loss += mse_loss_fn(x_current, x_final_target)
        x_final_predicted = x_current
        terminal_loss = mse_loss_fn(x_final_predicted, x_final_target)
        avg_effort_loss = total_effort / (config['NT_SOLVER'] - 1)
        avg_tracking_loss = running_tracking_loss / (config['NT_SOLVER'] - 1)
        total_loss = (args.terminal_weight * terminal_loss + args.running_weight * avg_tracking_loss + args.effort_weight * avg_effort_loss)
        optimizer.zero_grad(); total_loss.backward()
        torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0)
        optimizer.step()
        total_losses.append(total_loss.item()); terminal_losses.append(terminal_loss.item())
        tracking_losses.append(args.running_weight * avg_tracking_loss.item())
        effort_losses.append(args.effort_weight * avg_effort_loss.item())
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{config['CONTROLLER_EPOCHS']} | Total Loss: {total_loss.item():.4f} | Terminal MSE: {terminal_loss.item():.4f}")

    # --- 5. Save Model and Outputs ---
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(controller.state_dict(), os.path.join(output_dir, "burgers_controller_model.pth"))
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)
    plt.figure(figsize=(10, 6))
    plt.plot(total_losses, label='Total Loss (Weighted)')
    plt.plot(terminal_losses, label='Terminal MSE (Unweighted)', linestyle='--')
    plt.plot(tracking_losses, label='Running Loss (Weighted)', linestyle=':')
    plt.plot(effort_losses, label='Effort Loss (Weighted)', linestyle=':')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log')
    plt.title(f"Burgers' Controller Training Loss ({args.run_id})")
    plt.legend(); plt.grid(True); plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))
    print(f"\n--- Training Finished. Model and plots for '{args.run_id}' saved to {output_dir} ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Recurrent Controller for Burgers' Equation.")
    # ... (argparse section remains the same) ...
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--deeponet_run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--hidden_dim", type=int, required=True)
    parser.add_argument("--num_layers", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    parser.add_argument("--terminal_weight", type=float, required=True)
    parser.add_argument("--running_weight", type=float, required=True)
    parser.add_argument("--effort_weight", type=float, required=True)
    args = parser.parse_args()
    main(args)