# src/train_recurrent_controller.py
# FINAL version for the Heat Equation project.
# Includes an upgraded, more diverse target generation function.

import torch
import torch.nn as nn
import torch.optim as optim
import os
import yaml
import argparse
import numpy as np
import matplotlib.pyplot as plt

from data_and_models import PropagatorDeepONet, RecurrentController

def generate_random_targets(config, batch_size, device):
    """
    Generates a batch of diverse, random, smooth target state profiles
    for the Heat Equation project. These do NOT need to be zero at the boundaries.
    """
    targets = torch.zeros(batch_size, config['M_SENSORS'], device=device)
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=device)

    for i in range(batch_size):
        # Randomly choose a function type for each sample in the batch
        profile_type = np.random.choice(['sine_combo'])
        #profile_type = 'sine_combo'

        if profile_type == 'sine_combo':
            # A sum of 1 to 3 sine waves with random frequencies and phases
            # This creates complex, smooth, non-zero boundary profiles
            num_waves = np.random.randint(1, 4)
            for _ in range(num_waves):
                amplitude = torch.randn(1, device=device).item() * 0.7
                frequency = torch.randn(1, device=device).item() * 3.0
                phase = torch.rand(1, device=device).item() * 2 * np.pi
                targets[i, :] += amplitude * torch.sin(frequency * sensor_locs * np.pi + phase)

        elif profile_type == 'gaussian_bump':
            # A Gaussian bump at a random location
            amplitude = torch.randn(1, device=device).item() * 1.5
            center = torch.rand(1, device=device).item() * config['L']
            width = torch.rand(1, device=device).item() * (config['L'] / 4) + 0.05
            targets[i, :] += amplitude * torch.exp(-((sensor_locs - center)**2) / (2 * width**2))

        elif profile_type == 'linear_ramp':
            # A simple linear ramp with a random start and end point
            start_val = torch.rand(1, device=device).item() * 2.0 - 0.5
            end_val = torch.rand(1, device=device).item() * 2.0 - 0.5
            targets[i, :] += torch.linspace(start_val, end_val, config['M_SENSORS'], device=device)
            
    # Center the final profile around a random reference value
    # V_REF_VAL from the config is 0.0 for the heat equation project, so we'll use a mean of 0.5
    mean_val = torch.rand(1, device=device).item() * 1.0 + 0.25
    
    # Normalize and recenter each profile in the batch
    for i in range(batch_size):
        targets[i, :] = mean_val + (targets[i, :] - torch.mean(targets[i, :]))
            
    return torch.clip(targets, -1.5, 1.5) # Clip to a reasonable range


def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training Heat Eq. Recurrent Controller: {args.run_id} ---")
    print(f"Using device: {DEVICE}")
    print(f"Using frozen Propagator simulator from run: {args.deeponet_run_id}")

    # --- 2. Load the FROZEN Physics Simulator (PropagatorDeepONet) ---
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "propagator_deeponet_best.pth")
    deeponet_hyperparams_path = os.path.join(deeponet_run_dir, "hyperparams.yaml")
    if not os.path.exists(deeponet_model_path):
        print(f"FATAL: Propagator model not found at {deeponet_model_path}"); return
    with open(deeponet_hyperparams_path, 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    for param in physics_simulator.parameters(): param.requires_grad = False
    print("Pre-trained PropagatorDeepONet loaded and frozen.")

    # --- 3. Initialize Controller and Optimizer ---
    controller_kwargs = {k: v for k, v in vars(args).items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    # Note: control_scale is not needed for the heat equation controller
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], **controller_kwargs).to(DEVICE)
    
    if args.optimizer.lower() == 'adamw': optimizer = optim.AdamW(controller.parameters(), lr=args.learning_rate)
    else: optimizer = optim.Adam(controller.parameters(), lr=args.learning_rate)
    mse_loss_fn = nn.MSELoss()

    # --- 4. The Self-Supervised Training Loop ---
    print("\n--- Starting Self-Supervised Training via Differentiable Rollout ---")
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    total_losses, tracking_losses, effort_losses, terminal_losses = [], [], [], []


    for epoch in range(config['CONTROLLER_EPOCHS']):
        controller.train()
        T_final_target = generate_random_targets(config, config['CONTROLLER_BATCH_SIZE'], DEVICE)
        # For the heat equation, we start from a zero initial state
        T_current = torch.full((config['CONTROLLER_BATCH_SIZE'], config['M_SENSORS']), 0.0, device=DEVICE)
        controller_hidden_state = None
        total_effort, tracking_loss = 0.0, 0.0

        for _ in range(config['NT_SOLVER'] - 1):
            w_k, controller_hidden_state = controller(T_current, T_final_target, controller_hidden_state)
            total_effort += torch.mean(torch.sum(w_k**2, dim=1))
            T_current = physics_simulator(T_current, w_k, sensor_locs).squeeze(-1)
            tracking_loss += mse_loss_fn(T_current, T_final_target)
        
        #Last time step for terminal loss
        w_k, controller_hidden_state = controller(T_current, T_final_target, controller_hidden_state)
        total_effort += torch.mean(torch.sum(w_k**2, dim=1))
        T_current = physics_simulator(T_current, w_k, sensor_locs).squeeze(-1)
        
        terminal_loss = mse_loss_fn(T_current, T_final_target)
        effort_loss = total_effort / (config['NT_SOLVER'] - 1)
        average_tracking_loss = tracking_loss / (config['NT_SOLVER'] - 1)
        
        total_loss = (args.terminal_weight * terminal_loss +
                      args.running_weight * average_tracking_loss +
                      args.effort_weight * effort_loss)

        optimizer.zero_grad(); total_loss.backward()
        torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0)
        optimizer.step()
        
        total_losses.append(total_loss.item()); terminal_losses.append(terminal_loss.item())
        tracking_losses.append(args.running_weight * average_tracking_loss.item())
        effort_losses.append(args.effort_weight * effort_loss.item())
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{config['CONTROLLER_EPOCHS']} | Total Loss: {total_loss.item():.4f} | Terminal Loss: {terminal_loss.item():.4f}")

    # --- 5. Save Model and Outputs ---
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(controller.state_dict(), os.path.join(output_dir, "recurrent_controller_model.pth"))
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)

    plt.figure(figsize=(10, 6))
    plt.plot(total_losses, label='Total Loss'); plt.plot(terminal_losses, label='Terminal Loss (Unweighted)')
    plt.plot(tracking_losses, label='Running Loss (Weighted)'); plt.plot(effort_losses, label='Effort Loss (Weighted)')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.title(f'Heat Eq. Controller Loss ({args.run_id})'); plt.legend()
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))
    print(f"\n--- Training Finished. Model and plots for '{args.run_id}' saved to {output_dir} ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Recurrent Controller for the Heat Equation.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--deeponet_run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--hidden_dim", type=int, required=True)
    parser.add_argument("--num_layers", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    parser.add_argument("--terminal_weight", type=float, required=True)
    parser.add_argument("--running_weight", type=float, required=True)
    parser.add_argument("--effort_weight", type=float, required=True)
    args = parser.parse_args()
    main(args)