# src/train_deeponet.py
# FINAL version with the missing --trunk_depth argument fixed.

import sys
print("DEBUG: Starting Python interpreter.", flush=True)

try:
    # --- Import Block ---
    import torch
    print("DEBUG: import torch ... OK", flush=True)
    import torch.nn as nn
    print("DEBUG: import torch.nn ... OK", flush=True)
    import torch.optim as optim
    print("DEBUG: import torch.optim ... OK", flush=True)
    from torch.utils.data import TensorDataset, DataLoader
    print("DEBUG: from torch.utils.data ... OK", flush=True)
    import numpy as np
    print("DEBUG: import numpy ... OK", flush=True)
    import os
    print("DEBUG: import os ... OK", flush=True)
    import argparse
    print("DEBUG: import argparse ... OK", flush=True)
    import yaml
    print("DEBUG: import yaml ... OK", flush=True)
    import matplotlib
    matplotlib.use('Agg')
    print("DEBUG: matplotlib.use('Agg') ... OK", flush=True)
    import matplotlib.pyplot as plt
    print("DEBUG: import matplotlib.pyplot ... OK", flush=True)
    from data_and_models import DeepONetWithBias
    print("DEBUG: from data_and_models ... OK", flush=True)
    print("\nDEBUG: All imports were successful.\n", flush=True)

except Exception as e:
    print(f"CRITICAL ERROR during import: {e}", file=sys.stderr, flush=True)
    print(f"CRITICAL ERROR during import: {e}", flush=True) 
    exit(1)

def main(args):
    print("DEBUG: Inside main(). Loading config.", flush=True)
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {DEVICE}", flush=True)
    print(f"Running with arguments: {args}", flush=True)
    train_data_path = os.path.join(args.output_base_dir, "data", "train_data.npz")
    test_data_path = os.path.join(args.output_base_dir, "data", "test_data.npz")
    if not os.path.exists(train_data_path):
        print(f"FATAL ERROR: Data not found at {train_data_path}. Please run data generation first.", flush=True)
        exit(1)
    print(f"DEBUG: Loading training data from {train_data_path}...", flush=True)
    train_data = np.load(train_data_path)
    print("DEBUG: Training data loaded. Creating DataLoader.", flush=True)
    train_loader = DataLoader(TensorDataset(torch.from_numpy(train_data['branch_inputs']).float(), torch.from_numpy(train_data['trunk_inputs']).float(), torch.from_numpy(train_data['outputs']).float()), batch_size=config['BATCH_SIZE'], shuffle=True)
    print(f"DEBUG: Loading validation data from {test_data_path}...", flush=True)
    val_data = np.load(test_data_path)
    print("DEBUG: Validation data loaded. Creating DataLoader.", flush=True)
    val_loader = DataLoader(TensorDataset(torch.from_numpy(val_data['branch_inputs']).float(), torch.from_numpy(val_data['trunk_inputs']).float(), torch.from_numpy(val_data['outputs']).float()), batch_size=config['BATCH_SIZE'] * 2, shuffle=False)
    print("DEBUG: Initializing model.", flush=True)
    model = DeepONetWithBias(branch_input_dim=config['M_SENSORS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], latent_dim=args.latent_dim, branch_depth=args.branch_depth, branch_width=args.branch_width, trunk_depth=args.trunk_depth, trunk_width=args.trunk_width, activation_fn=args.activation_fn).to(DEVICE)
    criterion = nn.MSELoss()
    print("DEBUG: Selecting optimizer.", flush=True)
    if args.optimizer.lower() == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    elif args.optimizer.lower() == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
    elif args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate)
    else:
        raise ValueError(f"Unsupported optimizer: {args.optimizer}")
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)
    print("\n--- Starting DeepONet Surrogate Training ---", flush=True)
    train_loss_history, val_loss_history = [], []
    for epoch in range(config['EPOCHS']):
        model.train()
        epoch_train_loss = 0.0
        for branch, trunk, out in train_loader:
            branch, trunk, out = branch.to(DEVICE), trunk.to(DEVICE), out.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(branch, trunk), out)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_train_loss += loss.item()
        train_loss_history.append(epoch_train_loss / len(train_loader))
        model.eval()
        epoch_val_loss = 0.0
        with torch.no_grad():
            for branch, trunk, out in val_loader:
                branch, trunk, out = branch.to(DEVICE), trunk.to(DEVICE), out.to(DEVICE)
                epoch_val_loss += criterion(model(branch, trunk), out).item()
        val_loss_history.append(epoch_val_loss / len(val_loader))
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{config['EPOCHS']}, Train Loss: {train_loss_history[-1]:.4e}, Val Loss: {val_loss_history[-1]:.4e}", flush=True)
    print("\n--- Training Finished. Saving results. ---", flush=True)
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    model_path = os.path.join(output_dir, "deeponet_model.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}", flush=True)
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f:
        yaml.dump(vars(args), f)
    print(f"Hyperparameters saved to {os.path.join(output_dir, 'hyperparams.yaml')}", flush=True)
    plt.figure(figsize=(10, 6))
    plt.plot(train_loss_history, label='Training Loss')
    plt.plot(val_loss_history, label='Validation Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss (MSE)'); plt.title(f'DeepONet Loss (run: {args.run_id})')
    plt.yscale('log'); plt.legend(); plt.grid(True, which="both", ls="--")
    plot_path = os.path.join(output_dir, "loss_curve.png")
    plt.savefig(plot_path)
    print(f"Loss curve plot saved to {plot_path}", flush=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a DeepONet surrogate model on the cluster.")
    parser.add_argument("--config_path", type=str, required=True, help="Path to the base YAML config file.")
    parser.add_argument("--output_base_dir", type=str, required=True, help="Base directory for data and results.")
    parser.add_argument("--run_id", type=str, required=True, help="A unique name for this training run.")
    
    parser.add_argument("--learning_rate", type=float, required=True, help="Optimizer learning rate.")
    parser.add_argument("--latent_dim", type=int, required=True, help="Latent dimension size.")
    parser.add_argument("--branch_depth", type=int, required=True, help="Number of layers in the branch network.")
    parser.add_argument("--branch_width", type=int, required=True, help="Width of the branch network layers.")
    parser.add_argument("--trunk_width", type=int, required=True, help="Width of the trunk network layers.")
    
    parser.add_argument("--trunk_depth", type=int, required=True, help="Number of layers in the trunk network.")
    
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh', 'silu'], help="Activation function to use.")
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw', 'sgd'], help="Optimizer to use.")
    
    args = parser.parse_args()
    main(args)