# src/train_recurrent_controller.py
# REFACTORED: For training a physics-aware controller with variable viscosity nu(x).

import torch
import torch.nn as nn
import torch.optim as optim
import os
import yaml
import argparse
import numpy as np
import matplotlib.pyplot as plt

# MODIFIED: Import the spatial GRF generator
from data_and_models import PropagatorDeepONet, RecurrentController, generate_grf_spatial_series

def generate_random_targets(config, batch_size, device):
    """
    Generates a batch of diverse target profiles, including smooth and non-smooth shapes.
    It randomly chooses from the full suite of evaluation targets.
    All profiles are corrected to satisfy y(0)=0 and y(L)=0.
    """
    L = config['L']
    x_grid = np.linspace(0, L, config['M_SENSORS'])
    
    # List of all possible target types
    all_target_types = ['sum_of_sines', 'gaussian_bumps', 'step', 'high_freq']
    
    # Generate a random choice of target type for each item in the batch
    chosen_types = np.random.choice(all_target_types, size=batch_size)
    
    batch_targets = []

    for target_type in chosen_types:
        target_profile = np.zeros_like(x_grid)
        
        # --- Generate the raw, uncorrected profile ---
        if target_type == 'sum_of_sines':
            num_waves = np.random.randint(1, 4)
            for _ in range(num_waves):
                amplitude = np.random.randn() * 0.7
                wave_num = np.random.randint(1, 5)
                target_profile += amplitude * np.sin(wave_num * np.pi * x_grid / L)
            
        elif target_type == 'gaussian_bumps':
            num_bumps = np.random.randint(1, 3)
            for _ in range(num_bumps):
                amplitude = np.random.randn() * 1.5
                center = np.random.rand() * L
                width = np.random.rand() * (L / 5) + 0.05
                target_profile += amplitude * np.exp(-((x_grid - center)**2) / (2 * width**2))
            
        elif target_type == 'step':
            amplitude = np.random.uniform(0.4, 0.8)
            target_profile[(x_grid > 0.3 * L) & (x_grid < 0.7 * L)] = amplitude
        
        elif target_type == 'high_freq':
            amp1 = np.random.uniform(0.4, 0.6)
            amp2 = np.random.uniform(0.1, 0.3)
            freq1 = np.random.randint(1, 3)
            freq2 = np.random.randint(10, 15)
            base = amp1 * np.sin(freq1 * np.pi * x_grid / L)
            hf_comp = amp2 * np.sin(freq2 * np.pi * x_grid / L)
            target_profile = base + hf_comp
            
        # --- Apply Dirichlet Correction ---
        correction_line = target_profile[0] + (target_profile[-1] - target_profile[0]) * x_grid / L
        corrected_target = target_profile - correction_line
        
        batch_targets.append(corrected_target)

    # Convert the list of NumPy arrays to a single PyTorch tensor
    batch_targets_np = np.array(batch_targets)
    clipped_targets = np.clip(batch_targets_np, -1.5, 1.5)
    
    return torch.from_numpy(clipped_targets).float().to(device)

# NEW: Helper function to generate a batch of random viscosity profiles on the fly
def generate_random_viscosity_profiles(config, batch_size, device):
    """
    Generates a batch of random, smooth, positive viscosity profiles using a GRF.
    """
    # Generate using the numpy-based function first
    nu_raw_np = generate_grf_spatial_series(config, batch_size, config['VISCOSITY_LENGTH_SCALE'])
    
    # Scale to be positive and within the desired range
    nu_min, nu_max = config['VISCOSITY_RANGE']
    viscosity_profiles_np = nu_min + (nu_max - nu_min) * (0.5 * (np.tanh(nu_raw_np) + 1))
    
    # Convert to a torch tensor on the correct device
    return torch.from_numpy(viscosity_profiles_np).float().to(device)

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training Burgers' Recurrent Controller: {args.run_id} ---")
    print(f"Using frozen Propagator simulator from run: {args.deeponet_run_id}")

    # --- 2. Load the FROZEN Physics Simulator (PropagatorDeepONet) ---
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    deeponet_hyperparams_path = os.path.join(deeponet_run_dir, "hyperparams.yaml")
    if not os.path.exists(deeponet_model_path):
        print(f"FATAL: Propagator model not found at {deeponet_model_path}"); return
    with open(deeponet_hyperparams_path, 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    for param in physics_simulator.parameters(): param.requires_grad = False
    print("Pre-trained Burgers' PropagatorDeepONet loaded and frozen.")

    # --- 3. Initialize Controller and Optimizer ---
    controller_kwargs = {k: v for k, v in vars(args).items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], control_scale=config['CONTROL_SCALE'], **controller_kwargs).to(DEVICE)
    if args.optimizer.lower() == 'adamw': optimizer = optim.AdamW(controller.parameters(), lr=args.learning_rate)
    else: optimizer = optim.Adam(controller.parameters(), lr=args.learning_rate)
    mse_loss_fn = nn.MSELoss()
    
     # NEW: Define the learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',      # Reduce LR when the metric (loss) stops decreasing
        factor=0.5,      # New LR = Old LR * 0.5
        patience=100   # Print a message to the console when LR is reduced
    )

    # --- 4. The Self-Supervised Training Loop ---
    print("\n--- Starting Self-Supervised Training via Differentiable Rollout ---")
    x_grid_sensors = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    total_losses, tracking_losses, effort_losses, terminal_losses = [], [], [], []
    
    for epoch in range(config['CONTROLLER_EPOCHS']):
        controller.train()
        
        # MODIFIED: Generate a random problem instance, including the viscosity
        batch_size = config['CONTROLLER_BATCH_SIZE']
        x_final_target = generate_random_targets(config, batch_size, DEVICE)
        x_current = generate_random_targets(config, batch_size, DEVICE) * 0.5 # Start from a smaller random state
        viscosity_profile_batch = generate_random_viscosity_profiles(config, batch_size, DEVICE)
        
        controller_hidden_state = None
        total_effort, running_tracking_loss = 0.0, 0.0
        
        # Perform the differentiable rollout for this batch of problems
        for _ in range(config['NT_SOLVER'] - 1):
            # MODIFIED: Controller now receives the viscosity profile to inform its decision
            w_k, controller_hidden_state = controller(x_current, x_final_target, viscosity_profile_batch, controller_hidden_state)
            
            total_effort += torch.mean(torch.sum(w_k**2, dim=1))
            
            # MODIFIED: Simulator receives the viscosity profile to make an accurate prediction
            x_next_pred = physics_simulator(x_current, w_k, viscosity_profile_batch, x_grid_sensors).squeeze(-1)
            
            x_current = x_next_pred
            running_tracking_loss += mse_loss_fn(x_current, x_final_target)
            
        # Calculate losses based on the final state of the rollout
        x_final_predicted = x_current
        terminal_loss = mse_loss_fn(x_final_predicted, x_final_target)
        avg_effort_loss = total_effort / (config['NT_SOLVER'] - 1)
        avg_tracking_loss = running_tracking_loss / (config['NT_SOLVER'] - 1)
        
        total_loss = (args.terminal_weight * terminal_loss + 
                      args.running_weight * avg_tracking_loss + 
                      args.effort_weight * avg_effort_loss)
        
        # Backpropagation and optimization step
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0)
        optimizer.step()
        
        # Logging
        total_losses.append(total_loss.item())
        terminal_losses.append(terminal_loss.item())
        tracking_losses.append(args.running_weight * avg_tracking_loss.item())
        effort_losses.append(args.effort_weight * avg_effort_loss.item())
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{config['CONTROLLER_EPOCHS']} | Total Loss: {total_loss.item():.4f} | Terminal MSE: {terminal_loss.item():.4f}")
        
        scheduler.step(total_loss)
    # --- 5. Save Model and Outputs ---
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(controller.state_dict(), os.path.join(output_dir, "burgers_controller_model.pth"))
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)
    
    plt.figure(figsize=(10, 6))
    plt.plot(total_losses, label='Total Loss (Weighted)')
    plt.plot(terminal_losses, label='Terminal MSE (Unweighted)', linestyle='--')
    plt.plot(tracking_losses, label='Running Loss (Weighted)', linestyle=':')
    plt.plot(effort_losses, label='Effort Loss (Weighted)', linestyle=':')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log')
    plt.title(f"Burgers' Controller Training Loss ({args.run_id})")
    plt.legend(); plt.grid(True); plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))
    
    print(f"\n--- Training Finished. Model and plots for '{args.run_id}' saved to {output_dir} ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Recurrent Controller for Burgers' Equation.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--deeponet_run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--hidden_dim", type=int, required=True)
    parser.add_argument("--num_layers", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    parser.add_argument("--terminal_weight", type=float, required=True)
    parser.add_argument("--running_weight", type=float, required=True)
    parser.add_argument("--effort_weight", type=float, required=True)
    args = parser.parse_args()
    main(args)