# src/train_propagator_deeponet.py
# Cluster-ready script to train the one-step PropagatorDeepONet for the Burgers' Equation.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import yaml
import argparse
import matplotlib.pyplot as plt

# Import the configurable model definition
from data_and_models import PropagatorDeepONet

class PropagatorDataset(Dataset):
    """ Custom dataset to provide one-step samples ( (x_k, u_k) -> x_{k+1} ) """
    def __init__(self, filepath, config):
        print(f"Loading data for Burgers' Propagator from {filepath}...")
        data = np.load(filepath)
        self.controls = torch.from_numpy(data['control_sequences'][:, :-1, :]).float().reshape(-1, config['NUM_BASIS_FUNCTIONS'])
        self.state_inputs = torch.from_numpy(data['state_sequences'][:, :-1, :]).float().reshape(-1, config['M_SENSORS'])
        self.state_outputs = torch.from_numpy(data['state_sequences'][:, 1:, :]).float().reshape(-1, config['M_SENSORS'])
        print("Data loaded successfully.")

    def __len__(self):
        return self.controls.shape[0]

    def __getitem__(self, idx):
        return self.state_inputs[idx], self.controls[idx], self.state_outputs[idx]

def main(args):
    # --- 1. Load Configs and Setup ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training Burgers' Propagator: {args.run_id} ---")
    print(f"Using device: {DEVICE}")

    # --- 2. DataLoaders (assumes data has been generated) ---
    data_dir = os.path.join(args.output_base_dir, "data")
    train_path = os.path.join(data_dir, "train_trajectories.npz")
    test_path = os.path.join(data_dir, "test_trajectories.npz")
    if not os.path.exists(train_path):
        print(f"FATAL: Training data not found at {train_path}. Run generate_data first."); return

    train_loader = DataLoader(PropagatorDataset(train_path, config), batch_size=config['BATCH_SIZE'], shuffle=True)
    val_loader = DataLoader(PropagatorDataset(test_path, config), batch_size=config['BATCH_SIZE'] * 2, shuffle=False)

    # --- 3. Model, Optimizer ---
    model_kwargs = {k: v for k, v in vars(args).items() if k in ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']}
    model = PropagatorDeepONet(
        M_sensors=config['M_SENSORS'],
        num_basis_functions=config['NUM_BASIS_FUNCTIONS'],
        trunk_input_dim=config['TRUNK_INPUT_DIM'],
        **model_kwargs
    ).to(DEVICE)
    
    criterion = nn.MSELoss()
    if args.optimizer.lower() == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
    else: # Default to Adam
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)

    # --- 4. Training Loop with Early Stopping ---
    print("\n--- Starting Training ---")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    patience = 50
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    best_model_path = os.path.join(output_dir, "burgers_propagator_best.pth")
    
    x_grid_sensors = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    train_losses, val_losses = [], []

    for epoch in range(config['EPOCHS']):
        model.train(); epoch_train_loss = 0.0
        for x_k, u_k, x_k_plus_1 in train_loader:
            x_k, u_k, x_k_plus_1 = x_k.to(DEVICE), u_k.to(DEVICE), x_k_plus_1.to(DEVICE)
            optimizer.zero_grad()
            x_pred = model(x_k, u_k, x_grid_sensors).squeeze(-1)
            loss = criterion(x_pred, x_k_plus_1)
            loss.backward(); optimizer.step()
            epoch_train_loss += loss.item()
        train_losses.append(epoch_train_loss / len(train_loader))

        model.eval(); epoch_val_loss = 0.0
        with torch.no_grad():
            for x_k, u_k, x_k_plus_1 in val_loader:
                x_k, u_k, x_k_plus_1 = x_k.to(DEVICE), u_k.to(DEVICE), x_k_plus_1.to(DEVICE)
                x_pred = model(x_k, u_k, x_grid_sensors).squeeze(-1)
                epoch_val_loss += criterion(x_pred, x_k_plus_1).item()
        val_losses.append(epoch_val_loss / len(val_loader))

        print(f"Epoch {epoch+1}/{config['EPOCHS']} | Train Loss: {train_losses[-1]:.4e} | Val Loss: {val_losses[-1]:.4e}")
        scheduler.step(val_losses[-1])

        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"--- Early stopping triggered after epoch {epoch+1} ---"); break
    
    print(f"\nTraining finished. Best model for run '{args.run_id}' saved to {best_model_path}")
    
    # --- 5. Save Outputs ---
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss'); plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.title(f'Loss Curve ({args.run_id})')
    plt.legend(); plt.grid(True)
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Propagator DeepONet for Burgers' Equation.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--latent_dim", type=int, required=True)
    parser.add_argument("--branch_depth", type=int, required=True)
    parser.add_argument("--branch_width", type=int, required=True)
    parser.add_argument("--trunk_depth", type=int, required=True)
    parser.add_argument("--trunk_width", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    args = parser.parse_args()
    main(args)