import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import time
import scipy.sparse as sp
from scipy.sparse.linalg import spsolve
import argparse
import pickle # Ensure pickle is imported

# ---------------------
# 固定随机种子 (Fixed random seed)
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# ---------------------

# =============================================================================
# 2. 通用化数据集定义 (UniversalPDEDataset)
# =============================================================================
class UniversalPDEDataset(Dataset):
    def __init__(self, data_list, dataset_type, train_nt_limit=None): # Added train_nt_limit
        """
        Args:
            data_list: 包含样本字典的列表。 (List containing sample dictionaries)
            dataset_type: String identifying the dataset type.
            train_nt_limit: If specified, truncate sequences to this length for training.
        """
        if not data_list:
            raise ValueError("data_list cannot be empty")
        self.data_list = data_list
        self.dataset_type = dataset_type
        self.train_nt_limit = train_nt_limit # Store the limit

        # --- 从第一个样本推断参数 (Infer parameters from the first sample) ---
        first_sample = data_list[0]
        
        # Determine effective nt for the dataset instance (used in __getitem__)
        if dataset_type == 'heat_delayed_feedback':
            self.nt_from_sample = first_sample['U'].shape[0]
            self.nx_from_sample = first_sample['U'].shape[1]
            self.ny_from_sample = 1 # Assuming 1D for this example
            self.state_keys = ['U']
            self.num_state_vars = 1
            self.nx = self.nx_from_sample
            self.ny = self.ny_from_sample
            self.expected_bc_state_dim = 2 
        elif dataset_type == 'reaction_diffusion_neumann_feedback':
            self.nt_from_sample = first_sample['U'].shape[0]
            self.nx_from_sample = first_sample['U'].shape[1]
            self.ny_from_sample = 1
            self.state_keys = ['U']
            self.num_state_vars = 1
            self.nx = self.nx_from_sample
            self.ny = self.ny_from_sample
            self.expected_bc_state_dim = 2
        elif dataset_type == 'heat_nonlinear_feedback_gain':
            self.nt_from_sample = first_sample['U'].shape[0]
            self.nx_from_sample = first_sample['U'].shape[1]
            self.ny_from_sample = 1
            self.state_keys = ['U']
            self.num_state_vars = 1
            self.nx = self.nx_from_sample
            self.ny = self.ny_from_sample
            self.expected_bc_state_dim = 2
        elif dataset_type == 'convdiff':
            self.nt_from_sample = first_sample['U'].shape[0]
            self.nx_from_sample = first_sample['U'].shape[1]
            self.state_keys = ['U']
            self.num_state_vars = 1
            self.nx = self.nx_from_sample
            self.ny = 1 
            self.expected_bc_state_dim = 2
        else:
            raise ValueError(f"Unknown dataset_type: {dataset_type}")
        
        self.effective_nt = self.train_nt_limit if self.train_nt_limit is not None else self.nt_from_sample
        self.spatial_dim = self.nx * self.ny

        self.bc_state_key = 'BC_State'
        if self.bc_state_key not in first_sample:
            raise KeyError(f"'{self.bc_state_key}' not found in the first sample!")
        actual_bc_state_dim = first_sample[self.bc_state_key].shape[1]
        
        if actual_bc_state_dim != self.expected_bc_state_dim:
            print(f"Warning: BC_State dimension mismatch for {dataset_type}. "
                  f"Expected {self.expected_bc_state_dim}, got {actual_bc_state_dim}. "
                  f"Using actual dimension: {actual_bc_state_dim}")
            self.bc_state_dim = actual_bc_state_dim
        else:
            self.bc_state_dim = self.expected_bc_state_dim

        self.bc_control_key = 'BC_Control'
        if self.bc_control_key in first_sample and first_sample[self.bc_control_key] is not None:
                if first_sample[self.bc_control_key].ndim == 2: # Expected shape [nt, num_controls]
                    self.num_controls = first_sample[self.bc_control_key].shape[1]
                elif first_sample[self.bc_control_key].ndim == 1 and first_sample[self.bc_control_key].shape[0] == self.effective_nt:
                    # This case might mean num_controls = 1, and it's [nt] instead of [nt, 1]
                    # Or it could be an error in data generation. For now, assume 0 if not 2D.
                    print(f"Warning: '{self.bc_control_key}' is 1D. Assuming num_controls = 0 or check data format.")
                    self.num_controls = 0 # Or handle as 1 if appropriate
                else: # Not 2D and not matching nt for 1D case
                    self.num_controls = 0
        else: # Key not found or is None
            self.num_controls = 0
        
        if self.num_controls == 0 and self.bc_control_key in first_sample and first_sample[self.bc_control_key] is not None:
                print(f"Info: '{self.bc_control_key}' found but num_controls is 0 for sample 0. Shape: {first_sample[self.bc_control_key].shape if hasattr(first_sample[self.bc_control_key], 'shape') else 'N/A'}")
        elif self.num_controls == 0:
                print(f"Info: '{self.bc_control_key}' not found or is None in the first sample. Assuming num_controls = 0.")


    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        sample = self.data_list[idx]
        norm_factors = {}
        current_nt = self.effective_nt

        state_tensors_norm_list = []
        for key in self.state_keys:
            try:
                state_seq_full = sample[key] 
                state_seq = state_seq_full[:current_nt, ...]
            except KeyError:
                raise KeyError(f"State variable key '{key}' not found in sample {idx} for dataset type '{self.dataset_type}'")

            if state_seq.shape[0] != current_nt:
                raise ValueError(f"Time dimension mismatch after potential truncation for {key}. Expected {current_nt}, got {state_seq.shape[0]}")

            state_mean = np.mean(state_seq)
            state_std = np.std(state_seq) + 1e-8 # Add epsilon for stability
            state_norm = (state_seq - state_mean) / state_std
            state_tensors_norm_list.append(torch.tensor(state_norm).float())
            norm_factors[f'{key}_mean'] = state_mean
            norm_factors[f'{key}_std'] = state_std

        # --- BC State ---
        bc_state_seq_full = sample[self.bc_state_key]
        bc_state_seq = bc_state_seq_full[:current_nt, :] 

        if bc_state_seq.shape[0] != current_nt:
            raise ValueError(f"Time dimension mismatch for BC_State. Expected {current_nt}, got {bc_state_seq.shape[0]}")

        bc_state_norm = np.zeros_like(bc_state_seq, dtype=np.float32)
        norm_factors[f'{self.bc_state_key}_means'] = np.zeros(self.bc_state_dim)
        norm_factors[f'{self.bc_state_key}_stds'] = np.ones(self.bc_state_dim)
        if self.bc_state_dim > 0: # Ensure there are BC state dimensions to process
            for k_dim in range(self.bc_state_dim): 
                col = bc_state_seq[:, k_dim]
                mean_k = np.mean(col)
                std_k = np.std(col)
                if std_k > 1e-8:
                    bc_state_norm[:, k_dim] = (col - mean_k) / std_k
                    norm_factors[f'{self.bc_state_key}_means'][k_dim] = mean_k
                    norm_factors[f'{self.bc_state_key}_stds'][k_dim] = std_k
                else: # Handle constant columns
                    bc_state_norm[:, k_dim] = col - mean_k # Centered, std is effectively 1 for norm_factors
                    norm_factors[f'{self.bc_state_key}_means'][k_dim] = mean_k
                    # norm_factors[f'{self.bc_state_key}_stds'][k_dim] remains 1.0
        bc_state_tensor_norm = torch.tensor(bc_state_norm).float()

        # --- BC Control ---
        if self.num_controls > 0:
            try:
                bc_control_seq_full = sample[self.bc_control_key]
                bc_control_seq = bc_control_seq_full[:current_nt, :] 

                if bc_control_seq.shape[0] != current_nt:
                    raise ValueError(f"Time dimension mismatch for BC_Control. Expected {current_nt}, got {bc_control_seq.shape[0]}.")
                if bc_control_seq.shape[1] != self.num_controls:
                        raise ValueError(f"Control dimension mismatch for sample {idx}. Expected {self.num_controls}, got {bc_control_seq.shape[1]}")


                bc_control_norm = np.zeros_like(bc_control_seq, dtype=np.float32)
                norm_factors[f'{self.bc_control_key}_means'] = np.zeros(self.num_controls)
                norm_factors[f'{self.bc_control_key}_stds'] = np.ones(self.num_controls)
                for k_dim in range(self.num_controls): 
                    col = bc_control_seq[:, k_dim]
                    mean_k = np.mean(col)
                    std_k = np.std(col)
                    if std_k > 1e-8:
                        bc_control_norm[:, k_dim] = (col - mean_k) / std_k
                        norm_factors[f'{self.bc_control_key}_means'][k_dim] = mean_k
                        norm_factors[f'{self.bc_control_key}_stds'][k_dim] = std_k
                    else:
                        bc_control_norm[:, k_dim] = col - mean_k
                        norm_factors[f'{self.bc_control_key}_means'][k_dim] = mean_k
                bc_control_tensor_norm = torch.tensor(bc_control_norm).float()

            except KeyError:
                # This case should be less likely if num_controls > 0 was derived from sample[0]
                print(f"Warning: Sample {idx} missing '{self.bc_control_key}' despite num_controls > 0. Using zeros.")
                bc_control_tensor_norm = torch.zeros((current_nt, self.num_controls), dtype=torch.float32)
                norm_factors[f'{self.bc_control_key}_means'] = np.zeros(self.num_controls)
                norm_factors[f'{self.bc_control_key}_stds'] = np.ones(self.num_controls)
        else: # num_controls is 0
            bc_control_tensor_norm = torch.empty((current_nt, 0), dtype=torch.float32) # Empty tensor

        # Concatenate normalized BC_State and BC_Control
        bc_ctrl_tensor_norm = torch.cat((bc_state_tensor_norm, bc_control_tensor_norm), dim=-1)
        return state_tensors_norm_list, bc_ctrl_tensor_norm, norm_factors

# =============================================================================
# 3. 通用化 POD 基计算 (compute_pod_basis_generic)
# =============================================================================
def compute_pod_basis_generic(data_list, dataset_type, state_variable_key,
                              nx, nt, basis_dim, 
                              max_snapshots_pod=100):
    snapshots = []
    count = 0
    current_nx = nx 
    lin_interp = np.linspace(0, 1, current_nx)[np.newaxis, :] 
    print(f"  Computing POD for '{state_variable_key}' using {nt} timesteps, linear interp U_B...")

    for sample_idx, sample in enumerate(data_list):
        if count >= max_snapshots_pod:
            break
        if state_variable_key not in sample:
            print(f"Warning: Key '{state_variable_key}' not found in sample {sample_idx}. Skipping.")
            continue

        U_seq_full = sample[state_variable_key]
        U_seq = U_seq_full[:nt, :] 

        if U_seq.shape[0] != nt:
            print(f"Warning: U_seq actual timesteps {U_seq.shape[0]} != requested nt {nt} for POD in sample {sample_idx}. Skipping.")
            continue
        if U_seq.shape[1] != current_nx:
            print(f"Warning: Mismatch nx in sample {sample_idx} for {state_variable_key}. Expected {current_nx}, got {U_seq.shape[1]}. Skipping.")
            continue
        
        bc_left_val = U_seq[:, 0:1] 
        bc_right_val = U_seq[:, -1:]

        if np.isnan(bc_left_val).any() or np.isinf(bc_left_val).any() or \
           np.isnan(bc_right_val).any() or np.isinf(bc_right_val).any():
            print(f"Warning: NaN/Inf in boundary values for sample {sample_idx}, key '{state_variable_key}'. Skipping sample for POD.")
            continue

        U_B = bc_left_val * (1 - lin_interp) + bc_right_val * lin_interp
        U_star = U_seq - U_B 
        snapshots.append(U_star)
        count += 1
    
    if not snapshots:
        print(f"Error: No valid snapshots collected for POD for '{state_variable_key}'. Ensure 'nt' ({nt}) is appropriate.")
        return None
    try:
        all_snapshots_np = np.concatenate(snapshots, axis=0) 
    except ValueError as e:
        print(f"Error concatenating snapshots for '{state_variable_key}': {e}. Check snapshot shapes. Collected {len(snapshots)} lists of snapshots.")
        return None

    if np.isnan(all_snapshots_np).any() or np.isinf(all_snapshots_np).any():
        print(f"Warning: NaN/Inf found in snapshots for '{state_variable_key}' before POD. Clamping.")
        all_snapshots_np = np.nan_to_num(all_snapshots_np, nan=0.0, posinf=1e6, neginf=-1e6)
        if np.all(np.abs(all_snapshots_np) < 1e-12):
            print(f"Error: All snapshots became zero after clamping for '{state_variable_key}'.")
            return None

    U_mean = np.mean(all_snapshots_np, axis=0, keepdims=True)
    U_centered = all_snapshots_np - U_mean

    try:
        U_data_svd, S_data_svd, Vh_data_svd = np.linalg.svd(U_centered, full_matrices=False)
        rank = np.sum(S_data_svd > 1e-10) 
        actual_basis_dim = min(basis_dim, rank, current_nx)
        if actual_basis_dim == 0:
            print(f"Error: Data rank is zero or too low for '{state_variable_key}' after SVD. Effective rank: {rank}")
            return None
        if actual_basis_dim < basis_dim:
            print(f"Warning: Requested basis_dim {basis_dim} but effective data rank is ~{rank} for '{state_variable_key}'. Using {actual_basis_dim}.")
        basis = Vh_data_svd[:actual_basis_dim, :].T 
    except np.linalg.LinAlgError as e:
        print(f"SVD failed for '{state_variable_key}': {e}.")
        return None
    except Exception as e: 
        print(f"Error during SVD for '{state_variable_key}': {e}.")
        return None

    basis_norms = np.linalg.norm(basis, axis=0)
    basis_norms[basis_norms < 1e-10] = 1.0 
    basis = basis / basis_norms[np.newaxis, :]

    if actual_basis_dim < basis_dim:
        print(f"Padding POD basis for '{state_variable_key}' from dim {actual_basis_dim} to {basis_dim}")
        padding = np.zeros((current_nx, basis_dim - actual_basis_dim))
        basis = np.hstack((basis, padding))

    print(f"  Successfully computed POD basis for '{state_variable_key}' with shape {basis.shape}.")
    return basis.astype(np.float32)


# =============================================================================
# 4. 模型定义 - 公共组件 (Model Definitions - Common Components)
# =============================================================================

# 4.1. Feedforward 更新网络 (ImprovedUpdateFFN - General Version)
class ImprovedUpdateFFN(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_layers=3, dropout=0.1, output_dim=None):
        super().__init__()
        layers = []
        current_dim = input_dim
        if output_dim is None:
            output_dim = input_dim # Default to input_dim if output_dim not specified

        for i in range(num_layers - 1):
            layers.append(nn.Linear(current_dim, hidden_dim))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))
            current_dim = hidden_dim
        layers.append(nn.Linear(current_dim, output_dim)) # Use specified output_dim
        self.mlp = nn.Sequential(*layers)
        
        self.layernorm = nn.LayerNorm(output_dim)
        self.input_dim_for_residual = input_dim
        self.output_dim_for_residual = output_dim

    def forward(self, x):
        mlp_out = self.mlp(x)
        if self.input_dim_for_residual == self.output_dim_for_residual:
            out = self.layernorm(mlp_out + x) # Residual connection
        else:
            out = self.layernorm(mlp_out) # No residual if dims don't match
        return out

# 4.2. Lifting 模块 (UniversalLifting)
class UniversalLifting(nn.Module):
    def __init__(self, num_state_vars, bc_state_dim, num_controls, output_dim_per_var, nx,
                 hidden_dims_state_branch=64,
                 hidden_dims_control=[64, 128],
                 hidden_dims_fusion=[256, 512, 256], 
                 dropout=0.1):
        super().__init__()
        self.num_state_vars = num_state_vars
        self.bc_state_dim = bc_state_dim
        self.num_controls = num_controls
        self.nx = nx
        assert output_dim_per_var == nx, "output_dim_per_var must equal nx"

        self.state_branches = nn.ModuleList()
        if self.bc_state_dim > 0:
            for _ in range(bc_state_dim):
                self.state_branches.append(nn.Sequential(
                    nn.Linear(1, hidden_dims_state_branch),
                    nn.GELU(),
                ))
            state_feature_dim = bc_state_dim * hidden_dims_state_branch
        else:
            state_feature_dim = 0

        control_feature_dim = 0
        if self.num_controls > 0:
            control_layers = []
            current_dim_ctrl = num_controls 
            for h_dim in hidden_dims_control:
                control_layers.append(nn.Linear(current_dim_ctrl, h_dim))
                control_layers.append(nn.GELU())
                control_layers.append(nn.Dropout(dropout))
                current_dim_ctrl = h_dim
            self.control_mlp = nn.Sequential(*control_layers)
            control_feature_dim = current_dim_ctrl 
        else:
            self.control_mlp = nn.Sequential()

        fusion_input_dim = state_feature_dim + control_feature_dim
        fusion_layers = []
        
        if fusion_input_dim > 0: # Only build fusion if there are inputs
            current_dim_fusion = fusion_input_dim
            for h_dim in hidden_dims_fusion:
                fusion_layers.append(nn.Linear(current_dim_fusion, h_dim))
                fusion_layers.append(nn.GELU())
                fusion_layers.append(nn.Dropout(dropout))
                current_dim_fusion = h_dim
            fusion_layers.append(nn.Linear(current_dim_fusion, num_state_vars * nx))
            self.fusion = nn.Sequential(*fusion_layers)
        else: # No inputs to fusion network
            self.fusion = None 

    def forward(self, BC_Ctrl):
        if self.fusion is None: 
            batch_size = BC_Ctrl.shape[0] if BC_Ctrl is not None and BC_Ctrl.nelement() > 0 else 1
            return torch.zeros(batch_size, self.num_state_vars, self.nx, device=BC_Ctrl.device if BC_Ctrl is not None else 'cpu')

        features_to_concat = []
        if self.bc_state_dim > 0:
            BC_state = BC_Ctrl[:, :self.bc_state_dim]
            state_features_list = []
            for i in range(self.bc_state_dim):
                branch_out = self.state_branches[i](BC_state[:, i:i+1])
                state_features_list.append(branch_out)
            state_features = torch.cat(state_features_list, dim=-1) 
            features_to_concat.append(state_features)

        if self.num_controls > 0:
            BC_control = BC_Ctrl[:, self.bc_state_dim:] 
            control_features = self.control_mlp(BC_control) 
            features_to_concat.append(control_features)

        if not features_to_concat: 
            batch_size = BC_Ctrl.shape[0] if BC_Ctrl is not None and BC_Ctrl.nelement() > 0 else 1
            return torch.zeros(batch_size, self.num_state_vars, self.nx, device=BC_Ctrl.device if BC_Ctrl is not None else 'cpu')

        if len(features_to_concat) == 1:
            concat_features = features_to_concat[0]
        else:
            concat_features = torch.cat(features_to_concat, dim=-1)

        fused_output = self.fusion(concat_features)
        U_B_stacked = fused_output.view(-1, self.num_state_vars, self.nx)
        return U_B_stacked

# 4.3. 多头注意力机制 (CustomMultiHeadAttentionMechanism)
class CustomMultiHeadAttentionMechanism(nn.Module):
    def __init__(self, basis_dim, d_model, num_heads): # basis_dim here is sequence_length for attention
        super().__init__()
        self.num_heads = num_heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.head_dim = d_model // num_heads
        self.out_proj = nn.Linear(d_model, d_model) # Projects concatenated heads back to d_model

    def forward(self, Q, K, V): # Q, K, V are [batch, seq_len, d_model]
        batch_size, seq_len_q, d_model_q = Q.size()
        _, seq_len_kv, d_model_kv = K.size() # K and V have same seq_len and d_model
        
        assert seq_len_q == seq_len_kv, "Sequence lengths of Q and K/V must match for this attention"
        assert d_model_q == d_model_kv, "d_model of Q and K/V must match"
        
        seq_len = seq_len_q
        d_model = d_model_q

        Q_reshaped = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K_reshaped = K.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V_reshaped = V.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        KV = torch.matmul(K_reshaped.transpose(-2, -1), V_reshaped) # [batch, num_heads, head_dim, head_dim]
        z = torch.matmul(Q_reshaped, KV) # [batch, num_heads, seq_len, head_dim]

        z = z.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, d_model)
        z = self.out_proj(z) # [batch, seq_len, d_model]
        return z

# =============================================================================
# 4.4. 原模型定义 (Original Model Definitions)
# =============================================================================
#省略
# =============================================================================
# 4.5. 消融模型定义 (Ablation Model Definitions)
# =============================================================================

# 4.5.1 No-Attention ROM with EXPLICIT BC processing
class NoAttention_ExplicitBC_ROM(nn.Module):
    def __init__(self, state_variable_keys, nx, basis_dim, d_model, # d_model for FFN hidden dim
                 bc_state_dim, num_controls, 
                 add_error_estimator=False, shared_attention=False, # shared_attention for FFN/BC proc.
                 dropout_lifting=0.1, dropout_ffn=0.1, initial_alpha=0.1,
                 use_fixed_lifting=False,
                 bc_processed_dim=64, 
                 hidden_bc_processor_dim=128
                 ):
        super().__init__()
        self.state_keys = state_variable_keys
        self.num_state_vars = len(state_variable_keys)
        self.nx = nx
        self.basis_dim = basis_dim
        self.d_model = d_model # Used for FFN hidden dim if shared
        self.bc_state_dim = bc_state_dim
        self.num_controls = num_controls
        self.add_error_estimator = add_error_estimator
        self.shared_attention = shared_attention # Determines if FFN/BC proc are shared
        self.use_fixed_lifting = use_fixed_lifting
        self.bc_processed_dim = bc_processed_dim

        self.Phi = nn.ParameterDict()
        for key in self.state_keys:
            phi_param = nn.Parameter(torch.randn(nx, basis_dim))
            nn.init.orthogonal_(phi_param)
            self.Phi[key] = phi_param

        if not self.use_fixed_lifting:
            self.lifting = UniversalLifting(
                num_state_vars=self.num_state_vars,
                bc_state_dim=bc_state_dim,
                num_controls=num_controls,
                output_dim_per_var=nx,
                nx=nx,
                dropout=dropout_lifting
            )
        else:
            self.lifting = None
            lin_interp_coeffs = torch.linspace(0, 1, self.nx, dtype=torch.float32)
            self.register_buffer('lin_interp_coeffs', lin_interp_coeffs.view(1, 1, -1))

        # Only FFN, BC processing, and alpha are needed
        self.update_ffn = nn.ModuleDict()
        self.a0_mapping = nn.ModuleDict()
        self.alphas = nn.ParameterDict()
        self.bc_feature_processor = nn.ModuleDict()
        self.bc_to_a_update = nn.ModuleDict()
        total_bc_ctrl_dim = self.bc_state_dim + self.num_controls

        # Note: 'shared_attention' flag now controls sharing of FFN and BC processors
        if self.shared_attention: # "Shared" components
            self.update_ffn['shared'] = ImprovedUpdateFFN(input_dim=basis_dim, output_dim=basis_dim,
                                                          hidden_dim=d_model, dropout=dropout_ffn)
            self.a0_mapping['shared'] = nn.Sequential(
                nn.Linear(basis_dim, basis_dim), nn.ReLU(), nn.LayerNorm(basis_dim)
            )
            self.alphas['shared'] = nn.Parameter(torch.tensor(initial_alpha))
            if total_bc_ctrl_dim > 0:
                self.bc_feature_processor['shared'] = nn.Sequential(
                    nn.Linear(total_bc_ctrl_dim, hidden_bc_processor_dim),
                    nn.GELU(),
                    nn.Linear(hidden_bc_processor_dim, self.bc_processed_dim)
                )
                self.bc_to_a_update['shared'] = nn.Linear(self.bc_processed_dim, basis_dim)
            else:
                self.bc_feature_processor['shared'] = nn.Sequential()
                self.bc_to_a_update['shared'] = nn.Linear(0, basis_dim) if self.bc_processed_dim == 0 else nn.Linear(self.bc_processed_dim, basis_dim)

        else: # Components per state variable
            for key in self.state_keys:
                self.update_ffn[key] = ImprovedUpdateFFN(input_dim=basis_dim, output_dim=basis_dim,
                                                         hidden_dim=d_model, dropout=dropout_ffn)
                self.a0_mapping[key] = nn.Sequential(
                    nn.Linear(basis_dim, basis_dim), nn.ReLU(), nn.LayerNorm(basis_dim)
                )
                self.alphas[key] = nn.Parameter(torch.tensor(initial_alpha))
                if total_bc_ctrl_dim > 0:
                    self.bc_feature_processor[key] = nn.Sequential(
                        nn.Linear(total_bc_ctrl_dim, hidden_bc_processor_dim),
                        nn.GELU(),
                        nn.Linear(hidden_bc_processor_dim, self.bc_processed_dim)
                    )
                    self.bc_to_a_update[key] = nn.Linear(self.bc_processed_dim, basis_dim)
                else:
                    self.bc_feature_processor[key] = nn.Sequential()
                    self.bc_to_a_update[key] = nn.Linear(0, basis_dim) if self.bc_processed_dim == 0 else nn.Linear(self.bc_processed_dim, basis_dim)


        if self.add_error_estimator:
            total_basis_dim_all_vars = self.num_state_vars * basis_dim
            self.error_estimator = nn.Linear(total_basis_dim_all_vars, 1)

    def _get_layer(self, module_dict, key):
        return module_dict['shared'] if self.shared_attention else module_dict[key]

    def _get_alpha(self, key):
        return self.alphas['shared'] if self.shared_attention else self.alphas[key]

    def _compute_U_B(self, BC_Ctrl_n): # Copied from MultiVarAttentionROM_ExplicitBC
        if not self.use_fixed_lifting:
            if self.lifting is None:
                raise ValueError("Lifting network is None but use_fixed_lifting is False.")
            return self.lifting(BC_Ctrl_n)
        else:
            if self.num_state_vars == 1:
                if self.bc_state_dim < 2:
                    print(f"Warning (NoAttnExplicit): Fixed lifting for 1 state var expects bc_state_dim >= 2, got {self.bc_state_dim}. Returning zeros for U_B.")
                    return torch.zeros(BC_Ctrl_n.shape[0], 1, self.nx, device=BC_Ctrl_n.device)
                bc_left_val = BC_Ctrl_n[:, 0:1].unsqueeze(-1)
                bc_right_val = BC_Ctrl_n[:, 1:2].unsqueeze(-1)
                U_B_var = bc_left_val * (1 - self.lin_interp_coeffs) + \
                          bc_right_val * self.lin_interp_coeffs
                return U_B_var
            else:
                U_B_list = []
                if self.bc_state_dim < 2 * self.num_state_vars and self.num_state_vars > 0:
                    print(f"Warning (NoAttnExplicit): Fixed lifting for {self.num_state_vars} vars expects bc_state_dim >= {2*self.num_state_vars}, got {self.bc_state_dim}. Returning zeros for U_B.")
                    return torch.zeros(BC_Ctrl_n.shape[0], self.num_state_vars, self.nx, device=BC_Ctrl_n.device)

                for i_var in range(self.num_state_vars):
                    idx_left = i_var * 2
                    idx_right = i_var * 2 + 1
                    if idx_right >= self.bc_state_dim :
                        print(f"Warning (NoAttnExplicit): Not enough BC_State values for fixed lifting of var {i_var}. Using zeros for this var's U_B.")
                        U_B_single_var = torch.zeros(BC_Ctrl_n.shape[0], 1, self.nx, device=BC_Ctrl_n.device)
                    else:
                        bc_left_val = BC_Ctrl_n[:, idx_left:idx_left+1].unsqueeze(-1)
                        bc_right_val = BC_Ctrl_n[:, idx_right:idx_right+1].unsqueeze(-1)
                        U_B_single_var = bc_left_val * (1 - self.lin_interp_coeffs) + \
                                         bc_right_val * self.lin_interp_coeffs
                    U_B_list.append(U_B_single_var)
                return torch.cat(U_B_list, dim=1)

    def forward_step(self, a_n_dict, BC_Ctrl_n, params=None):
        batch_size = list(a_n_dict.values())[0].size(0)
        a_next_dict = {}
        U_hat_dict = {}
        U_B_stacked = self._compute_U_B(BC_Ctrl_n)

        bc_features_processed_dict = {}
        total_bc_ctrl_dim = self.bc_state_dim + self.num_controls
        if total_bc_ctrl_dim > 0:
            if self.shared_attention: # Using shared_attention flag for BC processor sharing
                bc_proc_layer = self._get_layer(self.bc_feature_processor, 'shared')
                if hasattr(bc_proc_layer, 'weight') or len(list(bc_proc_layer.parameters())) > 0:
                    shared_bc_features = bc_proc_layer(BC_Ctrl_n)
                    for key in self.state_keys:
                        bc_features_processed_dict[key] = shared_bc_features
                else:
                    for key in self.state_keys: bc_features_processed_dict[key] = None
            else:
                for key in self.state_keys:
                    bc_proc_layer = self._get_layer(self.bc_feature_processor, key)
                    if hasattr(bc_proc_layer, 'weight') or len(list(bc_proc_layer.parameters())) > 0:
                        bc_features_processed_dict[key] = bc_proc_layer(BC_Ctrl_n)
                    else:
                        bc_features_processed_dict[key] = None
        else:
            for key in self.state_keys:
                bc_features_processed_dict[key] = None
        
        for i_var, key in enumerate(self.state_keys):
            a_n_var = a_n_dict[key]
            Phi_var = self.Phi[key]
            ffn_var = self._get_layer(self.update_ffn, key)
            alpha_var = self._get_alpha(key)
            bc_to_a_update_layer = self._get_layer(self.bc_to_a_update, key)

            ffn_update_intrinsic = ffn_var(a_n_var)

            bc_driven_a_update = torch.zeros_like(a_n_var)
            current_bc_features = bc_features_processed_dict[key]
            if current_bc_features is not None and ( hasattr(bc_to_a_update_layer, 'weight') or len(list(bc_to_a_update_layer.parameters())) > 0 ):
                bc_driven_a_update = bc_to_a_update_layer(current_bc_features)
            
            # Update without attention term
            a_next_var = a_n_var + alpha_var * ffn_update_intrinsic + bc_driven_a_update
            a_next_dict[key] = a_next_var

            U_B_var = U_B_stacked[:, i_var, :].unsqueeze(-1)
            Phi_expanded = Phi_var.unsqueeze(0).expand(batch_size, -1, -1)
            a_next_unsq = a_next_var.unsqueeze(-1)
            U_recon_star = torch.bmm(Phi_expanded, a_next_unsq)
            U_hat_dict[key] = U_B_var + U_recon_star

        err_est = None
        if self.add_error_estimator:
            a_next_combined_for_err_est = torch.cat(list(a_next_dict.values()), dim=-1)
            err_est = self.error_estimator(a_next_combined_for_err_est)
        return a_next_dict, U_hat_dict, err_est

    def forward(self, a0_dict, BC_Ctrl_seq, T, params=None): # Identical to other ROMs
        a_current_dict = {}
        for key in self.state_keys:
            a0_map_layer = self._get_layer(self.a0_mapping, key)
            a_current_dict[key] = a0_map_layer(a0_dict[key])
        
        U_hat_seq_dict = {key: [] for key in self.state_keys}
        err_seq = [] if self.add_error_estimator else None

        for t_step in range(T):
            BC_Ctrl_n_step = BC_Ctrl_seq[:, t_step, :]
            a_next_dict_step, U_hat_dict_step, err_est_step = self.forward_step(
                a_current_dict, BC_Ctrl_n_step, params
            )
            for key in self.state_keys:
                U_hat_seq_dict[key].append(U_hat_dict_step[key])
            if self.add_error_estimator and err_est_step is not None:
                err_seq.append(err_est_step)
            a_current_dict = a_next_dict_step
        return U_hat_seq_dict, err_seq

    def get_basis(self, key):
        return self.Phi[key]

# 4.5.2 No-Attention ROM with IMPLICIT BC handling
class NoAttention_ImplicitBC_ROM(nn.Module):
    def __init__(self, state_variable_keys, nx, basis_dim, d_model, # d_model for FFN hidden dim
                 bc_state_dim, num_controls,
                 add_error_estimator=False, shared_attention=False, # shared_attention for FFN
                 dropout_lifting=0.1, dropout_ffn=0.1, initial_alpha=0.1,
                 use_fixed_lifting=False):
        super().__init__()
        self.state_keys = state_variable_keys
        self.num_state_vars = len(state_variable_keys)
        self.nx = nx
        self.basis_dim = basis_dim
        self.d_model = d_model # Used for FFN hidden dim if shared
        self.bc_state_dim = bc_state_dim
        self.num_controls = num_controls
        self.add_error_estimator = add_error_estimator
        self.shared_attention = shared_attention # Determines if FFN is shared
        self.use_fixed_lifting = use_fixed_lifting


        self.Phi = nn.ParameterDict()
        for key in self.state_keys:
            phi_param = nn.Parameter(torch.randn(nx, basis_dim))
            nn.init.orthogonal_(phi_param)
            self.Phi[key] = phi_param

        if not self.use_fixed_lifting:
            self.lifting = UniversalLifting(
                num_state_vars=self.num_state_vars,
                bc_state_dim=bc_state_dim,
                num_controls=num_controls,
                output_dim_per_var=nx,
                nx=nx,
                dropout=dropout_lifting
            )
        else:
            self.lifting = None
            lin_interp_coeffs = torch.linspace(0, 1, self.nx, dtype=torch.float32)
            self.register_buffer('lin_interp_coeffs', lin_interp_coeffs.view(1, 1, -1))


        # Only FFN and alpha are needed from the original attention-based structure
        self.update_ffn = nn.ModuleDict()
        self.a0_mapping = nn.ModuleDict()
        self.alphas = nn.ParameterDict()

        if self.shared_attention: # "Shared" FFN
            self.update_ffn['shared'] = ImprovedUpdateFFN(input_dim=basis_dim, output_dim=basis_dim,
                                                          hidden_dim=d_model, dropout=dropout_ffn)
            self.a0_mapping['shared'] = nn.Sequential(
                nn.Linear(basis_dim, basis_dim), nn.ReLU(), nn.LayerNorm(basis_dim)
            )
            self.alphas['shared'] = nn.Parameter(torch.tensor(initial_alpha))
        else: # FFN per state variable
            for key in self.state_keys:
                self.update_ffn[key] = ImprovedUpdateFFN(input_dim=basis_dim, output_dim=basis_dim,
                                                         hidden_dim=d_model, dropout=dropout_ffn)
                self.a0_mapping[key] = nn.Sequential(
                    nn.Linear(basis_dim, basis_dim), nn.ReLU(), nn.LayerNorm(basis_dim)
                )
                self.alphas[key] = nn.Parameter(torch.tensor(initial_alpha))

        if self.add_error_estimator:
            total_basis_dim_all_vars = self.num_state_vars * basis_dim
            self.error_estimator = nn.Linear(total_basis_dim_all_vars, 1)

    def _get_layer(self, module_dict, key):
        return module_dict['shared'] if self.shared_attention else module_dict[key]

    def _get_alpha(self, key):
        return self.alphas['shared'] if self.shared_attention else self.alphas[key]

    def _compute_U_B(self, BC_Ctrl_n): # Copied from MultiVarAttentionROM_ImplicitBC
        if not self.use_fixed_lifting:
            if self.lifting is None:
                 raise ValueError("Lifting network is None but use_fixed_lifting is False.")
            return self.lifting(BC_Ctrl_n)
        else: # Fixed linear interpolation
            if self.num_state_vars == 1:
                if self.bc_state_dim < 2: # Need at least left and right for 1D interpolation
                    print(f"Warning (NoAttnImplicit): Fixed lifting for 1 state var expects bc_state_dim >= 2, got {self.bc_state_dim}. Returning zeros for U_B.")
                    return torch.zeros(BC_Ctrl_n.shape[0], 1, self.nx, device=BC_Ctrl_n.device)
                bc_left_val = BC_Ctrl_n[:, 0:1].unsqueeze(-1) # [batch, 1, 1]
                bc_right_val = BC_Ctrl_n[:, 1:2].unsqueeze(-1) # [batch, 1, 1]
                U_B_var = bc_left_val * (1 - self.lin_interp_coeffs) + \
                          bc_right_val * self.lin_interp_coeffs # lin_interp_coeffs is [1,1,nx]
                return U_B_var # Shape: [batch, 1, nx]
            else: # Multi-variable fixed lifting
                U_B_list = []
                if self.bc_state_dim < 2 * self.num_state_vars and self.num_state_vars > 0:
                    print(f"Warning (NoAttnImplicit): Fixed lifting for {self.num_state_vars} vars expects bc_state_dim >= {2*self.num_state_vars}, got {self.bc_state_dim}. Returning zeros for U_B.")
                    return torch.zeros(BC_Ctrl_n.shape[0], self.num_state_vars, self.nx, device=BC_Ctrl_n.device)

                for i_var in range(self.num_state_vars):
                    idx_left = i_var * 2
                    idx_right = i_var * 2 + 1
                    if idx_right >= self.bc_state_dim : # Not enough BC_State values
                        print(f"Warning (NoAttnImplicit): Not enough BC_State values for fixed lifting of var {i_var}. Using zeros for this var's U_B.")
                        U_B_single_var = torch.zeros(BC_Ctrl_n.shape[0], 1, self.nx, device=BC_Ctrl_n.device)
                    else:
                        bc_left_val = BC_Ctrl_n[:, idx_left:idx_left+1].unsqueeze(-1)
                        bc_right_val = BC_Ctrl_n[:, idx_right:idx_right+1].unsqueeze(-1)
                        U_B_single_var = bc_left_val * (1 - self.lin_interp_coeffs) + \
                                         bc_right_val * self.lin_interp_coeffs
                    U_B_list.append(U_B_single_var)
                return torch.cat(U_B_list, dim=1) # Stack along variable dimension -> [batch, num_vars, nx]


    def forward_step(self, a_n_dict, BC_Ctrl_n, params=None):
        batch_size = list(a_n_dict.values())[0].size(0)
        a_next_dict = {}
        U_hat_dict = {}
        U_B_stacked = self._compute_U_B(BC_Ctrl_n) # Lifting for implicit BC handling

        for i, key in enumerate(self.state_keys):
            a_n_var = a_n_dict[key]
            Phi_var = self.Phi[key]
            ffn_var = self._get_layer(self.update_ffn, key)
            alpha_var = self._get_alpha(key)

            ffn_output = ffn_var(a_n_var) # [batch, basis_dim]
            
            # Update uses only scaled FFN output
            a_next_var = a_n_var + alpha_var * ffn_output
            a_next_dict[key] = a_next_var

            U_B_var = U_B_stacked[:, i, :].unsqueeze(-1)
            Phi_exp = Phi_var.unsqueeze(0).expand(batch_size, -1, -1)
            a_next_unsq = a_next_var.unsqueeze(-1)
            U_recon = torch.bmm(Phi_exp, a_next_unsq)
            U_hat_dict[key] = U_B_var + U_recon

        err_est = None
        if self.add_error_estimator:
            a_next_combined = torch.cat(list(a_next_dict.values()), dim=-1)
            err_est = self.error_estimator(a_next_combined)
        return a_next_dict, U_hat_dict, err_est

    def forward(self, a0_dict, BC_Ctrl_seq, T, params=None): # Identical to other ROMs
        a_current_dict = {}
        for key in self.state_keys:
            a0_map = self._get_layer(self.a0_mapping, key)
            a_current_dict[key] = a0_map(a0_dict[key])
        U_hat_seq_dict = {key: [] for key in self.state_keys}
        err_seq = [] if self.add_error_estimator else None
        for t in range(T):
            BC_Ctrl_n = BC_Ctrl_seq[:, t, :]
            a_next_dict, U_hat_dict, err_est = self.forward_step(a_current_dict, BC_Ctrl_n, params)
            for key in self.state_keys:
                U_hat_seq_dict[key].append(U_hat_dict[key])
            if self.add_error_estimator and err_est is not None:
                err_seq.append(err_est)
            a_current_dict = a_next_dict
        return U_hat_seq_dict, err_seq

    def get_basis(self, key):
        return self.Phi[key]

# =============================================================================
# 5. 训练与验证函数 (train_multivar_model, validate_multivar_model)
#    (Using versions from the explicit BC code as they are more comprehensive)
# =============================================================================
def train_multivar_model(model, data_loader, dataset_type, train_nt_target,
                         lr=1e-3, num_epochs=50, device='cuda',
                         checkpoint_path='rom_checkpoint.pt', lambda_res=0.05,
                         lambda_orth=0.001, lambda_bc_penalty=0.01,
                         clip_grad_norm=1.0):
    model.to(device)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-5)
    mse_loss = nn.MSELoss()
    best_val_loss = float('inf')
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

    start_epoch = 0
    if os.path.exists(checkpoint_path):
        print(f"Loading existing checkpoint from {checkpoint_path} ...")
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            if 'optimizer_state_dict' in checkpoint:
                try:
                    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                except:
                    print("Warning: Could not load optimizer state. Reinitializing optimizer.")
            start_epoch = checkpoint.get('epoch', 0) + 1
            best_val_loss = checkpoint.get('loss', float('inf'))
            # Load scheduler state if available and compatible
            if 'scheduler_state_dict' in checkpoint and hasattr(scheduler, 'load_state_dict'):
                try:
                    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                except:
                     print("Warning: Could not load scheduler state. Reinitializing scheduler.")
            print(f"Resuming training from epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting fresh.")

    state_keys = model.state_keys

    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss = 0.0
        count = 0
        # batch_start_time = time.time() # Moved inside loop for per-batch timing if uncommented

        for i, (state_data, BC_Ctrl_tensor, norm_factors) in enumerate(data_loader):
            # batch_iteration_start_time = time.time() # For timing a single batch iteration

            if isinstance(state_data, list):
                state_tensors = [s.to(device) for s in state_data]
                batch_size, nt, nx_data = state_tensors[0].shape
            else:
                state_tensors = [state_data.to(device)]
                batch_size, nt, nx_data = state_tensors[0].shape
            BC_Ctrl_tensor = BC_Ctrl_tensor.to(device)
            
            if nt != train_nt_target:
                raise ValueError(f"Mismatch: nt from DataLoader ({nt}) != train_nt_target ({train_nt_target})")

            optimizer.zero_grad()

            a0_dict = {}
            BC_ctrl_combined_t0 = BC_Ctrl_tensor[:, 0, :]
            with torch.no_grad():
                # _compute_U_B is part of the model instance
                U_B_stacked_t0 = model._compute_U_B(BC_ctrl_combined_t0)


            for k, key in enumerate(state_keys):
                U0_var = state_tensors[k][:, 0, :].unsqueeze(-1)
                U_B_t0_var = U_B_stacked_t0[:, k, :].unsqueeze(-1)
                U0_star_var = U0_var - U_B_t0_var

                Phi_var = model.get_basis(key)
                Phi_T_var = Phi_var.transpose(0, 1).unsqueeze(0).expand(batch_size, -1, -1)
                a0_dict[key] = torch.bmm(Phi_T_var, U0_star_var).squeeze(-1)

            U_hat_seq_dict, _ = model(a0_dict, BC_Ctrl_tensor, T=train_nt_target, params=None)

            total_batch_loss = 0.0
            mse_recon_loss_val = 0.0 # Use different name to avoid conflict with nn.MSELoss
            residual_orth_loss_val = 0.0
            orth_loss_val = 0.0
            boundary_penalty_val = 0.0

            for k, key in enumerate(state_keys):
                Phi_var = model.get_basis(key)
                Phi_T_var = Phi_var.transpose(0, 1).unsqueeze(0).expand(batch_size, -1, -1)
                U_hat_seq_var = U_hat_seq_dict[key]
                U_target_seq_var = state_tensors[k]

                var_mse_loss = 0.0
                var_res_orth_loss = 0.0
                for t_step in range(train_nt_target): # Corrected loop variable name
                    pred = U_hat_seq_var[t_step]
                    target = U_target_seq_var[:, t_step, :].unsqueeze(-1)
                    var_mse_loss += mse_loss(pred, target)

                    if lambda_res > 0:
                        r = target - pred
                        r_proj = torch.bmm(Phi_T_var, r)
                        var_res_orth_loss += mse_loss(r_proj, torch.zeros_like(r_proj))
                
                mse_recon_loss_val += (var_mse_loss / train_nt_target) # Use train_nt_target
                residual_orth_loss_val += (var_res_orth_loss / train_nt_target) # Use train_nt_target

                if lambda_orth > 0:
                    PhiT_Phi = torch.matmul(Phi_var.transpose(0, 1), Phi_var)
                    I = torch.eye(model.basis_dim, device=device)
                    orth_loss_val += torch.norm(PhiT_Phi - I, p='fro')**2

                if lambda_bc_penalty > 0:
                    if Phi_var.shape[0] > 1: # nx > 1
                        boundary_penalty_val += mse_loss(Phi_var[0, :], torch.zeros_like(Phi_var[0, :])) + \
                                             mse_loss(Phi_var[-1, :], torch.zeros_like(Phi_var[-1, :]))
                    elif Phi_var.shape[0] == 1: # nx = 1 (less common for PDE basis)
                         boundary_penalty_val += mse_loss(Phi_var[0, :], torch.zeros_like(Phi_var[0, :]))


            if len(state_keys) > 0: # Avoid division by zero if state_keys is empty
                orth_loss_val /= len(state_keys)
                boundary_penalty_val /= len(state_keys)
            else: # Should not happen if model is properly initialized
                orth_loss_val = 0 
                boundary_penalty_val = 0


            total_batch_loss = mse_recon_loss_val + \
                               lambda_res * residual_orth_loss_val + \
                               lambda_orth * orth_loss_val + \
                               lambda_bc_penalty * boundary_penalty_val

            total_batch_loss.backward()
            if clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)
            optimizer.step()

            epoch_loss += total_batch_loss.item()
            count += 1
            
            # batch_iteration_end_time = time.time()
            # if (i + 1) % 50 == 0: # Print every 50 batches
            #     print(f"    Epoch {epoch+1}/{num_epochs}, Batch {i+1}/{len(data_loader)}, "
            #           f"Batch Loss: {total_batch_loss.item():.4e}, Time/BatchIter: {batch_iteration_end_time - batch_iteration_start_time:.3f}s")


        avg_epoch_loss = epoch_loss / count if count > 0 else 0
        print(f"Epoch {epoch+1}/{num_epochs} finished. Average Training Loss: {avg_epoch_loss:.6f}")

        scheduler.step(avg_epoch_loss)
        if avg_epoch_loss < best_val_loss:
            best_val_loss = avg_epoch_loss
            print(f"Saving checkpoint with loss {best_val_loss:.6f} to {checkpoint_path}")
            save_dict = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(), # Save scheduler state
                'loss': best_val_loss,
                'dataset_type': dataset_type,
                'state_keys': model.state_keys,
                'nx': model.nx,
                'basis_dim': model.basis_dim,
                'd_model': model.d_model if hasattr(model, 'd_model') else None, # Save d_model if exists
                'bc_state_dim': model.bc_state_dim if hasattr(model, 'bc_state_dim') else model.lifting.bc_state_dim,
                'num_controls': model.num_controls if hasattr(model, 'num_controls') else None,
                'num_heads': model.num_heads if hasattr(model, 'num_heads') else None, # Save num_heads if exists
                'shared_attention': model.shared_attention if hasattr(model, 'shared_attention') else None,
                'use_fixed_lifting': model.use_fixed_lifting if hasattr(model, 'use_fixed_lifting') else None,
                # Explicit BC specific params, save if they exist on the model
                'bc_processed_dim': model.bc_processed_dim if hasattr(model, 'bc_processed_dim') else None,
                'hidden_bc_processor_dim': getattr(model, 'hidden_bc_processor_dim', None) # Another way to safely get attr
            }
            torch.save(save_dict, checkpoint_path)

    print("Training finished.")
    if os.path.exists(checkpoint_path):
        print(f"Loading best model from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
    return model

def validate_multivar_model(model, data_loader, dataset_type, 
                            train_nt_for_model_training: int,
                            T_value_for_model_training: float,
                            full_T_in_datafile: float,
                            full_nt_in_datafile: int, device='cuda',
                            save_fig_path='rom_result.png'):
    model.eval()
    results = {key: {'mse': [], 'rmse': [], 'relative_error': [], 'max_error': []}
               for key in model.state_keys}
    overall_rel_err_at_train_T = [] 

    test_horizons_T_values = [T_value_for_model_training,
                              T_value_for_model_training + 0.5 * (full_T_in_datafile - T_value_for_model_training),
                              full_T_in_datafile]
    test_horizons_T_values = sorted(list(set(h for h in test_horizons_T_values if h <= full_T_in_datafile and h > 0)))

    print(f"Validation Horizons (T values): {test_horizons_T_values}")
    print(f"Model was trained with nt={train_nt_for_model_training} for T={T_value_for_model_training}")
    print(f"Datafile contains nt={full_nt_in_datafile} for T={full_T_in_datafile}")

    try:
        state_data_full_batch, BC_Ctrl_tensor_full_batch, norm_factors_batch = next(iter(data_loader))
    except StopIteration:
        print("No validation data. Skipping validation.")
        return

    if isinstance(state_data_full_batch, list):
        state_tensors_full_sample = [s[0:1].to(device) for s in state_data_full_batch]
        _, nt_loaded, nx_loaded = state_tensors_full_sample[0].shape
    else:
        state_tensors_full_sample = [state_data_full_batch[0:1].to(device)]
        _, nt_loaded, nx_loaded = state_tensors_full_sample[0].shape
    BC_Ctrl_tensor_full_sample = BC_Ctrl_tensor_full_batch[0:1].to(device)

    if nt_loaded != full_nt_in_datafile:
        print(f"Warning: nt from val_loader ({nt_loaded}) != full_nt_in_datafile ({full_nt_in_datafile}). Check val_dataset setup.")
        # This might indicate val_dataset was accidentally truncated. For now, proceed with nt_loaded.
        # full_nt_in_datafile = nt_loaded # Adjust if this is the intended behavior

    norm_factors_sample = {}
    for key_nf, val_tensor_nf in norm_factors_batch.items(): # Renamed loop variables
        if isinstance(val_tensor_nf, torch.Tensor) or isinstance(val_tensor_nf, np.ndarray):
            if val_tensor_nf.ndim > 0 :
                norm_factors_sample[key_nf] = val_tensor_nf[0] 
            else:
                norm_factors_sample[key_nf] = val_tensor_nf
        else: 
            norm_factors_sample[key_nf] = val_tensor_nf

    state_keys = model.state_keys
    current_batch_size = 1

    a0_dict = {}
    BC0_full_sample = BC_Ctrl_tensor_full_sample[:, 0, :]
    U_B0_lifted_sample = model._compute_U_B(BC0_full_sample) # Use model's _compute_U_B
    for k_var_idx, key_a0 in enumerate(state_keys): # Renamed loop variable
        U0_full_var_s = state_tensors_full_sample[k_var_idx][:, 0, :].unsqueeze(-1)
        U_B0_var_s = U_B0_lifted_sample[:, k_var_idx, :].unsqueeze(-1)
        Phi_s = model.get_basis(key_a0).to(device)
        Phi_T_s = Phi_s.transpose(0, 1).unsqueeze(0)
        a0_val = torch.bmm(Phi_T_s, U0_full_var_s - U_B0_var_s).squeeze(-1)
        a0_dict[key_a0] = a0_val

    for T_test_horizon in test_horizons_T_values:
        nt_for_this_horizon = int((T_test_horizon / full_T_in_datafile) * (full_nt_in_datafile - 1)) + 1
        nt_for_this_horizon = min(nt_for_this_horizon, nt_loaded) # Cap at actual loaded timesteps

        print(f"\n--- Validating for T_horizon = {T_test_horizon:.2f} (nt = {nt_for_this_horizon}) ---")
        BC_seq_for_pred = BC_Ctrl_tensor_full_sample[:, :nt_for_this_horizon, :]
        U_hat_seq_dict, _ = model(a0_dict, BC_seq_for_pred, T=nt_for_this_horizon)

        combined_pred_denorm = []
        combined_gt_denorm = []
        num_vars_plot = len(state_keys)
        fig, axs = plt.subplots(num_vars_plot, 3, figsize=(18, 5 * num_vars_plot), squeeze=False)
        L_vis = 1.0 # Assuming L=1.0, adjust if dynamic

        for k_var_idx_plot, key_plot in enumerate(state_keys): # Renamed loop variables
            pred_norm_stacked = torch.cat(U_hat_seq_dict[key_plot], dim=0)
            pred_norm_reshaped = pred_norm_stacked.view(nt_for_this_horizon, current_batch_size, nx_loaded)
            pred_norm_final = pred_norm_reshaped.squeeze(1).detach().cpu().numpy()

            mean_k_val = norm_factors_sample[f'{key_plot}_mean']
            std_k_val = norm_factors_sample[f'{key_plot}_std']
            mean_k = mean_k_val.item() if hasattr(mean_k_val, 'item') else mean_k_val
            std_k = std_k_val.item() if hasattr(std_k_val, 'item') else std_k_val
            pred_denorm = pred_norm_final * std_k + mean_k

            gt_norm_full_var = state_tensors_full_sample[k_var_idx_plot].squeeze(0).cpu().numpy()
            gt_norm_sliced = gt_norm_full_var[:nt_for_this_horizon, :]
            gt_denorm = gt_norm_sliced * std_k + mean_k

            combined_pred_denorm.append(pred_denorm.flatten())
            combined_gt_denorm.append(gt_denorm.flatten())

            mse_k = np.mean((pred_denorm - gt_denorm)**2)
            rmse_k = np.sqrt(mse_k)
            rel_err_k = np.linalg.norm(pred_denorm - gt_denorm, 'fro') / (np.linalg.norm(gt_denorm, 'fro') + 1e-10)
            max_err_k = np.max(np.abs(pred_denorm - gt_denorm))
            print(f"  '{key_plot}': MSE={mse_k:.3e}, RMSE={rmse_k:.3e}, RelErr={rel_err_k:.3e}, MaxErr={max_err_k:.3e}")

            if abs(T_test_horizon - T_value_for_model_training) < 1e-5:
                results[key_plot]['mse'].append(mse_k)
                results[key_plot]['rmse'].append(rmse_k)
                results[key_plot]['relative_error'].append(rel_err_k)
                results[key_plot]['max_error'].append(max_err_k)

            diff_plot = np.abs(pred_denorm - gt_denorm)
            vmin_plot = min(gt_denorm.min(), pred_denorm.min())
            vmax_plot = max(gt_denorm.max(), pred_denorm.max())

            im0 = axs[k_var_idx_plot,0].imshow(gt_denorm, aspect='auto', origin='lower', vmin=vmin_plot, vmax=vmax_plot, extent=[0, L_vis, 0, T_test_horizon], cmap='viridis')
            axs[k_var_idx_plot,0].set_title(f"GT ({key_plot}) T={T_test_horizon:.1f}"); axs[k_var_idx_plot,0].set_ylabel("t")
            plt.colorbar(im0, ax=axs[k_var_idx_plot,0])
            im1 = axs[k_var_idx_plot,1].imshow(pred_denorm, aspect='auto', origin='lower', vmin=vmin_plot, vmax=vmax_plot, extent=[0, L_vis, 0, T_test_horizon], cmap='viridis')
            axs[k_var_idx_plot,1].set_title(f"Pred ({key_plot}) T={T_test_horizon:.1f}")
            plt.colorbar(im1, ax=axs[k_var_idx_plot,1])
            im2 = axs[k_var_idx_plot,2].imshow(diff_plot, aspect='auto', origin='lower', extent=[0, L_vis, 0, T_test_horizon], cmap='magma')
            axs[k_var_idx_plot,2].set_title(f"Error ({key_plot}) (Max {max_err_k:.2e})")
            plt.colorbar(im2, ax=axs[k_var_idx_plot,2])
            for j_plot in range(3): axs[k_var_idx_plot,j_plot].set_xlabel("x")

        overall_rel_err_horizon = np.linalg.norm(np.concatenate(combined_pred_denorm) - np.concatenate(combined_gt_denorm)) / \
                                  (np.linalg.norm(np.concatenate(combined_gt_denorm)) + 1e-10)
        print(f"  Overall RelErr for T={T_test_horizon:.1f}: {overall_rel_err_horizon:.3e}")
        if abs(T_test_horizon - T_value_for_model_training) < 1e-5:
            overall_rel_err_at_train_T.append(overall_rel_err_horizon)

        fig.suptitle(f"Validation @ T={T_test_horizon:.1f} ({dataset_type.upper()}) — basis={model.basis_dim}, d_model={getattr(model, 'd_model', 'N/A')}")
        fig.tight_layout(rect=[0,0.03,1,0.95])
        horizon_fig_path = save_fig_path.replace('.png', f'_T{str(T_test_horizon).replace(".","p")}.png')
        plt.savefig(horizon_fig_path)
        print(f"Saved validation figure to: {horizon_fig_path}")
        plt.close(fig) # Close the figure to free memory

    print(f"\n--- Validation Summary (Metrics for T={T_value_for_model_training:.1f}) ---")
    for key_sum in state_keys: # Renamed loop variable
        if results[key_sum]['mse']:
            avg_mse = np.mean(results[key_sum]['mse'])
            avg_rmse = np.mean(results[key_sum]['rmse'])
            avg_rel = np.mean(results[key_sum]['relative_error'])
            avg_max = np.mean(results[key_sum]['max_error'])
            print(f"  {key_sum}: MSE={avg_mse:.4e}, RMSE={avg_rmse:.4e}, RelErr={avg_rel:.4e}, MaxErr={avg_max:.4e}")
    if overall_rel_err_at_train_T:
        print(f"Overall Avg RelErr for T={T_value_for_model_training:.1f}: {np.mean(overall_rel_err_at_train_T):.4e}")


# =============================================================================
# 6. 主流程 (Main script)
# =============================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train ROM models with various configurations including ablations.")
    parser.add_argument('--datatype', type=str, required=True,
                        choices=['heat_delayed_feedback', 'reaction_diffusion_neumann_feedback',
                                 'heat_nonlinear_feedback_gain', 'convdiff'],
                        help='Type of dataset to use.')
    parser.add_argument('--model_variant', type=str, required=True,
                        choices=['explicit_bc_no_attn', 'implicit_bc_no_attn'],
                        help='Which model variant to run.')
    
    parser.add_argument('--use_fixed_lifting', action='store_true',
                        help='Use fixed linear interpolation for U_B instead of learned lifting.')
    parser.add_argument('--random_phi_init', action='store_true',
                        help='Use random orthogonal initialization for Phi instead of POD.')
    
    # General Hyperparameters
    parser.add_argument('--basis_dim', type=int, default=32, help='Dimension of the reduced basis.')
    parser.add_argument('--d_model', type=int, default=256, help='Model dimension for FFN hidden dim in NoAttention models.') # Re-purposed for FFN
    # num_heads is not used by NoAttention models
    parser.add_argument('--initial_alpha', type=float, default=0.1, help='Initial value for alpha scaling FFN updates.')
    parser.add_argument('--shared_components', action='store_true', 
                        help='Use shared FFN/BC processors for NoAttention models.')

    # Explicit BC NoAttention ROM specific
    parser.add_argument('--bc_processed_dim', type=int, default=32, 
                        help='Dimension of processed BC features for explicit BC NoAttention ROM.')
    parser.add_argument('--hidden_bc_processor_dim', type=int, default=128,
                        help='Hidden dimension for the BC feature processor MLP in explicit BC NoAttention ROM.')
    
    # Training params
    parser.add_argument('--num_epochs', type=int, default=150)
    parser.add_argument('--lr', type=float, default=5e-4) # Default from user's explicit_bc script
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--clip_grad_norm', type=float, default=1.0)
    parser.add_argument('--lambda_res', type=float, default=0.05)
    parser.add_argument('--lambda_orth', type=float, default=0.001)
    parser.add_argument('--lambda_bc_penalty', type=float, default=0.01)

    args = parser.parse_args()

    DATASET_TYPE = args.datatype 
    MODEL_VARIANT = args.model_variant
    USE_FIXED_LIFTING_ABLATION = args.use_fixed_lifting
    RANDOM_PHI_INIT_ABLATION = args.random_phi_init
    SHARED_COMPONENTS = args.shared_components

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    print(f"Selected Dataset Type: {DATASET_TYPE.upper()}")
    print(f"Selected Model Variant: {MODEL_VARIANT.upper()}")
    print(f"Use Fixed Lifting: {USE_FIXED_LIFTING_ABLATION}")
    print(f"Random Phi Init: {RANDOM_PHI_INIT_ABLATION}")
    print(f"Shared FFN/BC Processors: {SHARED_COMPONENTS}")

    # --- 时间参数 (Time parameters) ---
    if DATASET_TYPE in ['heat_delayed_feedback', 'reaction_diffusion_neumann_feedback', 'heat_nonlinear_feedback_gain', 'convdiff']:
        FULL_T_IN_DATAFILE = 2.0
        FULL_NT_IN_DATAFILE = 300 
    else: 
        FULL_T_IN_DATAFILE = 2.0 
        FULL_NT_IN_DATAFILE = 600 
        
    if DATASET_TYPE in ['heat_delayed_feedback', 'reaction_diffusion_neumann_feedback', 'heat_nonlinear_feedback_gain', 'convdiff']:
        TRAIN_T_TARGET = 1.5 
    else:
        TRAIN_T_TARGET = 1.0 
    
    if FULL_NT_IN_DATAFILE <= 1:
        TRAIN_NT_FOR_MODEL = FULL_NT_IN_DATAFILE
    else:
        TRAIN_NT_FOR_MODEL = int((TRAIN_T_TARGET / FULL_T_IN_DATAFILE) * (FULL_NT_IN_DATAFILE - 1)) + 1

    print(f"Full data T={FULL_T_IN_DATAFILE}, nt={FULL_NT_IN_DATAFILE}")
    print(f"Training will use T={TRAIN_T_TARGET}, nt={TRAIN_NT_FOR_MODEL}")

    # --- 路径设置 (Path settings) ---
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    suffix_lift = "_fixedlift" if USE_FIXED_LIFTING_ABLATION else ""
    suffix_phi = "_randphi" if RANDOM_PHI_INIT_ABLATION else ""
    suffix_shared = "_shared" if SHARED_COMPONENTS else "_pervar"
    
    # Construct run_name based on model_variant and other flags
    if MODEL_VARIANT == 'explicit_bc_no_attn':
        suffix_model = f"_NoAttnExpBC_ffn{args.d_model}_bcp{args.bc_processed_dim}"
    elif MODEL_VARIANT == 'implicit_bc_no_attn':
        suffix_model = f"_NoAttnImpBC_ffn{args.d_model}"
    else:
        raise ValueError(f"Unknown model_variant: {MODEL_VARIANT}")
        
    run_name = f"{DATASET_TYPE}_b{args.basis_dim}{suffix_model}{suffix_shared}{suffix_lift}{suffix_phi}"
    
    # Define unique checkpoint and results directories for these No-Attention ablations
    base_ckpt_dir_no_attn = "./New_ckpt_NoAttention/" # Example new base
    checkpoint_dir = os.path.join(base_ckpt_dir_no_attn, f"_checkpoints_{DATASET_TYPE}")
    results_dir = os.path.join(base_ckpt_dir_no_attn, f"results_{DATASET_TYPE}")
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(results_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'rom_{run_name}.pt') # Using 'rom_' prefix for consistency
    save_fig_path = os.path.join(results_dir, f'rom_result_{run_name}.png')
    basis_dir = os.path.join(checkpoint_dir, 'pod_bases_no_attn') # Specific POD bases dir
    if not RANDOM_PHI_INIT_ABLATION:
        os.makedirs(basis_dir, exist_ok=True)

    # --- 数据集特定参数和加载 (Dataset specific parameters and loading) ---
    if DATASET_TYPE == 'heat_delayed_feedback':
        dataset_path = "./datasets_new_feedback/heat_delayed_feedback_v1_5000s_64nx_300nt.pkl" 
        nx_data_param = 64; state_keys = ['U']
    elif DATASET_TYPE == 'reaction_diffusion_neumann_feedback':
        dataset_path = "./datasets_new_feedback/reaction_diffusion_neumann_feedback_v1_5000s_64nx_300nt.pkl"
        nx_data_param = 64; state_keys = ['U']
    elif DATASET_TYPE == 'heat_nonlinear_feedback_gain':
        dataset_path = "./datasets_new_feedback/heat_nonlinear_feedback_gain_v1_5000s_64nx_300nt.pkl"
        nx_data_param = 64; state_keys = ['U']
    elif DATASET_TYPE == 'convdiff':
        dataset_path = "./datasets_new_feedback/convdiff_v1_5000s_64nx_300nt.pkl"
        nx_data_param = 64; state_keys = ['U']
    else: # Should have been caught by argparse choices, but good to have a fallback
        raise ValueError(f"Unknown dataset_type: {DATASET_TYPE}")

    print(f"Loading dataset for {DATASET_TYPE} from {dataset_path}")
    try:
        with open(dataset_path, 'rb') as f: data_list = pickle.load(f)
        print(f"Loaded {len(data_list)} samples.")
    except FileNotFoundError: print(f"Error: Dataset file not found at {dataset_path}"); exit()
    if not data_list: print("No data generated, exiting."); exit()

    # --- 数据拆分和加载 (Data splitting and loading) ---
    random.shuffle(data_list)
    n_total = len(data_list); n_train = int(0.8 * n_total)
    train_data_list = data_list[:n_train]; val_data_list = data_list[n_train:]
    print(f"Train samples: {len(train_data_list)}, Validation samples: {len(val_data_list)}")

    train_dataset = UniversalPDEDataset(train_data_list, dataset_type=DATASET_TYPE, train_nt_limit=TRAIN_NT_FOR_MODEL)
    val_dataset = UniversalPDEDataset(val_data_list, dataset_type=DATASET_TYPE, train_nt_limit=None) 
    
    # Get num_workers from args if defined, else default
    num_workers = 6 # Default to 0 for easier debugging
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                              num_workers=num_workers, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=num_workers)

    current_nx_model = train_dataset.nx 
    model_num_controls = train_dataset.num_controls 
    model_bc_state_dim = train_dataset.bc_state_dim 

    print(f"Model Configuration: nx={current_nx_model}, basis_dim={args.basis_dim}")
    print(f"Model BC_State dim: {model_bc_state_dim}, Model Num Controls: {model_num_controls}")

    # --- 模型初始化 (Model Initialization) ---
    online_model = None
    if MODEL_VARIANT == 'explicit_bc_no_attn':
        print(f"Initializing NoAttention_ExplicitBC_ROM:")
        print(f"  FFN hidden (d_model)={args.d_model}")
        print(f"  BC processed dim={args.bc_processed_dim}, BC processor hidden={args.hidden_bc_processor_dim}")
        online_model = NoAttention_ExplicitBC_ROM(
            state_variable_keys=state_keys, 
            nx=current_nx_model, 
            basis_dim=args.basis_dim,
            d_model=args.d_model, # For FFN hidden layer
            bc_state_dim=model_bc_state_dim, 
            num_controls=model_num_controls,
            add_error_estimator=False, # Set as needed from args if added
            use_fixed_lifting=USE_FIXED_LIFTING_ABLATION,
            bc_processed_dim=args.bc_processed_dim, 
            hidden_bc_processor_dim=args.hidden_bc_processor_dim,
            initial_alpha=args.initial_alpha
        )
    elif MODEL_VARIANT == 'implicit_bc_no_attn':
        print(f"Initializing NoAttention_ImplicitBC_ROM:")
        print(f"  FFN hidden (d_model)={args.d_model}")
        online_model = NoAttention_ImplicitBC_ROM(
            state_variable_keys=state_keys, 
            nx=current_nx_model, 
            basis_dim=args.basis_dim,
            d_model=args.d_model, # For FFN hidden layer
            bc_state_dim=model_bc_state_dim, 
            num_controls=model_num_controls,
            add_error_estimator=False, # Set as needed
            use_fixed_lifting=USE_FIXED_LIFTING_ABLATION,
            initial_alpha=args.initial_alpha
        )
    else:
        raise ValueError(f"Unknown model_variant specified: {MODEL_VARIANT}")

    if online_model is None: # Should be caught by the ValueError above
        print("Error: Model was not instantiated. Exiting.")
        exit()
    else:
        print(f"Successfully instantiated: {online_model.__class__.__name__}")
        
    # --- POD 基初始化 (POD Basis Initialization) ---
    if not RANDOM_PHI_INIT_ABLATION: 
        print("\nInitializing Phi with POD bases...")
        pod_bases = {}
        if not train_dataset: print("Error: Training dataset is empty. Cannot compute POD."); exit()
        try:
            first_sample_data_pod, _, _ = train_dataset[0]
            first_state_tensor_pod = first_sample_data_pod[0] 
            actual_nt_for_pod = first_state_tensor_pod.shape[0] # Should be TRAIN_NT_FOR_MODEL
        except IndexError: print("Error: Could not access shape from the first sample in train_dataset for POD."); exit()

        for key_pod_loop in state_keys:
            # Include more specific info in basis filename to avoid conflicts
            basis_filename = f'pod_basis_{key_pod_loop}_nx{current_nx_model}_nt{actual_nt_for_pod}_bdim{args.basis_dim}_{MODEL_VARIANT}{suffix_shared}.npy'
            basis_path = os.path.join(basis_dir, basis_filename) 
            loaded_basis = None
            if os.path.exists(basis_path):
                print(f"  Loading existing POD basis for '{key_pod_loop}' from {basis_path}...")
                try: loaded_basis = np.load(basis_path)
                except Exception as e: print(f"  Error loading {basis_path}: {e}. Will recompute."); loaded_basis = None
                if loaded_basis is not None and loaded_basis.shape != (current_nx_model, args.basis_dim):
                    print(f"  Shape mismatch for loaded basis '{key_pod_loop}'. Expected ({current_nx_model}, {args.basis_dim}), got {loaded_basis.shape}. Recomputing."); loaded_basis = None
            
            if loaded_basis is None:
                print(f"  Computing POD basis for '{key_pod_loop}' (using nt={actual_nt_for_pod}, basis_dim={args.basis_dim})...")
                computed_basis = compute_pod_basis_generic(
                    data_list=train_data_list, dataset_type=DATASET_TYPE, state_variable_key=key_pod_loop,
                    nx=current_nx_model, nt=actual_nt_for_pod, basis_dim=args.basis_dim)
                if computed_basis is not None:
                    pod_bases[key_pod_loop] = computed_basis
                    os.makedirs(os.path.dirname(basis_path), exist_ok=True)
                    np.save(basis_path, computed_basis)
                    print(f"  Saved computed POD basis for '{key_pod_loop}' to {basis_path}")
                else: print(f"ERROR: Failed to compute POD basis for '{key_pod_loop}'. Exiting."); exit()
            else: pod_bases[key_pod_loop] = loaded_basis
        
        with torch.no_grad(): 
            for key_phi_init in state_keys:
                if key_phi_init in pod_bases and hasattr(online_model, 'Phi') and key_phi_init in online_model.Phi:
                    model_phi_param = online_model.Phi[key_phi_init]
                    pod_phi_tensor = torch.tensor(pod_bases[key_phi_init].astype(np.float32), device=model_phi_param.device)
                    if model_phi_param.shape == pod_phi_tensor.shape: 
                        model_phi_param.copy_(pod_phi_tensor); print(f"  Initialized Phi for '{key_phi_init}' with POD.")
                    else: print(f"  WARNING: Shape mismatch for Phi '{key_phi_init}'. Expected {model_phi_param.shape}, got {pod_phi_tensor.shape}. Using random init.")
                else: print(f"  WARNING: No POD basis found or Phi module not present for '{key_phi_init}'. Using random init for Phi.")
    else:
        print("\nSkipping POD basis initialization for Phi (using random orthogonal initialization).")

    # --- 训练 (Training) ---
    print(f"\nStarting training for {run_name}...")
    start_train_time = time.time()
    online_model = train_multivar_model(
        online_model, train_loader, dataset_type=DATASET_TYPE,
        train_nt_target=TRAIN_NT_FOR_MODEL, 
        lr=args.lr, num_epochs=args.num_epochs, device=device,
        checkpoint_path=checkpoint_path, 
        lambda_res=args.lambda_res, lambda_orth=args.lambda_orth, 
        lambda_bc_penalty=args.lambda_bc_penalty, clip_grad_norm=args.clip_grad_norm
    )
    end_train_time = time.time()
    print(f"Training took {end_train_time - start_train_time:.2f} seconds.")

    # --- 验证 (Validation) ---
    if val_data_list: 
        print(f"\nStarting validation for {run_name}...")
        validate_multivar_model(
            online_model, val_loader, dataset_type=DATASET_TYPE, device=device,
            save_fig_path=save_fig_path,
            train_nt_for_model_training=TRAIN_NT_FOR_MODEL,
            T_value_for_model_training=TRAIN_T_TARGET,
            full_T_in_datafile=FULL_T_IN_DATAFILE,
            full_nt_in_datafile=FULL_NT_IN_DATAFILE 
        )
    else:
        print("\nNo validation data. Skipping validation.")
    
    print("="*60)
    print(f"Run finished for dataset: {DATASET_TYPE.upper()} - {run_name}")
    print(f"Final checkpoint saved to: {checkpoint_path}")
    if val_data_list: print(f"Validation figure(s) saved with prefix: {save_fig_path.replace('.png','')}")
    print("="*60)
