
import torch
import torch.nn as nn
import torch.optim as optim
import os
import yaml
import argparse
import numpy as np
import matplotlib.pyplot as plt

from data_and_models import PropagatorDeepONet, RecurrentController

def generate_random_targets(config, batch_size, device):
    """
    Generates a batch of diverse, random, smooth target state profiles
    for the Heat Equation project. These do NOT need to be zero at the boundaries.
    """
    targets = torch.zeros(batch_size, config['M_SENSORS'], device=device)
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=device)

    for i in range(batch_size):
        # Randomly choose a function type for each sample in the batch
        profile_type = np.random.choice(['sine_combo'])
        #profile_type = 'sine_combo'

        if profile_type == 'sine_combo':
            num_waves = np.random.randint(1, 4)
            for _ in range(num_waves):
                amplitude = torch.randn(1, device=device).item() * 0.7
                frequency = torch.randn(1, device=device).item() * 3.0
                phase = torch.rand(1, device=device).item() * 2 * np.pi
                targets[i, :] += amplitude * torch.sin(frequency * sensor_locs * np.pi + phase)

        elif profile_type == 'gaussian_bump':
            # A Gaussian bump at a random location
            amplitude = torch.randn(1, device=device).item() * 1.5
            center = torch.rand(1, device=device).item() * config['L']
            width = torch.rand(1, device=device).item() * (config['L'] / 4) + 0.05
            targets[i, :] += amplitude * torch.exp(-((sensor_locs - center)**2) / (2 * width**2))

        elif profile_type == 'linear_ramp':
            # A simple linear ramp with a random start and end point
            start_val = torch.rand(1, device=device).item() * 2.0 - 0.5
            end_val = torch.rand(1, device=device).item() * 2.0 - 0.5
            targets[i, :] += torch.linspace(start_val, end_val, config['M_SENSORS'], device=device)
            
    # Center the final profile around a random reference value
    # V_REF_VAL from the config is 0.0 for the heat equation project, so we'll use a mean of 0.5
    mean_val = torch.rand(1, device=device).item() * 1.0 + 0.25
    
    # Normalize and recenter each profile in the batch
    for i in range(batch_size):
        targets[i, :] = mean_val + (targets[i, :] - torch.mean(targets[i, :]))
            
    return torch.clip(targets, -1.5, 1.5) # Clip to a reasonable range


def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training Heat Eq. Recurrent Controller: {args.run_id} ---")
    print(f"Using device: {DEVICE}")
    print(f"Using frozen Propagator simulator from run: {args.deeponet_run_id}")

    # --- 2. Load the FROZEN Physics Simulator (PropagatorDeepONet) ---
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "propagator_deeponet_best.pth")
    deeponet_hyperparams_path = os.path.join(deeponet_run_dir, "hyperparams.yaml")
    if not os.path.exists(deeponet_model_path):
        print(f"FATAL: Propagator model not found at {deeponet_model_path}"); return
    with open(deeponet_hyperparams_path, 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    for param in physics_simulator.parameters(): param.requires_grad = False
    print("Pre-trained PropagatorDeepONet loaded and frozen.")

    # --- 3. Initialize Controller and Optimizer ---
    controller_kwargs = {k: v for k, v in vars(args).items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    # Note: control_scale is not needed for the heat equation controller
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], **controller_kwargs).to(DEVICE)
    
    if args.optimizer.lower() == 'adamw': optimizer = optim.AdamW(controller.parameters(), lr=args.learning_rate)
    else: optimizer = optim.Adam(controller.parameters(), lr=args.learning_rate)
    mse_loss_fn = nn.MSELoss()

    # --- 4. The Self-Supervised Training Loop ---
    print("\n--- Starting Self-Supervised Training via Differentiable Rollout ---")
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    total_losses, tracking_losses, effort_losses, terminal_losses = [], [], [], []


    for epoch in range(config['CONTROLLER_EPOCHS']):
        controller.train()
        T_final_target = generate_random_targets(config, config['CONTROLLER_BATCH_SIZE'], DEVICE)
        # For the heat equation, we start from a zero initial state
        T_current = torch.full((config['CONTROLLER_BATCH_SIZE'], config['M_SENSORS']), 0.0, device=DEVICE)
        controller_hidden_state = None
        total_effort, tracking_loss = 0.0, 0.0

        for _ in range(config['NT_SOLVER'] - 1):
            w_k, controller_hidden_state = controller(T_current, T_final_target, controller_hidden_state)
            total_effort += torch.mean(torch.sum(w_k**2, dim=1))
            T_current = physics_simulator(T_current, w_k, sensor_locs).squeeze(-1)
            tracking_loss += mse_loss_fn(T_current, T_final_target)
        
        #Last time step for terminal loss
        w_k, controller_hidden_state = controller(T_current, T_final_target, controller_hidden_state)
        total_effort += torch.mean(torch.sum(w_k**2, dim=1))
        T_current = physics_simulator(T_current, w_k, sensor_locs).squeeze(-1)
        
        terminal_loss = mse_loss_fn(T_current, T_final_target)
        effort_loss = total_effort / (config['NT_SOLVER'] - 1)
        average_tracking_loss = tracking_loss / (config['NT_SOLVER'] - 1)
        
        total_loss = (args.terminal_weight * terminal_loss +
                      args.running_weight * average_tracking_loss +
                      args.effort_weight * effort_loss)

        optimizer.zero_grad(); total_loss.backward()
        torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0)
        optimizer.step()
        
        total_losses.append(total_loss.item()); terminal_losses.append(terminal_loss.item())
        tracking_losses.append(args.running_weight * average_tracking_loss.item())
        effort_losses.append(args.effort_weight * effort_loss.item())
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{config['CONTROLLER_EPOCHS']} | Total Loss: {total_loss.item():.4f} | Terminal Loss: {terminal_loss.item():.4f}")

    # --- 5. Save Model and Outputs ---
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(controller.state_dict(), os.path.join(output_dir, "recurrent_controller_model.pth"))
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)

    plt.figure(figsize=(10, 6))
    plt.plot(total_losses, label='Total Loss'); plt.plot(terminal_losses, label='Terminal Loss (Unweighted)')
    plt.plot(tracking_losses, label='Running Loss (Weighted)'); plt.plot(effort_losses, label='Effort Loss (Weighted)')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.title(f'Heat Eq. Controller Loss ({args.run_id})'); plt.legend()
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))
    print(f"\n--- Training Finished. Model and plots for '{args.run_id}' saved to {output_dir} ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Recurrent Controller for the Heat Equation.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--deeponet_run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--hidden_dim", type=int, required=True)
    parser.add_argument("--num_layers", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    parser.add_argument("--terminal_weight", type=float, required=True)
    parser.add_argument("--running_weight", type=float, required=True)
    parser.add_argument("--effort_weight", type=float, required=True)
    args = parser.parse_args()
    main(args)