# src/train_propagator_deeponet.py
# Cluster-ready script to train the one-step PropagatorDeepONet for the Burgers' Equation.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import yaml
import argparse
import matplotlib.pyplot as plt

# Import the configurable model definition
from data_and_models import PropagatorDeepONet

# MODIFIED: Dataset class now handles viscosity profiles
class PropagatorDataset(Dataset):
    """ Custom dataset to provide one-step samples ( (x_k, u_k, v(x)) -> x_{k+1} ) """
    def __init__(self, filepath, config):
        print(f"Loading data for Burgers' Propagator from {filepath}...")
        data = np.load(filepath)
        
        # We need to reshape the state and control sequences to be flat lists of one-step transitions
        # From (num_sims, num_time_steps, dim) to (num_sims * (num_time_steps-1), dim)
        num_sims = data['state_sequences'].shape[0]
        num_transitions = data['state_sequences'].shape[1] - 1

        self.controls = torch.from_numpy(data['control_sequences'][:, :-1, :]).float().reshape(-1, config['NUM_BASIS_FUNCTIONS'])
        self.state_inputs = torch.from_numpy(data['state_sequences'][:, :-1, :]).float().reshape(-1, config['M_SENSORS'])
        self.state_outputs = torch.from_numpy(data['state_sequences'][:, 1:, :]).float().reshape(-1, config['M_SENSORS'])

        # ADDED: Load and process viscosity profiles
        # CRITICAL FIX: Changed key from 'viscosity_sequences' to 'viscosity_profiles'
        if 'viscosity_profiles' in data:
            print("Found viscosity profiles in dataset.")
            visc_profiles = torch.from_numpy(data['viscosity_profiles']).float() # Shape: (num_sims, M_SENSORS)
            # We need to repeat each viscosity profile for all time steps in that trajectory
            self.viscosities = visc_profiles.repeat_interleave(num_transitions, dim=0)
            assert self.viscosities.shape[0] == self.state_inputs.shape[0]
        else:
            # POLISHED: Improved robustness of the fallback
            print("No viscosity profiles found. Running in constant viscosity mode.")
            default_visc_value = np.mean(config.get('VISCOSITY_RANGE', [0.03])) # Defaults to 0.03 if range isn't there
            const_visc = torch.full((self.state_inputs.shape[0], config['M_SENSORS']), default_visc_value)
            self.viscosities = const_visc

        print("Data loaded successfully.")

    def __len__(self):
        return self.controls.shape[0]

    def __getitem__(self, idx):
        # MODIFIED: Return the viscosity profile along with the other data
        return self.state_inputs[idx], self.controls[idx], self.viscosities[idx], self.state_outputs[idx]

def main(args):
    # --- 1. Load Configs and Setup ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training Burgers' Propagator: {args.run_id} ---")
    print(f"Using device: {DEVICE}")

    # --- 2. DataLoaders (assumes data has been generated) ---
    data_dir = os.path.join(args.output_base_dir, "data")
    train_path = os.path.join(data_dir, "train_trajectories_variable.npz")
    test_path = os.path.join(data_dir, "test_trajectories_variable.npz")
    if not os.path.exists(train_path):
        print(f"WARNING: Variable dataset not found at {train_path}. Trying original name...")
        train_path = os.path.join(data_dir, "train_trajectories.npz")
        test_path = os.path.join(data_dir, "test_trajectories.npz")
        if not os.path.exists(train_path):
            print(f"FATAL: Training data not found. Run generate_data first."); return

    train_loader = DataLoader(PropagatorDataset(train_path, config), batch_size=config['BATCH_SIZE'], shuffle=True)
    val_loader = DataLoader(PropagatorDataset(test_path, config), batch_size=config['BATCH_SIZE'] * 2, shuffle=False)

    # --- 3. Model, Optimizer ---
    # POLISHED: Cleaned up the argument passing to the model
    model_arg_names = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    model_kwargs = {name: getattr(args, name) for name in model_arg_names}
    model = PropagatorDeepONet(
        M_sensors=config['M_SENSORS'],
        num_basis_functions=config['NUM_BASIS_FUNCTIONS'],
        trunk_input_dim=config['TRUNK_INPUT_DIM'],
        **model_kwargs
    ).to(DEVICE)
    
    criterion = nn.MSELoss()
    if args.optimizer.lower() == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
    else: # Default to Adam
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)

    # --- 4. Training Loop with Early Stopping ---
    print("\n--- Starting Training ---")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    patience = 50
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    best_model_path = os.path.join(output_dir, "burgers_propagator_best.pth")
    
    x_grid_sensors = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    train_losses, val_losses = [], []

    for epoch in range(config['EPOCHS']):
        model.train(); epoch_train_loss = 0.0
        # MODIFIED: Unpack the new viscosity tensor from the loader
        for x_k, u_k, v_profile, x_k_plus_1 in train_loader:
            x_k, u_k, v_profile, x_k_plus_1 = x_k.to(DEVICE), u_k.to(DEVICE), v_profile.to(DEVICE), x_k_plus_1.to(DEVICE)
            optimizer.zero_grad()
            # MODIFIED: Pass the viscosity profile to the model's forward pass
            x_pred = model(x_k, u_k, v_profile, x_grid_sensors).squeeze(-1)
            loss = criterion(x_pred, x_k_plus_1)
            loss.backward(); optimizer.step()
            epoch_train_loss += loss.item()
        train_losses.append(epoch_train_loss / len(train_loader))

        model.eval(); epoch_val_loss = 0.0
        with torch.no_grad():
            # MODIFIED: Unpack the new viscosity tensor from the loader
            for x_k, u_k, v_profile, x_k_plus_1 in val_loader:
                x_k, u_k, v_profile, x_k_plus_1 = x_k.to(DEVICE), u_k.to(DEVICE), v_profile.to(DEVICE), x_k_plus_1.to(DEVICE)
                # MODIFIED: Pass the viscosity profile to the model's forward pass
                x_pred = model(x_k, u_k, v_profile, x_grid_sensors).squeeze(-1)
                epoch_val_loss += criterion(x_pred, x_k_plus_1).item()
        val_losses.append(epoch_val_loss / len(val_loader))

        print(f"Epoch {epoch+1}/{config['EPOCHS']} | Train Loss: {train_losses[-1]:.4e} | Val Loss: {val_losses[-1]:.4e}")
        scheduler.step(val_losses[-1])

        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"--- Early stopping triggered after epoch {epoch+1} ---"); break
    
    print(f"\nTraining finished. Best model for run '{args.run_id}' saved to {best_model_path}")
    
    # --- 5. Save Outputs ---
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss'); plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.title(f'Loss Curve ({args.run_id})')
    plt.legend(); plt.grid(True)
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Propagator DeepONet for Burgers' Equation.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--latent_dim", type=int, required=True)
    parser.add_argument("--branch_depth", type=int, required=True)
    parser.add_argument("--branch_width", type=int, required=True)
    parser.add_argument("--trunk_depth", type=int, required=True)
    parser.add_argument("--trunk_width", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    args = parser.parse_args()
    main(args)