# src/train_propagator_deeponet_2d.py
# Cluster-ready script to train the one-step PropagatorDeepONet for the 2D Heat Equation.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import yaml
import argparse
import matplotlib.pyplot as plt

# --- CHANGE: Import from the 2D models file ---
from data_and_models_2d import PropagatorDeepONet

# --- CHANGE: Dataset class adapted for 2D config keys ---
class PropagatorDataset(Dataset):
    def __init__(self, filepath, config):
        data = np.load(filepath)
        
        # Total number of basis functions is now the product of basis in each dimension
        num_basis_functions = config['NUM_BASIS_X'] * config['NUM_BASIS_Y']
        # Total number of sensors is the product of sensors on each axis
        num_sensors = config['NX_SENSORS'] * config['NY_SENSORS']
        
        # Data is already flattened in the generation script, so we just use the calculated totals
        self.controls = torch.from_numpy(data['control_sequences'][:, :-1, :]).float().reshape(-1, num_basis_functions)
        self.state_inputs = torch.from_numpy(data['state_sequences'][:, :-1, :]).float().reshape(-1, num_sensors)
        self.state_outputs = torch.from_numpy(data['state_sequences'][:, 1:, :]).float().reshape(-1, num_sensors)

    def __len__(self):
        return self.controls.shape[0]

    def __getitem__(self, idx):
        return self.state_inputs[idx], self.controls[idx], self.state_outputs[idx]

def main(args):
    # --- 1. Load Configs and Setup ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training 2D PropagatorDeepONet: {args.run_id} ---")
    print(f"Using device: {DEVICE}")

    # --- 2. DataLoaders ---
    data_dir = os.path.join(args.output_base_dir, "data")
    train_path = os.path.join(data_dir, "train_trajectories_2d.npz")
    test_path = os.path.join(data_dir, "test_trajectories_2d.npz")
    if not os.path.exists(train_path):
        print(f"FATAL: Training data not found at {train_path}"); return

    train_loader = DataLoader(PropagatorDataset(train_path, config), batch_size=config['BATCH_SIZE'], shuffle=True)
    val_loader = DataLoader(PropagatorDataset(test_path, config), batch_size=config['BATCH_SIZE'], shuffle=False)

    # --- 3. Model, Optimizer ---
    model_kwargs = {k: v for k, v in vars(args).items() if k in ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']}
    
    # --- CHANGE: Calculate total sensors and basis functions for model instantiation ---
    total_sensors = config['NX_SENSORS'] * config['NY_SENSORS']
    total_basis_functions = config['NUM_BASIS_X'] * config['NUM_BASIS_Y']
    
    model = PropagatorDeepONet(
        M_sensors=total_sensors,
        num_basis_functions=total_basis_functions,
        trunk_input_dim=config['TRUNK_INPUT_DIM'], # This should be 2 from the config
        **model_kwargs
    ).to(DEVICE)
    
    criterion = nn.MSELoss()
    if args.optimizer.lower() == 'adamw': optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
    else: optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)

    # --- 4. Training Loop with Early Stopping ---
    print("\n--- Starting Training ---")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    patience = 100
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    best_model_path = os.path.join(output_dir, "propagator_deeponet_2d_best.pth")
    
    # --- CHANGE: Create a 2D grid of sensor locations ---
    x_locs = np.linspace(0, config['L_X'], config['NX_SENSORS'])
    y_locs = np.linspace(0, config['L_Y'], config['NY_SENSORS'])
    grid_x, grid_y = np.meshgrid(x_locs, y_locs)
    # Stack to create (x,y) pairs and reshape to (num_sensors, 2)
    sensor_locs_np = np.stack([grid_x.ravel(), grid_y.ravel()], axis=-1)
    # Convert to tensor for the model, add a batch dimension
    sensor_locs = torch.from_numpy(sensor_locs_np).float().to(DEVICE).unsqueeze(0)
    
    train_losses, val_losses = [], []

    for epoch in range(config['EPOCHS']):
        model.train(); epoch_train_loss = 0.0
        for T_k, w_k, T_k_plus_1 in train_loader:
            T_k, w_k, T_k_plus_1 = T_k.to(DEVICE), w_k.to(DEVICE), T_k_plus_1.to(DEVICE)
            optimizer.zero_grad()
            # The model call is unchanged, but sensor_locs is now a 2D grid
            T_pred = model(T_k, w_k, sensor_locs).squeeze(-1)
            loss = criterion(T_pred, T_k_plus_1)
            loss.backward(); optimizer.step()
            epoch_train_loss += loss.item()
        train_losses.append(epoch_train_loss / len(train_loader))

        model.eval(); epoch_val_loss = 0.0
        with torch.no_grad():
            for T_k, w_k, T_k_plus_1 in val_loader:
                T_k, w_k, T_k_plus_1 = T_k.to(DEVICE), w_k.to(DEVICE), T_k_plus_1.to(DEVICE)
                T_pred = model(T_k, w_k, sensor_locs).squeeze(-1)
                epoch_val_loss += criterion(T_pred, T_k_plus_1).item()
        val_losses.append(epoch_val_loss / len(val_loader))

        print(f"Epoch {epoch+1}/{config['EPOCHS']} | Train Loss: {train_losses[-1]:.4e} | Val Loss: {val_losses[-1]:.4e}")
        scheduler.step(val_losses[-1])

        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"--- Early stopping triggered after epoch {epoch+1} ---"); break
    
    print(f"\nTraining finished. Best model saved to {best_model_path}")
    
    # --- 5. Save Outputs ---
    with open(os.path.join(output_dir, 'hyperparams_propagator_2d.yaml'), 'w') as f: yaml.dump(vars(args), f)
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss'); plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.title('Propagator 2D Loss Curve'); plt.legend()
    plt.savefig(os.path.join(output_dir, "loss_curve_propagator_2d.png"))

# The argparse section remains unchanged as it defines the interface for running the script.
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the 2D Propagator DeepONet model.")
    parser.add_argument("--config_path", type=str, required=True, help="Path to the 2D config YAML file.")
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--latent_dim", type=int, required=True)
    parser.add_argument("--branch_depth", type=int, required=True)
    parser.add_argument("--branch_width", type=int, required=True)
    parser.add_argument("--trunk_depth", type=int, required=True)
    parser.add_argument("--trunk_width", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    args = parser.parse_args()
    main(args)