# src/train_propagator_deeponet.py
# Cluster-ready script to train the one-step PropagatorDeepONet.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import yaml
import argparse
import matplotlib.pyplot as plt

from data_and_models import PropagatorDeepONet

class PropagatorDataset(Dataset):
    def __init__(self, filepath, config):
        data = np.load(filepath)
        self.controls = torch.from_numpy(data['control_sequences'][:, :-1, :]).float().reshape(-1, config['NUM_BASIS_FUNCTIONS'])
        self.state_inputs = torch.from_numpy(data['state_sequences'][:, :-1, :]).float().reshape(-1, config['M_SENSORS'])
        self.state_outputs = torch.from_numpy(data['state_sequences'][:, 1:, :]).float().reshape(-1, config['M_SENSORS'])
    def __len__(self): return self.controls.shape[0]
    def __getitem__(self, idx): return self.state_inputs[idx], self.controls[idx], self.state_outputs[idx]

def main(args):
    # --- 1. Load Configs and Setup ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)

    # --- Override config with command-line arg if provided ---
    if args.num_basis_functions is not None:
        print(f"Overriding NUM_BASIS_FUNCTIONS from config. Using M = {args.num_basis_functions}.")
        config['NUM_BASIS_FUNCTIONS'] = args.num_basis_functions

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training PropagatorDeepONet: {args.run_id} (M={config['NUM_BASIS_FUNCTIONS']}) ---")
    print(f"Using device: {DEVICE}")

    # --- 2. DataLoaders ---
    data_dir = os.path.join(args.output_base_dir, "data")
    train_path = os.path.join(data_dir, f"train_trajectories_m{config['NUM_BASIS_FUNCTIONS']}.npz")
    test_path = os.path.join(data_dir, f"test_trajectories_m{config['NUM_BASIS_FUNCTIONS']}.npz")
    if not os.path.exists(train_path):
        print(f"FATAL: Training data for M={config['NUM_BASIS_FUNCTIONS']} not found at {train_path}"); return

    train_loader = DataLoader(PropagatorDataset(train_path, config), batch_size=config['BATCH_SIZE'], shuffle=True)
    val_loader = DataLoader(PropagatorDataset(test_path, config), batch_size=config['BATCH_SIZE'], shuffle=False)

    # --- 3. Model, Optimizer ---
    model_kwargs = {k: v for k, v in vars(args).items() if k in ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']}
    model = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **model_kwargs).to(DEVICE)
    
    criterion = nn.MSELoss()
    if args.optimizer.lower() == 'adamw': optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
    else: optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)

    # --- 4. Training Loop with Early Stopping ---
    print("\n--- Starting Training ---")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    patience = 100
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    best_model_path = os.path.join(output_dir, "propagator_deeponet_best.pth")
    
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    train_losses, val_losses = [], []

    for epoch in range(config['EPOCHS']):
        model.train(); epoch_train_loss = 0.0
        for T_k, w_k, T_k_plus_1 in train_loader:
            T_k, w_k, T_k_plus_1 = T_k.to(DEVICE), w_k.to(DEVICE), T_k_plus_1.to(DEVICE)
            optimizer.zero_grad()
            T_pred = model(T_k, w_k, sensor_locs).squeeze(-1)
            loss = criterion(T_pred, T_k_plus_1)
            loss.backward(); optimizer.step()
            epoch_train_loss += loss.item()
        train_losses.append(epoch_train_loss / len(train_loader))

        model.eval(); epoch_val_loss = 0.0
        with torch.no_grad():
            for T_k, w_k, T_k_plus_1 in val_loader:
                T_k, w_k, T_k_plus_1 = T_k.to(DEVICE), w_k.to(DEVICE), T_k_plus_1.to(DEVICE)
                T_pred = model(T_k, w_k, sensor_locs).squeeze(-1)
                epoch_val_loss += criterion(T_pred, T_k_plus_1).item()
        val_losses.append(epoch_val_loss / len(val_loader))

        print(f"Epoch {epoch+1}/{config['EPOCHS']} | Train Loss: {train_losses[-1]:.4e} | Val Loss: {val_losses[-1]:.4e}")
        scheduler.step(val_losses[-1])

        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"--- Early stopping triggered after epoch {epoch+1} ---"); break
    
    print(f"\nTraining finished. Best model saved to {best_model_path}")
    
    # --- 5. Save Outputs ---
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss'); plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.title('Loss Curve'); plt.legend()
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Propagator DeepONet model.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--latent_dim", type=int, required=True)
    parser.add_argument("--branch_depth", type=int, required=True)
    parser.add_argument("--branch_width", type=int, required=True)
    parser.add_argument("--trunk_depth", type=int, required=True)
    parser.add_argument("--trunk_width", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    # --- NEW ARGUMENT ---
    parser.add_argument("--num_basis_functions", type=int, default=None, help="Override NUM_BASIS_FUNCTIONS from config file.")
    args = parser.parse_args()
    main(args)