
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import yaml
import argparse
import matplotlib.pyplot as plt

# Import the configurable model definition from our updated file
from data_and_models import PropagatorDeepONet

class PropagatorDataset(Dataset):

    def __init__(self, filepath, config):
        print(f"Loading direct control data for Burgers' Propagator from {filepath}...")
        data = np.load(filepath)
        
        # 'control_sequences' now holds the control values at each of the M_SENSORS locations.
        # Shape: (num_sims, num_timesteps-1, M_SENSORS) -> flattened
        self.controls_at_sensors = torch.from_numpy(data['control_sequences'][:, :-1, :]).float().reshape(-1, config['M_SENSORS'])
        
        # State inputs are x_k at sensor locations
        self.state_inputs = torch.from_numpy(data['state_sequences'][:, :-1, :]).float().reshape(-1, config['M_SENSORS'])
        
        # State outputs are the target x_{k+1} at sensor locations
        self.state_outputs = torch.from_numpy(data['state_sequences'][:, 1:, :]).float().reshape(-1, config['M_SENSORS'])
        
        assert self.state_inputs.shape[0] == self.controls_at_sensors.shape[0]
        print(f"Data loaded successfully. Found {len(self)} one-step samples.")

    def __len__(self):
        return self.state_inputs.shape[0]

    def __getitem__(self, idx):
        return self.state_inputs[idx], self.controls_at_sensors[idx], self.state_outputs[idx]

def main(args):
    # --- 1. Load Configs and Setup ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Training Burgers' Propagator (Direct Control): {args.run_id} ---")
    print(f"Using device: {DEVICE}")

    # --- 2. DataLoaders ---
    data_dir = os.path.join(args.output_base_dir, "data")
    train_path = os.path.join(data_dir, "train_trajectories.npz")
    test_path = os.path.join(data_dir, "test_trajectories.npz")
    if not os.path.exists(train_path):
        print(f"FATAL: Training data not found at {train_path}. Run generate_data first."); return

    train_loader = DataLoader(PropagatorDataset(train_path, config), batch_size=config['BATCH_SIZE'], shuffle=True)
    val_loader = DataLoader(PropagatorDataset(test_path, config), batch_size=config['BATCH_SIZE'] * 2, shuffle=False)

    # --- 3. Model, Optimizer ---
    model_kwargs = {k: v for k, v in vars(args).items() if k in ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']}
    model = PropagatorDeepONet(
        M_sensors=config['M_SENSORS'],
        num_basis_functions=config['NUM_BASIS_FUNCTIONS'], # Note: This arg is kept for API consistency but is unused by the direct control model's branch network.
        trunk_input_dim=config['TRUNK_INPUT_DIM'],
        **model_kwargs
    ).to(DEVICE)
    
    criterion = nn.MSELoss()
    if args.optimizer.lower() == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
    else: # Default to Adam
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)

    # --- 4. Training Loop with Early Stopping ---
    print("\n--- Starting Training ---")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    patience = 50
    output_dir = os.path.join(args.output_base_dir, args.run_id)
    os.makedirs(output_dir, exist_ok=True)
    best_model_path = os.path.join(output_dir, "burgers_propagator_best.pth")
    
    x_grid_sensors = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    train_losses, val_losses = [], []

    for epoch in range(config['EPOCHS']):
        model.train(); epoch_train_loss = 0.0
        # The second item from the loader is now u_k at sensor locations
        for x_k, u_k_at_sensors, x_k_plus_1 in train_loader:
            x_k, u_k_at_sensors, x_k_plus_1 = x_k.to(DEVICE), u_k_at_sensors.to(DEVICE), x_k_plus_1.to(DEVICE)
            
            optimizer.zero_grad()
            
            # Pass the state and control values at sensor locations to the model
            x_pred = model(x_k, u_k_at_sensors, x_grid_sensors).squeeze(-1)
            
            loss = criterion(x_pred, x_k_plus_1)
            loss.backward(); optimizer.step()
            epoch_train_loss += loss.item()
        train_losses.append(epoch_train_loss / len(train_loader))

        model.eval(); epoch_val_loss = 0.0
        with torch.no_grad():
            for x_k, u_k_at_sensors, x_k_plus_1 in val_loader:
                x_k, u_k_at_sensors, x_k_plus_1 = x_k.to(DEVICE), u_k_at_sensors.to(DEVICE), x_k_plus_1.to(DEVICE)
                
                # The model call is the same, but the data in u_k_at_sensors is different
                x_pred = model(x_k, u_k_at_sensors, x_grid_sensors).squeeze(-1)
                
                epoch_val_loss += criterion(x_pred, x_k_plus_1).item()
        val_losses.append(epoch_val_loss / len(val_loader))

        print(f"Epoch {epoch+1}/{config['EPOCHS']} | Train Loss: {train_losses[-1]:.4e} | Val Loss: {val_losses[-1]:.4e}")
        scheduler.step(val_losses[-1])

        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"--- Early stopping triggered after epoch {epoch+1} ---"); break
    
    print(f"\nTraining finished. Best model for run '{args.run_id}' saved to {best_model_path}")
    
    # --- 5. Save Outputs ---
    with open(os.path.join(output_dir, 'hyperparams.yaml'), 'w') as f: yaml.dump(vars(args), f)
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss'); plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.title(f'Loss Curve ({args.run_id})')
    plt.legend(); plt.grid(True)
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the Propagator DeepONet for Burgers' Equation with Direct Control.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--optimizer", type=str, required=True, choices=['adam', 'adamw'])
    parser.add_argument("--latent_dim", type=int, required=True)
    parser.add_argument("--branch_depth", type=int, required=True)
    parser.add_argument("--branch_width", type=int, required=True)
    parser.add_argument("--trunk_depth", type=int, required=True)
    parser.add_argument("--trunk_width", type=int, required=True)
    parser.add_argument("--activation_fn", type=str, required=True, choices=['relu', 'tanh'])
    args = parser.parse_args()
    main(args)