# gnot_stepper.py
# =============================================================================
#       GNOT-inspired Time-Stepping Neural Operator (Adapted for Task)
# =============================================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import time
import pickle
import argparse

# ---------------------
# Fixed random seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# ---------------------

print(f"GNOT-Stepper Script (Task Adapted) started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# =============================================================================
# 0. UniversalPDEDataset (Identical to your BENO script)
#    This ensures the data is handled in exactly the same way.
# =============================================================================
class UniversalPDEDataset(Dataset):
    def __init__(self, data_list, dataset_type, train_nt_limit=None):
        if not data_list: raise ValueError("data_list cannot be empty")
        self.data_list = data_list
        self.dataset_type = dataset_type.lower()
        self.train_nt_limit = train_nt_limit

        first_sample = data_list[0]
        params = first_sample.get('params', {})
        self.nt_from_sample_file, self.nx_from_sample_file, self.ny_from_sample_file = 0, 0, 1

        if self.dataset_type in ['advection', 'burgers', 'heat_delayed_feedback', 'reaction_diffusion_neumann_feedback', 'heat_nonlinear_feedback_gain', 'convdiff']:
            self.nt_from_sample_file = first_sample['U'].shape[0]
            self.nx_from_sample_file = first_sample['U'].shape[1]
            self.state_keys, self.num_state_vars, self.expected_bc_state_dim = ['U'], 1, 2
        elif self.dataset_type == 'euler':
            self.nt_from_sample_file, self.nx_from_sample_file = first_sample['rho'].shape[0]
            self.state_keys, self.num_state_vars, self.expected_bc_state_dim = ['rho', 'u'], 2, 4
        # ... (Add other dataset types if necessary)
        else: raise ValueError(f"Unknown dataset_type: {self.dataset_type}")

        self.effective_nt_for_loader = self.train_nt_limit if self.train_nt_limit is not None else self.nt_from_sample_file
        self.nx, self.ny = self.nx_from_sample_file, self.ny_from_sample_file

        self.bc_state_key = 'BC_State'
        self.bc_state_dim = first_sample[self.bc_state_key].shape[1]

        self.bc_control_key = 'BC_Control'
        if self.bc_control_key in first_sample and first_sample[self.bc_control_key] is not None and first_sample[self.bc_control_key].size > 0:
            self.num_controls = first_sample[self.bc_control_key].shape[1]
        else:
            self.num_controls = 0

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        # This function is identical to your BENO script's __getitem__
        # It handles normalization per sample, which is a good practice.
        sample = self.data_list[idx]
        norm_factors = {}
        current_nt_for_item = self.effective_nt_for_loader
        state_tensors_norm_list = []

        for key in self.state_keys:
            state_seq = sample[key][:current_nt_for_item, ...]
            state_mean, state_std = np.mean(state_seq), np.std(state_seq) + 1e-8
            state_norm = (state_seq - state_mean) / state_std
            state_tensors_norm_list.append(torch.tensor(state_norm).float())
            norm_factors[f'{key}_mean'] = state_mean
            norm_factors[f'{key}_std'] = state_std

        bc_state_seq = sample[self.bc_state_key][:current_nt_for_item, :]
        bc_state_norm = np.zeros_like(bc_state_seq, dtype=np.float32)
        norm_factors[f'{self.bc_state_key}_means'] = np.mean(bc_state_seq, axis=0) if bc_state_seq.size > 0 else np.zeros(self.bc_state_dim)
        norm_factors[f'{self.bc_state_key}_stds'] = np.ones(self.bc_state_dim)
        if bc_state_seq.size > 0:
            for k_dim in range(self.bc_state_dim):
                col, mean_k, std_k = bc_state_seq[:, k_dim], np.mean(bc_state_seq[:, k_dim]), np.std(bc_state_seq[:, k_dim])
                if std_k > 1e-8:
                    bc_state_norm[:, k_dim] = (col - mean_k) / std_k
                    norm_factors[f'{self.bc_state_key}_stds'][k_dim] = std_k
                else:
                    bc_state_norm[:, k_dim] = col - mean_k
        bc_state_tensor_norm = torch.tensor(bc_state_norm).float()

        if self.num_controls > 0:
            bc_control_seq = sample[self.bc_control_key][:current_nt_for_item, :]
            bc_control_norm = np.zeros_like(bc_control_seq, dtype=np.float32)
            norm_factors[f'{self.bc_control_key}_means'] = np.mean(bc_control_seq, axis=0) if bc_control_seq.size > 0 else np.zeros(self.num_controls)
            norm_factors[f'{self.bc_control_key}_stds'] = np.ones(self.num_controls)
            if bc_control_seq.size > 0:
                for k_dim in range(self.num_controls):
                    col, mean_k, std_k = bc_control_seq[:, k_dim], np.mean(bc_control_seq[:, k_dim]), np.std(bc_control_seq[:, k_dim])
                    if std_k > 1e-8:
                        bc_control_norm[:, k_dim] = (col - mean_k) / std_k
                        norm_factors[f'{self.bc_control_key}_stds'][k_dim] = std_k
                    else:
                        bc_control_norm[:, k_dim] = col - mean_k
            bc_control_tensor_norm = torch.tensor(bc_control_norm).float()
        else:
            bc_control_tensor_norm = torch.empty((current_nt_for_item, 0), dtype=torch.float32)

        bc_ctrl_tensor_norm = torch.cat((bc_state_tensor_norm, bc_control_tensor_norm), dim=-1)
        output_state_tensors = state_tensors_norm_list[0] if self.num_state_vars == 1 else state_tensors_norm_list
        return output_state_tensors, bc_ctrl_tensor_norm, norm_factors

# =============================================================================
# 1. GNOT-Stepper Architecture Components
# =============================================================================

class MLP(nn.Module):
    """A simple Multi-Layer Perceptron."""
    def __init__(self, input_dim, output_dim, hidden_dims=[], activation=nn.GELU):
        super().__init__()
        layers = []
        current_dim = input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(current_dim, h_dim))
            layers.append(activation())
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, output_dim))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

class LinearAttention(nn.Module):
    """
    Linear attention mechanism.
    Inspired by the GNOT paper, this computes attention with linear complexity.
    It supports self-attention (x attends to x) and cross-attention (x attends to y).
    """
    def __init__(self, embed_dim, n_head):
        super().__init__()
        self.n_head = n_head
        self.head_dim = embed_dim // n_head
        assert self.head_dim * n_head == embed_dim, "embed_dim must be divisible by n_head"

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(0.1)

    def forward(self, x_query, x_kv):
        B, N_q, C = x_query.shape
        _, N_kv, _ = x_kv.shape

        q = self.query(x_query).view(B, N_q, self.n_head, self.head_dim).transpose(1, 2)
        k = self.key(x_kv).view(B, N_kv, self.n_head, self.head_dim).transpose(1, 2)
        v = self.value(x_kv).view(B, N_kv, self.n_head, self.head_dim).transpose(1, 2)

        # Linear attention with softmax normalization
        q = q.softmax(dim=-1)
        k = k.softmax(dim=-1)
        
        # Context is the weighted sum of values, where weights are keys
        context = torch.einsum('b h j d, b h j e -> b h d e', k, v)
        # Final output is the dot product of queries and context
        out = torch.einsum('b h i d, b h d e -> b h i e', q, context)
        
        # Reshape and project
        out = out.transpose(1, 2).reshape(B, N_q, C)
        out = self.out_proj(out)
        return self.attn_drop(out)

class GNOTAttentionBlock(nn.Module):
    """
    The core block of the GNOT architecture, adapted for time-stepping.
    Follows the "cross-attention -> self-attention" structure.
    Includes the geometric gating mechanism[cite: 9, 234].
    """
    def __init__(self, embed_dim, n_head, n_experts, space_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_experts = n_experts
        
        # Attention Layers
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.cross_attention = LinearAttention(embed_dim, n_head)
        
        self.ln3 = nn.LayerNorm(embed_dim)
        self.self_attention = LinearAttention(embed_dim, n_head)

        # Geometric Gating (Mixture of Experts)
        self.gating_network = MLP(space_dim, n_experts, [embed_dim, embed_dim])
        self.expert_mlps = nn.ModuleList([MLP(embed_dim, embed_dim, [embed_dim * 2, embed_dim]) for _ in range(n_experts)])

    def forward(self, query_embed, input_embed, coords):
        # query_embed: Embeddings of the query points (spatial grid)
        # input_embed: Embeddings of the input function (previous state u_t)
        # coords: Geometric coordinates of the query points
        
        # 1. Cross-Attention: Query points attend to the previous state to gather info
        x = query_embed + self.cross_attention(self.ln1(query_embed), self.ln2(input_embed))

        # 2. Geometric Gating (MoE FFN)
        gate_scores = F.softmax(self.gating_network(coords), dim=-1) # (B, N, n_experts)
        expert_outputs = torch.stack([expert(x) for expert in self.expert_mlps], dim=-1) # (B, N, C, n_experts)
        ffn_out = torch.einsum('b n c e, b n e -> b n c', expert_outputs, gate_scores)
        x = x + ffn_out

        # 3. Self-Attention: Query points attend to each other to propagate info spatially
        x = x + self.self_attention(self.ln3(x), self.ln3(x))
        
        return x

class GNOTStepper(nn.Module):
    """
    GNOT-inspired Neural Operator for time-stepping.
    """
    def __init__(self, nx, num_state_vars, bc_ctrl_dim, state_keys,
                 embed_dim=128, n_layers=4, n_head=8, n_experts=4):
        super().__init__()
        self.nx = nx
        self.num_state_vars = num_state_vars
        self.state_keys = state_keys
        self.embed_dim = embed_dim
        self.space_dim = 1 # Assuming 1D spatial domain

        # Encoders for inputs
        self.coord_encoder = MLP(self.space_dim, embed_dim, [embed_dim, embed_dim])
        self.input_encoder = MLP(num_state_vars + self.space_dim, embed_dim, [embed_dim, embed_dim])
        self.bc_encoder = MLP(bc_ctrl_dim, embed_dim, [embed_dim, embed_dim])

        # Core GNOT blocks
        self.blocks = nn.ModuleList([
            GNOTAttentionBlock(embed_dim, n_head, n_experts, self.space_dim) for _ in range(n_layers)
        ])
        
        # Final decoder
        self.decoder = MLP(embed_dim, num_state_vars, [embed_dim * 2, embed_dim])
        
        # Create spatial coordinate grid (query points)
        self.register_buffer('grid', torch.linspace(0, 1, nx).view(1, nx, 1))

    def forward(self, u_t, bc_ctrl_t):
        B = u_t.shape[0]
        device = u_t.device
        
        grid = self.grid.repeat(B, 1, 1).to(device)

        # 1. Encode all inputs into the embedding space
        query_embed = self.coord_encoder(grid)
        
        # The input function is the combination of the previous state and its coordinates
        u_t_with_coords = torch.cat([u_t, grid], dim=-1)
        input_embed = self.input_encoder(u_t_with_coords)
        
        # Encode boundary conditions and broadcast them to all query points
        bc_embed = self.bc_encoder(bc_ctrl_t).unsqueeze(1).repeat(1, self.nx, 1)

        # Add boundary information to the query embeddings
        x = query_embed + bc_embed
        
        # 2. Process through GNOT attention blocks
        for block in self.blocks:
            x = block(x, input_embed, grid)
            
        # 3. Decode to get the next state
        u_tp1_pred = self.decoder(x)
        
        return u_tp1_pred

# =============================================================================
# 2. Training and Validation Functions (Adapted for GNOT Stepper)
#    These are almost identical to your BENO script's functions.
# =============================================================================
def train_gnot_stepper(model, data_loader, dataset_type, train_nt_for_model,
                       lr=1e-3, num_epochs=50, device='cuda',
                       checkpoint_path='gnot_ckpt.pt', clip_grad_norm=1.0):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10, verbose=True)
    mse_loss = nn.MSELoss()
    
    start_epoch, best_loss = 0, float('inf')
    if os.path.exists(checkpoint_path):
        print(f"Loading GNOT checkpoint from {checkpoint_path}...")
        try:
            ckpt = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(ckpt['model_state_dict'])
            optimizer.load_state_dict(ckpt['optimizer_state_dict'])
            start_epoch = ckpt.get('epoch', 0) + 1
            best_loss = ckpt.get('loss', float('inf'))
            print(f"Resuming GNOT training from epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading GNOT checkpoint: {e}. Starting fresh.")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        batch_start_time = time.time()
        
        for i, (state_data, bc_ctrl_data, _) in enumerate(data_loader):
            if isinstance(state_data, list):
                state_seq = torch.stack(state_data, dim=-1).to(device)
            else:
                state_seq = state_data.unsqueeze(-1).to(device)
            bc_ctrl_seq = bc_ctrl_data.to(device)

            optimizer.zero_grad()
            total_seq_loss = 0
            
            for t in range(train_nt_for_model - 1):
                u_n = state_seq[:, t, :, :]
                bc_ctrl_n = bc_ctrl_seq[:, t, :]
                u_np1_true = state_seq[:, t + 1, :, :]
                
                u_np1_pred = model(u_n, bc_ctrl_n)
                
                step_loss = mse_loss(u_np1_pred, u_np1_true)
                total_seq_loss += step_loss
            
            batch_loss = total_seq_loss / (train_nt_for_model - 1)
            epoch_loss += batch_loss.item()
            num_batches += 1
            
            batch_loss.backward()
            if clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)
            optimizer.step()

            if (i + 1) % 50 == 0:
                elapsed = time.time() - batch_start_time
                print(f" GNOT Ep {epoch+1} B {i+1}/{len(data_loader)}, Loss {batch_loss.item():.3e}, Time/50B {elapsed:.2f}s")
                batch_start_time = time.time()
                
        avg_epoch_loss = epoch_loss / max(num_batches, 1)
        print(f"GNOT Epoch {epoch+1}/{num_epochs} Avg Loss: {avg_epoch_loss:.6f}")
        scheduler.step(avg_epoch_loss)
        
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            print(f"Saving GNOT checkpoint with loss {best_loss:.6f}")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss
            }, checkpoint_path)
            
    print("GNOT Training finished.")
    if os.path.exists(checkpoint_path):
        print("Loading best GNOT model for validation.")
        ckpt = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
    return model

def validate_gnot_stepper(model, data_loader, dataset_type,
                          T_value_for_model_training, full_T_in_datafile, full_nt_in_datafile,
                          dataset_params_for_plot, device='cuda',
                          save_fig_path_prefix='gnot_result'):
    model.eval()
    state_keys_val = model.state_keys
    num_state_vars_val = model.num_state_vars

    # This multi-horizon testing logic is critical for fair comparison
    test_horizons_T_values = [
        T_value_for_model_training,
        T_value_for_model_training + 0.5 * (full_T_in_datafile - T_value_for_model_training),
        full_T_in_datafile
    ]
    test_horizons_T_values = sorted(list(set(test_horizons_T_values)))
    print(f"GNOT Validation for T_horizons: {test_horizons_T_values}")

    try:
        state_data_full_loaded, BC_Ctrl_tensor_full_loaded, norm_factors_batch = next(iter(data_loader))
    except StopIteration:
        print("Validation data_loader is empty. Skipping validation.")
        return

    with torch.no_grad():
        if isinstance(state_data_full_loaded, list):
            state_seq_true_norm_full = torch.stack(state_data_full_loaded, dim=-1)[0].to(device)
        else:
            state_seq_true_norm_full = state_data_full_loaded.unsqueeze(-1)[0].to(device)
        
        BC_Ctrl_seq_norm_full = BC_Ctrl_tensor_full_loaded[0].to(device)
        _, nx_plot, _ = state_seq_true_norm_full.shape

        norm_factors_sample = {key: val[0].cpu().numpy() for key, val in norm_factors_batch.items()}
        
        # This is the single, consistent initial condition for all rollouts
        u_initial_norm = state_seq_true_norm_full[0:1, :, :]

        for T_horizon_current in test_horizons_T_values:
            nt_for_rollout = int((T_horizon_current / full_T_in_datafile) * (full_nt_in_datafile - 1)) + 1
            nt_for_rollout = min(nt_for_rollout, full_nt_in_datafile)
            print(f"\n  Rollout for T_horizon = {T_horizon_current:.2f} (nt = {nt_for_rollout})")

            u_pred_seq_norm_horizon = torch.zeros(nt_for_rollout, nx_plot, num_state_vars_val, device=device)
            u_current_pred_step = u_initial_norm.clone()
            u_pred_seq_norm_horizon[0,:,:] = u_current_pred_step.squeeze(0)

            for t_step in range(nt_for_rollout - 1):
                bc_ctrl_n_step = BC_Ctrl_seq_norm_full[t_step:t_step+1, :]
                u_next_pred_norm_step = model(u_current_pred_step, bc_ctrl_n_step)
                u_pred_seq_norm_horizon[t_step+1,:,:] = u_next_pred_norm_step.squeeze(0)
                u_current_pred_step = u_next_pred_norm_step
            
            # Denormalize & Calculate Metrics
            # (This part is identical to your validation function)
            U_pred_denorm_h, U_gt_denorm_h = {}, {}
            state_true_norm_sliced_h = state_seq_true_norm_full[:nt_for_rollout,:,:]
            pred_np_h, gt_np_h = u_pred_seq_norm_horizon.cpu().numpy(), state_true_norm_sliced_h.cpu().numpy()

            for k_idx, key_val in enumerate(state_keys_val):
                mean_k = norm_factors_sample[f'{key_val}_mean']
                std_k = norm_factors_sample[f'{key_val}_std']
                pred_denorm_v = pred_np_h[:,:,k_idx] * std_k + mean_k
                gt_denorm_v = gt_np_h[:,:,k_idx] * std_k + mean_k
                U_pred_denorm_h[key_val], U_gt_denorm_h[key_val] = pred_denorm_v, gt_denorm_v
                
                mse_k_h = np.mean((pred_denorm_v - gt_denorm_v)**2)
                rel_err_k_h = np.linalg.norm(pred_denorm_v - gt_denorm_v, 'fro') / (np.linalg.norm(gt_denorm_v, 'fro') + 1e-10)
                print(f"    Metrics '{key_val}' @ T={T_horizon_current:.2f}: MSE={mse_k_h:.3e}, RelErr={rel_err_k_h:.4f}")

            # Visualization
            fig, axs = plt.subplots(num_state_vars_val, 3, figsize=(18, 5*num_state_vars_val), squeeze=False)
            # ... (Plotting code is identical to yours)
            for k_idx, key_val in enumerate(state_keys_val):
                gt_p, pred_p = U_gt_denorm_h[key_val], U_pred_denorm_h[key_val]
                diff_p = np.abs(pred_p - gt_p)
                v_min, v_max = min(np.min(gt_p), np.min(pred_p)), max(np.max(gt_p), np.max(pred_p))
                plot_ext = [0, dataset_params_for_plot.get('L',1.0), 0, T_horizon_current]

                im0 = axs[k_idx,0].imshow(gt_p, aspect='auto', origin='lower', vmin=v_min, vmax=v_max, extent=plot_ext, cmap='viridis')
                axs[k_idx,0].set_title(f"Truth ({key_val})")
                im1=axs[k_idx,1].imshow(pred_p, aspect='auto', origin='lower',vmin=v_min,vmax=v_max,extent=plot_ext,cmap='viridis')
                axs[k_idx,1].set_title(f"GNOT Pred ({key_val})") # GNOT Title
                im2=axs[k_idx,2].imshow(diff_p, aspect='auto', origin='lower', extent=plot_ext, cmap='magma')
                axs[k_idx,2].set_title(f"Abs Error (Max:{np.max(diff_p):.2e})")
                plt.colorbar(im0,ax=axs[k_idx,0]); plt.colorbar(im1,ax=axs[k_idx,1]); plt.colorbar(im2,ax=axs[k_idx,2])

            fig.suptitle(f"GNOT Validation ({dataset_type.capitalize()}) @ T={T_horizon_current:.2f}") # GNOT Title
            fig.tight_layout(rect=[0, 0.03, 1, 0.95])
            curr_fig_path = save_fig_path_prefix+f"_T{str(T_horizon_current).replace('.','p')}.png"
            plt.savefig(curr_fig_path)
            print(f"  Saved GNOT validation plot to {curr_fig_path}")
            plt.close(fig)

# =============================================================================
# 3. Main Execution Block
# =============================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train a GNOT-inspired stepper model.")
    parser.add_argument('--datatype', type=str, required=True, 
                        choices=['advection', 'euler', 'burgers', 'darcy','heat_delayed_feedback',
                                 'reaction_diffusion_neumann_feedback','heat_nonlinear_feedback_gain','convdiff'])
    args = parser.parse_args()

    DATASET_TYPE = args.datatype
    MODEL_TYPE = 'GNOT_Stepper'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"--- Running {MODEL_TYPE} Baseline for {DATASET_TYPE} on {device} ---")

    # --- Model and Training Hyperparameters ---
    # Inspired by GNOT paper Table 4 [cite: 555]
    EMBED_DIM = 128
    GNN_LAYERS = 4 # Renamed to N_LAYERS for clarity
    NHEAD = 8
    N_EXPERTS = 4 # Key GNOT parameter for geometric gating
    
    LEARNING_RATE = 5e-4
    BATCH_SIZE = 16
    NUM_EPOCHS = 150
    CLIP_GRAD_NORM = 1.0
    
    # --- Dataset and Time Horizon Parameters ---
    FULL_T_IN_DATAFILE = 2.0
    FULL_NT_IN_DATAFILE = 300
    TRAIN_T_TARGET = 1.5 # CRITICAL: Train only on partial data
    TRAIN_NT_FOR_MODEL = int((TRAIN_T_TARGET / FULL_T_IN_DATAFILE) * (FULL_NT_IN_DATAFILE - 1)) + 1

    # Dataset paths and parameters (same as your script)
    dataset_params_for_plot = {}
    if DATASET_TYPE in ['heat_delayed_feedback', 'reaction_diffusion_neumann_feedback', 'heat_nonlinear_feedback_gain', 'convdiff']:
        dataset_path = f"./datasets_new_feedback/{DATASET_TYPE}_v1_5000s_64nx_300nt.pkl"
        main_state_keys, main_num_state_vars = ['U'], 1
        dataset_params_for_plot = {'nx': 64, 'ny': 1, 'L': 1.0}
    else: # Fallback for older datasets
        dataset_path = f"./datasets_full/{DATASET_TYPE}_data_10000s_128nx_600nt.pkl"
        main_state_keys = ['rho','u'] if DATASET_TYPE == 'euler' else ['U']
        main_num_state_vars = 2 if DATASET_TYPE == 'euler' else 1
        dataset_params_for_plot = {'nx': 128, 'ny': 1, 'L': 1.0}
        if DATASET_TYPE not in ['advection', 'burgers', 'euler']:
             print(f"Warning: Using default path for {DATASET_TYPE}. Please verify.")

    print(f"Dataset: {DATASET_TYPE}")
    print(f"Datafile T={FULL_T_IN_DATAFILE}, nt={FULL_NT_IN_DATAFILE}")
    print(f"Training with T_duration={TRAIN_T_TARGET}, nt_steps={TRAIN_NT_FOR_MODEL}")

    # --- Data Loading ---
    print(f"Loading dataset: {dataset_path}")
    try:
        with open(dataset_path, 'rb') as f: data_list_all = pickle.load(f)
    except FileNotFoundError:
        print(f"Error: File not found {dataset_path}"); exit()

    random.shuffle(data_list_all)
    n_train = int(0.8 * len(data_list_all))
    train_data_list, val_data_list = data_list_all[:n_train], data_list_all[n_train:]
    print(f"Train samples: {len(train_data_list)}, Validation samples: {len(val_data_list)}")

    train_dataset = UniversalPDEDataset(train_data_list, dataset_type=DATASET_TYPE, train_nt_limit=TRAIN_NT_FOR_MODEL)
    val_dataset = UniversalPDEDataset(val_data_list, dataset_type=DATASET_TYPE, train_nt_limit=None) # Val loader gets full sequence

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

    # --- Model Initialization ---
    actual_bc_ctrl_dim = train_dataset.bc_state_dim + train_dataset.num_controls
    current_nx = train_dataset.nx

    print(f"\nInitializing {MODEL_TYPE} model...")
    print(f"  nx={current_nx}, num_state_vars={main_num_state_vars}, bc_ctrl_dim={actual_bc_ctrl_dim}")
    print(f"  Hyperparams: embed_dim={EMBED_DIM}, layers={GNN_LAYERS}, heads={NHEAD}, experts={N_EXPERTS}")
    
    gnot_model = GNOTStepper(
        nx=current_nx,
        num_state_vars=main_num_state_vars,
        bc_ctrl_dim=actual_bc_ctrl_dim,
        state_keys=main_state_keys,
        embed_dim=EMBED_DIM,
        n_layers=GNN_LAYERS,
        n_head=NHEAD,
        n_experts=N_EXPERTS
    )
    
    # --- Training and Validation ---
    run_name = f"{DATASET_TYPE}_{MODEL_TYPE}_trainT{TRAIN_T_TARGET}"
    checkpoint_dir = f"./checkpoints_{DATASET_TYPE}_{MODEL_TYPE}"
    results_dir = f"./results_{DATASET_TYPE}_{MODEL_TYPE}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(results_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'model_{run_name}.pt')
    save_fig_path_prefix = os.path.join(results_dir, f'result_{run_name}')

    print(f"\nStarting training for {MODEL_TYPE} on {DATASET_TYPE}...")
    start_train_time = time.time()
    gnot_model = train_gnot_stepper(
        gnot_model, train_loader, dataset_type=DATASET_TYPE,
        train_nt_for_model=TRAIN_NT_FOR_MODEL,
        lr=LEARNING_RATE, num_epochs=NUM_EPOCHS, device=device,
        checkpoint_path=checkpoint_path, clip_grad_norm=CLIP_GRAD_NORM
    )
    end_train_time = time.time()
    print(f"Training took {end_train_time - start_train_time:.2f} seconds.")

    if val_data_list:
        print(f"\nStarting validation for {MODEL_TYPE} on {DATASET_TYPE}...")
        validate_gnot_stepper(
            gnot_model, val_loader, dataset_type=DATASET_TYPE,
            T_value_for_model_training=TRAIN_T_TARGET,
            full_T_in_datafile=FULL_T_IN_DATAFILE,
            full_nt_in_datafile=FULL_NT_IN_DATAFILE,
            dataset_params_for_plot=dataset_params_for_plot, device=device,
            save_fig_path_prefix=save_fig_path_prefix
        )
    else:
        print("\nNo validation data. Skipping validation.")

    print("="*60)
    print(f"Run finished: {run_name}")