# =============================================================================
#
#    BENO-inspired Time-Stepping Neural Operator (Recommended Baseline)
#
#    - Author: Gemini (Based on user's code and discussions)
#    - Date: 2025-07-28
#    - Description: This script provides a faithful adaptation of the BENO
#      architecture for time-stepping PDE problems. It serves as a strong,
#      fair baseline for comparison against other models like ROMs.
#
# =============================================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import time
import pickle
import argparse

# ---------------------
# 固定随机种子 (保持可复现性)
# ---------------------
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# ---------------------

print(f"BENO-Stepper Baseline Script started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# =============================================================================
# 0. 通用化数据集加载器 (与您之前的版本一致)
# =============================================================================
class UniversalPDEDataset(Dataset):
    def __init__(self, data_list, dataset_type, train_nt_limit=None):
        if not data_list:
            raise ValueError("data_list cannot be empty")
        self.data_list = data_list
        self.dataset_type = dataset_type.lower()
        self.train_nt_limit = train_nt_limit

        first_sample = data_list[0]
        params = first_sample.get('params', {})

        self.nt_from_sample_file = 0
        self.nx_from_sample_file = 0
        
        # 简化数据集类型判断
        self.state_keys = ['U']
        self.num_state_vars = 1
        self.expected_bc_state_dim = 2
        
        if 'U' in first_sample:
            self.nt_from_sample_file = first_sample['U'].shape[0]
            self.nx_from_sample_file = first_sample['U'].shape[1]
        elif 'rho' in first_sample: # 兼容旧的Euler数据
            self.nt_from_sample_file = first_sample['rho'].shape[0]
            self.nx_from_sample_file = first_sample['rho'].shape[1]
            self.state_keys = ['rho', 'u']; self.num_state_vars = 2
            self.expected_bc_state_dim = 4
        else:
             raise ValueError(f"Cannot determine state variables from first sample.")

        self.effective_nt_for_loader = self.train_nt_limit if self.train_nt_limit is not None else self.nt_from_sample_file

        self.nx = self.nx_from_sample_file
        self.ny = 1 # 假设所有问题都是1D空间
        self.spatial_dim = self.nx * self.ny

        self.bc_state_key = 'BC_State'
        if self.bc_state_key not in first_sample:
            raise KeyError(f"'{self.bc_state_key}' not found in the first sample!")
        self.bc_state_dim = first_sample[self.bc_state_key].shape[1]
        
        self.bc_control_key = 'BC_Control'
        if self.bc_control_key in first_sample and first_sample[self.bc_control_key] is not None and first_sample[self.bc_control_key].size > 0:
            self.num_controls = first_sample[self.bc_control_key].shape[1]
        else:
            self.num_controls = 0

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        sample = self.data_list[idx]
        norm_factors = {}
        current_nt_for_item = self.effective_nt_for_loader
        state_tensors_norm_list = []

        # 状态变量归一化
        for key in self.state_keys:
            state_seq_full = sample[key]
            state_seq = state_seq_full[:current_nt_for_item, ...]
            state_mean = np.mean(state_seq)
            state_std = np.std(state_seq) + 1e-8
            state_norm = (state_seq - state_mean) / state_std
            state_tensors_norm_list.append(torch.tensor(state_norm).float())
            norm_factors[f'{key}_mean'] = state_mean
            norm_factors[f'{key}_std'] = state_std
        
        # BC State 归一化
        bc_state_seq = sample[self.bc_state_key][:current_nt_for_item, :]
        bc_state_norm = np.zeros_like(bc_state_seq, dtype=np.float32)
        norm_factors[f'{self.bc_state_key}_means'] = np.mean(bc_state_seq, axis=0) if bc_state_seq.size > 0 else np.zeros(self.bc_state_dim)
        norm_factors[f'{self.bc_state_key}_stds'] = np.ones(self.bc_state_dim)
        if bc_state_seq.size > 0:
            for k_dim in range(self.bc_state_dim):
                col = bc_state_seq[:, k_dim]
                mean_k, std_k = np.mean(col), np.std(col)
                if std_k > 1e-8:
                    bc_state_norm[:, k_dim] = (col - mean_k) / std_k
                    norm_factors[f'{self.bc_state_key}_stds'][k_dim] = std_k
                else:
                    bc_state_norm[:, k_dim] = col - mean_k
        bc_state_tensor_norm = torch.tensor(bc_state_norm).float()

        # BC Control 归一化
        if self.num_controls > 0:
            bc_control_seq = sample[self.bc_control_key][:current_nt_for_item, :]
            bc_control_norm = np.zeros_like(bc_control_seq, dtype=np.float32)
            norm_factors[f'{self.bc_control_key}_means'] = np.mean(bc_control_seq, axis=0) if bc_control_seq.size > 0 else np.zeros(self.num_controls)
            norm_factors[f'{self.bc_control_key}_stds'] = np.ones(self.num_controls)
            if bc_control_seq.size > 0:
                for k_dim in range(self.num_controls):
                    col = bc_control_seq[:, k_dim]
                    mean_k, std_k = np.mean(col), np.std(col)
                    if std_k > 1e-8:
                        bc_control_norm[:, k_dim] = (col - mean_k) / std_k
                        norm_factors[f'{self.bc_control_key}_stds'][k_dim] = std_k
                    else:
                        bc_control_norm[:, k_dim] = col - mean_k
            bc_control_tensor_norm = torch.tensor(bc_control_norm).float()
        else:
            bc_control_tensor_norm = torch.empty((current_nt_for_item, 0), dtype=torch.float32)
        
        bc_ctrl_tensor_norm = torch.cat((bc_state_tensor_norm, bc_control_tensor_norm), dim=-1)
        output_state_tensors = state_tensors_norm_list[0] if self.num_state_vars == 1 else state_tensors_norm_list
        return output_state_tensors, bc_ctrl_tensor_norm, norm_factors


# =============================================================================
# 1. 辅助模块 (MLP, Transformer层)
# =============================================================================
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims=[], activation=nn.GELU, dropout=0.1):
        super().__init__(); layers=[]; current_dim=input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(current_dim,h_dim)); layers.append(activation())
            if dropout > 0: layers.append(nn.Dropout(dropout))
            current_dim=h_dim
        layers.append(nn.Linear(current_dim,output_dim)); self.net=nn.Sequential(*layers)
    def forward(self,x): return self.net(x)

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.1, activation=F.gelu):
        super().__init__()
        self.self_attn=nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.linear1=nn.Linear(d_model,dim_feedforward); self.dropout=nn.Dropout(dropout)
        self.linear2=nn.Linear(dim_feedforward,d_model); self.norm1=nn.LayerNorm(d_model)
        self.norm2=nn.LayerNorm(d_model); self.dropout1=nn.Dropout(dropout)
        self.dropout2=nn.Dropout(dropout); self.activation=activation
    def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
        src2, _ = self.self_attn(src, src, src, key_padding_mask=src_key_padding_mask, attn_mask=src_mask)
        src = src + self.dropout1(src2); src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2); src = self.norm2(src)
        return src

# =============================================================================
# 2. 修正后的BENO-Stepper架构
# =============================================================================
class BoundaryEmbedder(nn.Module):
    """
    使用Transformer将边界特征编码为全局上下文向量。
    这个模块设计得很好，予以保留。
    """
    def __init__(self, num_bc_points, input_feat_dim, d_model, nhead, num_encoder_layers, output_dim):
        super().__init__()
        self.num_bc_points = num_bc_points
        self.input_feat_dim = input_feat_dim
        self.input_proj = nn.Linear(input_feat_dim, d_model)
        self.pos_encoder = nn.Parameter(torch.randn(1, num_bc_points, d_model) * 0.02)
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward=d_model * 2)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        self.output_mlp = MLP(d_model, output_dim, hidden_dims=[d_model // 2])

    def forward(self, boundary_features):
        src = self.input_proj(boundary_features) + self.pos_encoder
        memory = self.transformer_encoder(src)
        pooled_memory = memory.mean(dim=1)  # 平均池化得到一个全局向量
        global_embedding = self.output_mlp(pooled_memory)
        return global_embedding

class GNNLikeProcessor(nn.Module):
    """
    在规则网格上模拟GNN/MPNN过程。
    通过将全局边界嵌入注入到每个节点的更新规则中，来实现边界条件的驱动。
    这个模块设计得很好，予以保留。
    """
    def __init__(self, input_dim, output_dim, hidden_dim, num_layers, global_embed_dim):
        super().__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            layer = nn.ModuleDict({
                'conv': nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2, padding_mode='replicate'),
                'norm': nn.LayerNorm(hidden_dim),
                'act': nn.GELU(),
                'node_mlp': MLP(hidden_dim + global_embed_dim, hidden_dim, [hidden_dim])
            })
            self.layers.append(layer)
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, global_boundary_embed):
        x = self.input_layer(x)
        for layer in self.layers:
            x_res = x
            x_perm = x.permute(0, 2, 1)
            aggregated = layer['conv'](x_perm).permute(0, 2, 1)
            aggregated = layer['norm'](aggregated)
            aggregated = layer['act'](aggregated)
            
            B, N, _ = aggregated.shape
            embed_expanded = global_boundary_embed.unsqueeze(1).repeat(1, N, 1)
            
            mlp_input = torch.cat([aggregated, embed_expanded], dim=-1)
            x_updated = layer['node_mlp'](mlp_input)
            x = x_res + x_updated # 残差连接
            
        return self.output_layer(x)

class BENOStepper(nn.Module):
    """
    修正后的BENO-Stepper主模型。
    该架构忠实地反映了BENO论文的核心思想，并适配了时间步进任务。
    它包含一个边界编码器、一个主处理器和一个解码器，形成清晰的单路径流。
    """
    def __init__(self, nx, num_state_vars, bc_ctrl_dim_input, state_keys,
                 embed_dim=128, hidden_dim=128, gnn_layers=3, transformer_layers=2, nhead=4):
        super().__init__()
        self.nx = nx
        self.num_state_vars = num_state_vars
        self.state_keys = state_keys
        self.num_bc_points = 2

        self.ctrl_feat_per_bc_point = bc_ctrl_dim_input // self.num_bc_points
        self.bc_feat_dim = num_state_vars + self.ctrl_feat_per_bc_point
        print(f"BENO Initializing: num_state_vars={num_state_vars}, total_bc_ctrl_dim={bc_ctrl_dim_input}")
        print(f"  Derived ctrl_feat_per_bc_point={self.ctrl_feat_per_bc_point}, leading to bc_feat_dim={self.bc_feat_dim}")
        
        self.global_embed_dim = embed_dim // 2
        
        self.boundary_embedder = BoundaryEmbedder(
            num_bc_points=self.num_bc_points,
            input_feat_dim=self.bc_feat_dim,
            d_model=embed_dim,
            nhead=nhead,
            num_encoder_layers=transformer_layers,
            output_dim=self.global_embed_dim
        )
        
        self.processor = GNNLikeProcessor(
            input_dim=num_state_vars,
            output_dim=hidden_dim,
            hidden_dim=hidden_dim,
            num_layers=gnn_layers,
            global_embed_dim=self.global_embed_dim
        )
        
        # 解码器: 将处理后的特征映射回状态空间
        self.decoder_mlp = MLP(hidden_dim, num_state_vars, [hidden_dim, hidden_dim])

    def _extract_boundary_features(self, u_t, bc_ctrl_t):
        u_left = u_t[:, 0, :]
        u_right = u_t[:, -1, :]
        
        bc_info_left = bc_ctrl_t[:, :self.ctrl_feat_per_bc_point]
        bc_info_right = bc_ctrl_t[:, self.ctrl_feat_per_bc_point : 2*self.ctrl_feat_per_bc_point]
        
        feat_left = torch.cat([u_left, bc_info_left], dim=-1)
        feat_right = torch.cat([u_right, bc_info_right], dim=-1)

        return torch.stack([feat_left, feat_right], dim=1)

    def forward(self, u_t, bc_ctrl_t):
        """
        前向传播：
        1. 提取并编码边界信息得到全局嵌入向量。
        2. 使用主处理器在全局边界嵌入的驱动下，处理整个状态场。
        3. 解码得到状态的增量 delta_u。
        4. 通过残差连接得到下一时刻的状态 u_{t+1}。
        """
        boundary_features = self._extract_boundary_features(u_t, bc_ctrl_t)
        global_boundary_embed = self.boundary_embedder(boundary_features)
        processed_features = self.processor(u_t, global_boundary_embed)
        
        # 预测状态增量 (delta_u)
        delta_u_pred = self.decoder_mlp(processed_features)
        
        # 通过残差连接得到下一时刻的状态
        u_tp1_pred = u_t + delta_u_pred
        
        return u_tp1_pred

# =============================================================================
# 4. 训练与验证函数
# =============================================================================
def train_beno_stepper(model, data_loader, train_nt_for_model,
                       lr=1e-3, num_epochs=50, device='cuda',
                       checkpoint_path='beno_ckpt.pt', clip_grad_norm=1.0):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)
    mse_loss = nn.MSELoss(reduction='mean')
    start_epoch = 0
    best_loss = float('inf')

    # 加载检查点
    if os.path.exists(checkpoint_path):
        print(f"Loading BENO checkpoint from {checkpoint_path}...")
        try:
            ckpt = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(ckpt['model_state_dict'])
            if 'optimizer_state_dict' in ckpt:
                try: optimizer.load_state_dict(ckpt['optimizer_state_dict'])
                except: print("Warning: BENO Optimizer state mismatch.")
            start_epoch = ckpt.get('epoch', 0) + 1
            best_loss = ckpt.get('loss', float('inf'))
            print(f"Resuming BENO training from epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading BENO checkpoint: {e}. Starting fresh.")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        batch_start_time = time.time()

        for i, (state_data_loaded, BC_Ctrl_tensor_loaded, _) in enumerate(data_loader):
            # 将数据处理成正确的格式和设备
            if isinstance(state_data_loaded, list):
                state_seq_true_train = torch.stack(state_data_loaded, dim=-1).to(device)
            else:
                state_seq_true_train = state_data_loaded.unsqueeze(-1).to(device)
            
            BC_Ctrl_seq_train = BC_Ctrl_tensor_loaded.to(device)
            B, nt_loaded, nx, _ = state_seq_true_train.shape
            
            if nt_loaded != train_nt_for_model:
                raise ValueError(f"Data nt {nt_loaded} != train_nt {train_nt_for_model}")

            optimizer.zero_grad()
            total_seq_loss = 0.0
            
            # 在时间序列上进行自回归训练
            for t in range(train_nt_for_model - 1):
                u_n_true = state_seq_true_train[:, t, :, :]
                bc_ctrl_n = BC_Ctrl_seq_train[:, t, :]
                u_np1_true = state_seq_true_train[:, t + 1, :, :]
                
                # 模型预测
                u_np1_pred = model(u_n_true, bc_ctrl_n)
                
                # 计算损失
                step_loss = mse_loss(u_np1_pred, u_np1_true)
                total_seq_loss += step_loss
            
            batch_loss = total_seq_loss / (train_nt_for_model - 1)
            epoch_loss += batch_loss.item()
            num_batches += 1
            
            batch_loss.backward()
            if clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)
            optimizer.step()
            
            if (i + 1) % 50 == 0:
                elapsed = time.time() - batch_start_time
                print(f" BENO Ep {epoch+1} B {i+1}/{len(data_loader)}, Loss {batch_loss.item():.3e}, Time/50B {elapsed:.2f}s")
                batch_start_time = time.time()

        avg_epoch_loss = epoch_loss / max(num_batches, 1)
        print(f"BENO Epoch {epoch+1}/{num_epochs} Avg Loss: {avg_epoch_loss:.6f}")
        scheduler.step(avg_epoch_loss)
        
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            print(f"Saving new best BENO checkpoint with loss {best_loss:.6f}")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
                'state_keys': model.state_keys
            }, checkpoint_path)
            
    print("BENO Training finished.")
    # 加载最好的模型用于返回
    if os.path.exists(checkpoint_path):
        print(f"Loading best BENO model for validation.")
        ckpt = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
    return model

def validate_beno_stepper(model, data_loader, dataset_type,
                          train_nt_for_model_training: int, T_value_for_model_training: float,
                          full_T_in_datafile: float, full_nt_in_datafile: int,
                          dataset_params_for_plot: dict, device='cuda',
                          save_fig_path_prefix='beno_result'):
    """验证函数，与您之前的版本逻辑基本一致，用于评估和绘图。"""
    model.eval()
    state_keys_val = model.state_keys
    num_state_vars_val = model.num_state_vars

    # 定义测试的时间范围
    test_horizons_T_values = [T_value_for_model_training]
    if full_T_in_datafile > T_value_for_model_training:
        test_horizons_T_values.append(T_value_for_model_training + 0.5 * (full_T_in_datafile - T_value_for_model_training))
        test_horizons_T_values.append(full_T_in_datafile)
    test_horizons_T_values = sorted(list(set(h for h in test_horizons_T_values if h <= full_T_in_datafile + 1e-6)))
    print(f"BENO Validation for T_horizons: {test_horizons_T_values}")

    with torch.no_grad():
        try:
            state_data_full_loaded, BC_Ctrl_tensor_full_loaded, norm_factors_batch = next(iter(data_loader))
        except StopIteration:
            print("Validation data_loader empty. Skipping BENO validation.")
            return

        # 准备数据
        if isinstance(state_data_full_loaded, list):
            state_seq_true_norm_full = torch.stack(state_data_full_loaded, dim=-1)[0].to(device)
        else:
            state_seq_true_norm_full = state_data_full_loaded.unsqueeze(-1)[0].to(device)
        BC_Ctrl_seq_norm_full = BC_Ctrl_tensor_full_loaded[0].to(device)
        _, nx_plot, _ = state_seq_true_norm_full.shape

        norm_factors_sample = {key: val[0].cpu().numpy() for key, val in norm_factors_batch.items()}
        u_initial_norm = state_seq_true_norm_full[0:1, :, :]

        # 在不同时间范围上进行推理
        for T_horizon_current in test_horizons_T_values:
            nt_for_rollout = int((T_horizon_current / full_T_in_datafile) * (full_nt_in_datafile - 1)) + 1
            nt_for_rollout = min(nt_for_rollout, full_nt_in_datafile)
            print(f"\n  Rollout for T_horizon = {T_horizon_current:.2f} (nt = {nt_for_rollout})")

            # 自回归推理
            u_pred_seq_norm_horizon = torch.zeros(nt_for_rollout, nx_plot, num_state_vars_val, device=device)
            u_current_pred_step = u_initial_norm.clone()
            u_pred_seq_norm_horizon[0, :, :] = u_current_pred_step.squeeze(0)
            BC_Ctrl_for_rollout = BC_Ctrl_seq_norm_full[:nt_for_rollout, :]

            for t_step in range(nt_for_rollout - 1):
                bc_ctrl_n_step = BC_Ctrl_for_rollout[t_step:t_step+1, :]
                u_next_pred_norm_step = model(u_current_pred_step, bc_ctrl_n_step)
                u_pred_seq_norm_horizon[t_step + 1, :, :] = u_next_pred_norm_step.squeeze(0)
                u_current_pred_step = u_next_pred_norm_step # 使用预测结果作为下一步的输入

            # 反归一化并计算指标
            U_pred_denorm_h, U_gt_denorm_h = {}, {}
            state_true_norm_sliced_h = state_seq_true_norm_full[:nt_for_rollout, :, :]
            pred_np_h, gt_np_h = u_pred_seq_norm_horizon.cpu().numpy(), state_true_norm_sliced_h.cpu().numpy()

            for k_idx, key_val in enumerate(state_keys_val):
                mean_k, std_k = norm_factors_sample[f'{key_val}_mean'], norm_factors_sample[f'{key_val}_std']
                pred_denorm_v = pred_np_h[:, :, k_idx] * std_k + mean_k
                gt_denorm_v = gt_np_h[:, :, k_idx] * std_k + mean_k
                U_pred_denorm_h[key_val] = pred_denorm_v
                U_gt_denorm_h[key_val] = gt_denorm_v
                mse_k_h = np.mean((pred_denorm_v - gt_denorm_v)**2)
                rel_err_k_h = np.linalg.norm(pred_denorm_v - gt_denorm_v, 'fro') / (np.linalg.norm(gt_denorm_v, 'fro') + 1e-10)
                print(f"    Metrics '{key_val}' @ T={T_horizon_current:.1f}: MSE={mse_k_h:.3e}, RelErr={rel_err_k_h:.3e}")

            # 可视化
            fig, axs = plt.subplots(num_state_vars_val, 3, figsize=(18, 5 * num_state_vars_val), squeeze=False)
            fig_L = dataset_params_for_plot.get('L', 1.0)
            plot_ext = [0, fig_L, 0, T_horizon_current]
            for k_idx, key_val in enumerate(state_keys_val):
                gt_p, pred_p = U_gt_denorm_h[key_val], U_pred_denorm_h[key_val]
                diff_p, max_err_p = np.abs(pred_p - gt_p), np.max(np.abs(pred_p - gt_p))
                vmin_p, vmax_p = min(np.min(gt_p), np.min(pred_p)), max(np.max(gt_p), np.max(pred_p))
                
                im0 = axs[k_idx, 0].imshow(gt_p.T, aspect='auto', origin='lower', vmin=vmin_p, vmax=vmax_p, extent=plot_ext, cmap='viridis')
                axs[k_idx, 0].set_title(f"Ground Truth ({key_val})")
                im1 = axs[k_idx, 1].imshow(pred_p.T, aspect='auto', origin='lower', vmin=vmin_p, vmax=vmax_p, extent=plot_ext, cmap='viridis')
                im1.set_title(f"BENO Prediction ({key_val})")
                im2 = axs[k_idx, 2].imshow(diff_p.T, aspect='auto', origin='lower', extent=plot_ext, cmap='magma')
                im2.set_title(f"Absolute Error (Max:{max_err_p:.2e})")
                for j_p in range(3): axs[k_idx, j_p].set_xlabel("Time"); axs[k_idx, j_p].set_ylabel("x")
                plt.colorbar(im0, ax=axs[k_idx, 0]); plt.colorbar(im1, ax=axs[k_idx, 1]); plt.colorbar(im2, ax=axs[k_idx, 2])
            
            fig.suptitle(f"BENO Baseline Validation ({dataset_type.capitalize()}) @ T={T_horizon_current:.1f}")
            fig.tight_layout(rect=[0, 0.03, 1, 0.95])
            curr_fig_path = save_fig_path_prefix + f"_T{str(T_horizon_current).replace('.', 'p')}.png"
            plt.savefig(curr_fig_path)
            print(f"  Saved BENO validation plot to {curr_fig_path}")
            plt.show()

# =============================================================================
# 5. 主程序入口
# =============================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Run BENO-Stepper baseline.")
    parser.add_argument('--datatype', type=str, required=True, 
                        choices=['heat_delayed_feedback', 'reaction_diffusion_neumann_feedback', 
                                 'heat_nonlinear_feedback_gain', 'convdiff'], 
                        help='Type of dataset to use.')
    args = parser.parse_args()

    DATASET_TYPE = args.datatype
    MODEL_TYPE = 'BENO_Stepper_Baseline'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"--- Running {MODEL_TYPE} for {DATASET_TYPE} on {device} ---")

    # --- 模型超参数 ---
    EMBED_DIM = 128
    HIDDEN_DIM = 128
    GNN_LAYERS = 4
    TRANSFORMER_LAYERS = 2
    NHEAD = 4
    LEARNING_RATE = 1e-3
    BATCH_SIZE = 32
    NUM_EPOCHS = 150
    CLIP_GRAD_NORM = 1.0

    # --- 数据集时间参数 ---
    FULL_T_IN_DATAFILE = 2.0
    FULL_NT_IN_DATAFILE = 300
    TRAIN_T_TARGET = 1.5 # 训练所用的时间长度
    TRAIN_NT_FOR_MODEL = int((TRAIN_T_TARGET / FULL_T_IN_DATAFILE) * (FULL_NT_IN_DATAFILE - 1)) + 1
    
    print(f"Dataset: {DATASET_TYPE}")
    print(f"Training with T_duration={TRAIN_T_TARGET}, nt_points={TRAIN_NT_FOR_MODEL}")

    # --- 数据集路径配置 ---
    dataset_path = f"./datasets_new_feedback/{DATASET_TYPE}_v1_5000s_64nx_300nt.pkl"
    dataset_params_for_plot = {'nx': 64, 'ny': 1, 'L': 1.0, 'T': FULL_T_IN_DATAFILE}

    # --- 加载数据 ---
    print(f"Loading dataset: {dataset_path}")
    try:
        with open(dataset_path, 'rb') as f: data_list_all = pickle.load(f)
        print(f"Loaded {len(data_list_all)} samples.")
    except FileNotFoundError:
        print(f"Error: File not found {dataset_path}"); exit()

    random.shuffle(data_list_all)
    n_train = int(0.8 * len(data_list_all))
    train_data_list, val_data_list = data_list_all[:n_train], data_list_all[n_train:]
    print(f"Train samples: {len(train_data_list)}, Validation samples: {len(val_data_list)}")

    # 创建 DataLoader
    train_dataset = UniversalPDEDataset(train_data_list, dataset_type=DATASET_TYPE, train_nt_limit=TRAIN_NT_FOR_MODEL)
    val_dataset = UniversalPDEDataset(val_data_list, dataset_type=DATASET_TYPE, train_nt_limit=None) # 验证集加载完整序列
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)

    # --- 初始化模型 ---
    actual_bc_ctrl_dim = train_dataset.bc_state_dim + train_dataset.num_controls
    online_beno_model = BENOStepper(
        nx=train_dataset.nx,
        num_state_vars=train_dataset.num_state_vars,
        bc_ctrl_dim_input=actual_bc_ctrl_dim,
        state_keys=train_dataset.state_keys,
        embed_dim=EMBED_DIM, hidden_dim=HIDDEN_DIM, gnn_layers=GNN_LAYERS,
        transformer_layers=TRANSFORMER_LAYERS, nhead=NHEAD
    )

    # --- 设置保存路径 ---
    run_name = f"{DATASET_TYPE}_{MODEL_TYPE}_emb{EMBED_DIM}"
    checkpoint_dir = f"./New_ckpt_BENO/checkpoints_{DATASET_TYPE}_{MODEL_TYPE}"
    results_dir = f"./result_all_BENO/results_{DATASET_TYPE}_{MODEL_TYPE}"
    os.makedirs(checkpoint_dir, exist_ok=True); os.makedirs(results_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'model_{run_name}.pt')
    save_fig_path_prefix = os.path.join(results_dir, f'result_{run_name}')

    # --- 开始训练 ---
    print(f"\nStarting training for {MODEL_TYPE} on {DATASET_TYPE}...")
    start_train_time = time.time()
    online_beno_model = train_beno_stepper(
        online_beno_model, train_loader,
        train_nt_for_model=TRAIN_NT_FOR_MODEL,
        lr=LEARNING_RATE, num_epochs=NUM_EPOCHS, device=device,
        checkpoint_path=checkpoint_path, clip_grad_norm=CLIP_GRAD_NORM
    )
    print(f"Training took {time.time() - start_train_time:.2f} seconds.")

    # --- 开始验证 ---
    if val_data_list:
        print(f"\nStarting validation for {MODEL_TYPE} on {DATASET_TYPE}...")
        validate_beno_stepper(
            online_beno_model, val_loader, dataset_type=DATASET_TYPE,
            train_nt_for_model_training=TRAIN_NT_FOR_MODEL,
            T_value_for_model_training=TRAIN_T_TARGET,
            full_T_in_datafile=FULL_T_IN_DATAFILE,
            full_nt_in_datafile=FULL_NT_IN_DATAFILE,
            dataset_params_for_plot=dataset_params_for_plot, device=device,
            save_fig_path_prefix=save_fig_path_prefix
        )
    else:
        print("\nNo validation data. Skipping validation.")

    print("=" * 60)
    print(f"Run finished: {run_name}")
    print(f"Final checkpoint: {checkpoint_path}")
    if val_data_list: print(f"Validation figures saved with prefix: {save_fig_path_prefix}")
    print("=" * 60)