# src/train_controller_direct.py

import torch
import torch.nn as nn
import torch.optim as optim
import os
import argparse
import yaml
import numpy as np
import matplotlib.pyplot as plt

from data_and_models import DeepONetWithBias

# --- Controller Model Definition ---
class DirectDecisionMaker(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_layers, activation_fn):
        super(DirectDecisionMaker, self).__init__()
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()
        layers = [nn.Linear(input_dim, hidden_dim), activation]
        for _ in range(num_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), activation])
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.net = nn.Sequential(*layers)
    def forward(self, zeta):
        return self.net(zeta)

def generate_random_targets(config, batch_size, device):
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=device).unsqueeze(1)
    zeta = torch.zeros(batch_size, config['M_SENSORS'], device=device)
    for i in range(batch_size):
        for _ in range(np.random.randint(1, 4)):
            zeta[i, :] += (torch.randn(1, device=device) * 0.5 * torch.sin(torch.randn(1, device=device) * 3.0 * sensor_locs.squeeze() * np.pi + torch.rand(1, device=device) * 2 * np.pi))
    return config['V_REF_VAL'] + (zeta - zeta.mean()) * 0.2

def main(args):
    # --- 1. Load Configurations & Setup ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Starting Controller Training: {args.run_id} ---")
    print(f"Using device: {DEVICE}")
    print(f"Using frozen DeepONet simulator from run: {args.deeponet_run_id}")

    # --- 2. Load the FROZEN Physics Simulator (DeepONet) ---
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "deeponet_model.pth")
    deeponet_hyperparams_path = os.path.join(deeponet_run_dir, "hyperparams.yaml")
    if not os.path.exists(deeponet_model_path):
        print(f"FATAL: Pre-trained DeepONet model not found at {deeponet_model_path}"); return
    with open(deeponet_hyperparams_path, 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    model_arg_keys = ['latent_dim', 'branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = DeepONetWithBias(branch_input_dim=config['M_SENSORS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    for param in physics_simulator.parameters(): param.requires_grad = False
    print("Pre-trained DeepONet physics simulator loaded and frozen.")


    controller_arg_keys = ['hidden_dim', 'num_layers', 'activation_fn']

    controller_kwargs = {key: getattr(args, key) for key in controller_arg_keys}

    decision_maker = DirectDecisionMaker(
        input_dim=config['M_SENSORS'],
        output_dim=config['M_SENSORS'],
        **controller_kwargs 
    ).to(DEVICE)
    
    if args.optimizer.lower() == 'adamw':
        optimizer = optim.AdamW(decision_maker.parameters(), lr=args.learning_rate)
    elif args.optimizer.lower() == 'adam':
        optimizer = optim.Adam(decision_maker.parameters(), lr=args.learning_rate)
    else:
        raise ValueError("Optimizer not supported")
    mse_loss_fn = nn.MSELoss()

    # --- 4. Initialize Training Components ---
    mu_U_max = torch.zeros(1, requires_grad=False, device=DEVICE)
    mu_U_min = torch.zeros(1, requires_grad=False, device=DEVICE)
    torch.manual_seed(42)
    zeta_val = generate_random_targets(config, 512, DEVICE)
    patience, patience_counter, best_val_loss = 100, 0, float('inf')
    train_loss_history, val_loss_history = [], []
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    best_model_path = os.path.join(output_dir, "controller_model_direct_best.pth")

    # --- 5. The Self-Supervised Training Loop ---
    print("\n--- Starting Self-Supervised Training Loop ---")
    sensor_locations_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(1)
    for epoch in range(config['CONTROLLER_EPOCHS']):
        decision_maker.train()
        epoch_train_loss = 0.0
        for _ in range(config['CONTROLLER_STEPS_PER_EPOCH']):
            zeta_batch = generate_random_targets(config, config['CONTROLLER_BATCH_SIZE'], DEVICE)
            u_at_sensors = decision_maker(zeta_batch)
            trunk_input = torch.cat([sensor_locations_torch.repeat(config['CONTROLLER_BATCH_SIZE'], 1), torch.full((config['M_SENSORS'] * config['CONTROLLER_BATCH_SIZE'], 1), config['T_FINAL'], device=DEVICE)], dim=1)
            branch_input = torch.repeat_interleave(u_at_sensors, repeats=config['M_SENSORS'], dim=0)
            V_predicted_at_T = physics_simulator(branch_input, trunk_input).view(config['CONTROLLER_BATCH_SIZE'], config['M_SENSORS'])
            loss_U_up = torch.mean(torch.relu(u_at_sensors - config['U_MAX']))
            loss_U_down = torch.mean(torch.relu(config['U_MIN'] - u_at_sensors))
            tracking_loss = mse_loss_fn(V_predicted_at_T, zeta_batch)
            effort_loss = args.gamma * torch.mean(torch.sum(u_at_sensors**2, dim=1))
            loss = tracking_loss + effort_loss + mu_U_max.detach() * loss_U_up + mu_U_min.detach() * loss_U_down
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            with torch.no_grad(): 
                mu_U_max += config['RHO'] * loss_U_up.detach()
                mu_U_min += config['RHO'] * loss_U_down.detach()
            epoch_train_loss += loss.item()
        train_loss_history.append(epoch_train_loss / config['CONTROLLER_STEPS_PER_EPOCH'])
        decision_maker.eval()
        with torch.no_grad():
            u_val = decision_maker(zeta_val)
            trunk_input_val = torch.cat([sensor_locations_torch.repeat(512, 1), torch.full((config['M_SENSORS'] * 512, 1), config['T_FINAL'], device=DEVICE)], dim=1)
            branch_input_val = torch.repeat_interleave(u_val, repeats=config['M_SENSORS'], dim=0)
            V_predicted_at_T_val = physics_simulator(branch_input_val, trunk_input_val).view(512, config['M_SENSORS'])
            val_loss = mse_loss_fn(V_predicted_at_T_val, zeta_val)
            val_loss_history.append(val_loss.item())
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{config['CONTROLLER_EPOCHS']}, Train Loss: {train_loss_history[-1]:.4e}, Val Loss: {val_loss_history[-1]:.4e}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(decision_maker.state_dict(), best_model_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {patience} epochs without improvement.")
                break
    print(f"\n--- Controller Training Finished. Best model for '{args.run_id}' saved to {best_model_path} ---")

    # --- 6. Save Final Outputs ---
    plt.figure(figsize=(10, 6))
    plt.plot(train_loss_history, label='Training Loss')
    plt.plot(val_loss_history, label='Validation Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title(f'Controller Loss (run: {args.run_id})')
    plt.yscale('log'); plt.legend(); plt.grid(True)
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f:
        yaml.dump(vars(args), f)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a direct controller using a pre-trained DeepONet.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--deeponet_run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--hidden_dim", type=int, required=True)
    parser.add_argument("--num_layers", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    parser.add_argument("--gamma", type=float, required=True)
    args = parser.parse_args()
    main(args)