

import torch
import torch.nn as nn
import torch.optim as optim
import os
import yaml
import argparse
import numpy as np
import matplotlib.pyplot as plt

from data_and_models import PropagatorDeepONet, RecurrentController

def generate_random_targets(config, batch_size, device):
    """
    Generates a batch of diverse, random, smooth target state profiles
    that all satisfy the zero Dirichlet boundary conditions (v(0)=v(L)=0).
    This function requires no changes for the direct control method.
    """
    targets = torch.zeros(batch_size, config['M_SENSORS'], device=device)
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=device)

    for i in range(batch_size):
        profile_type = np.random.choice(['sine', 'gaussian_bump', 'parabolic_segment'])

        if profile_type == 'sine':
            num_waves = np.random.randint(1, 4)
            for _ in range(num_waves):
                amplitude = torch.randn(1, device=device) * 0.7
                wave_num = np.random.randint(1, 5)
                targets[i, :] += amplitude * torch.sin(wave_num * np.pi * sensor_locs / config['L'])

        elif profile_type == 'gaussian_bump':
            amplitude = torch.randn(1, device=device) * 1.5
            center = torch.rand(1, device=device) * config['L']
            width = torch.rand(1, device=device) * (config['L'] / 4) + 0.05
            gaussian = amplitude * torch.exp(-((sensor_locs - center)**2) / (2 * width**2))
            g0 = amplitude * torch.exp(-((0 - center)**2) / (2 * width**2))
            gL = amplitude * torch.exp(-((config['L'] - center)**2) / (2 * width**2))
            correction_line = g0 + (gL - g0) * sensor_locs / config['L']
            targets[i, :] = gaussian - correction_line

        elif profile_type == 'parabolic_segment':
            amplitude = torch.randn(1, device=device) * 4.0
            targets[i, :] = amplitude * sensor_locs * (config['L'] - sensor_locs)
            
    return torch.clip(targets, -1.5, 1.5)

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training Burgers' Recurrent Controller (Direct Control): {args.run_id} ---")
    print(f"Using frozen Propagator simulator from run: {args.deeponet_run_id}")

    # --- 2. Load the FROZEN Physics Simulator (PropagatorDeepONet) ---
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    deeponet_hyperparams_path = os.path.join(deeponet_run_dir, "hyperparams.yaml")
    if not os.path.exists(deeponet_model_path):
        print(f"FATAL: Propagator model not found at {deeponet_model_path}"); return
    with open(deeponet_hyperparams_path, 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    
    # This correctly loads the direct-control Propagator trained previously
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    for param in physics_simulator.parameters(): param.requires_grad = False
    print("Pre-trained Burgers' PropagatorDeepONet (Direct Control version) loaded and frozen.")

    # --- 3. Initialize Controller and Optimizer ---
    controller_kwargs = {k: v for k, v in vars(args).items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    
    # This correctly initializes the direct-control Controller
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], control_scale=config['CONTROL_SCALE'], **controller_kwargs).to(DEVICE)
    
    if args.optimizer.lower() == 'adamw': optimizer = optim.AdamW(controller.parameters(), lr=args.learning_rate)
    else: optimizer = optim.Adam(controller.parameters(), lr=args.learning_rate)
    mse_loss_fn = nn.MSELoss()

    # --- 4. The Self-Supervised Training Loop ---
    print("\n--- Starting Self-Supervised Training via Differentiable Rollout ---")
    x_grid_sensors = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    total_losses, tracking_losses, effort_losses, terminal_losses = [], [], [], []
    for epoch in range(config['CONTROLLER_EPOCHS']):
        controller.train()
        x_final_target = generate_random_targets(config, config['CONTROLLER_BATCH_SIZE'], DEVICE)
        x_current = generate_random_targets(config, config['CONTROLLER_BATCH_SIZE'], DEVICE) * 0.5
        controller_hidden_state = None
        total_effort, running_tracking_loss = 0.0, 0.0
        
        for _ in range(config['NT_SOLVER'] - 1):
            # MODIFIED: Renamed 'w_k' to 'u_k_at_sensors' for clarity. This is the direct control value at sensor locations.
            u_k_at_sensors, controller_hidden_state = controller(x_current, x_final_target, controller_hidden_state)
            
            # Effort loss calculation works the same way, but on the new control vector
            total_effort += torch.mean(torch.sum(u_k_at_sensors**2, dim=1))
            
            # The simulator's forward call is now semantically correct. It takes the state and the control values at sensors.
            x_next_pred = physics_simulator(x_current, u_k_at_sensors, x_grid_sensors).squeeze(-1)
            
            x_current = x_next_pred
            running_tracking_loss += mse_loss_fn(x_current, x_final_target)
            
        x_final_predicted = x_current
        terminal_loss = mse_loss_fn(x_final_predicted, x_final_target)
        avg_effort_loss = total_effort / (config['NT_SOLVER'] - 1)
        avg_tracking_loss = running_tracking_loss / (config['NT_SOLVER'] - 1)
        
        total_loss = (args.terminal_weight * terminal_loss + 
                      args.running_weight * avg_tracking_loss + 
                      args.effort_weight * avg_effort_loss)
                      
        optimizer.zero_grad(); total_loss.backward()
        torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0)
        optimizer.step()
        
        total_losses.append(total_loss.item()); terminal_losses.append(terminal_loss.item())
        tracking_losses.append(args.running_weight * avg_tracking_loss.item())
        effort_losses.append(args.effort_weight * avg_effort_loss.item())
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{config['CONTROLLER_EPOCHS']} | Total Loss: {total_loss.item():.4f} | Terminal MSE: {terminal_loss.item():.4f}")

    # --- 5. Save Model and Outputs (No changes needed here) ---
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(controller.state_dict(), os.path.join(output_dir, "burgers_controller_model.pth"))
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)
    plt.figure(figsize=(10, 6))
    plt.plot(total_losses, label='Total Loss (Weighted)')
    plt.plot(terminal_losses, label='Terminal MSE (Unweighted)', linestyle='--')
    plt.plot(tracking_losses, label='Running Loss (Weighted)', linestyle=':')
    plt.plot(effort_losses, label='Effort Loss (Weighted)', linestyle=':')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log')
    plt.title(f"Burgers' Controller Training Loss ({args.run_id})")
    plt.legend(); plt.grid(True); plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))
    print(f"\n--- Training Finished. Model and plots for '{args.run_id}' saved to {output_dir} ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Recurrent Controller for Burgers' Equation using Direct Control.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--deeponet_run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--hidden_dim", type=int, required=True)
    parser.add_argument("--num_layers", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    parser.add_argument("--terminal_weight", type=float, required=True)
    parser.add_argument("--running_weight", type=float, required=True)
    parser.add_argument("--effort_weight", type=float, required=True)
    args = parser.parse_args()
    main(args)