# train_unisolver.py
# =============================================================================
#       Unisolver-inspired Time-Stepping Neural Operator (Adapted for Task)
# =============================================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import time
import pickle
import argparse
from einops.layers.torch import Rearrange
from einops import rearrange, repeat

# ---------------------
# Fixed random seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# ---------------------

print(f"Unisolver-Stepper Script (Task Adapted) started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# =============================================================================
# 0. UniversalPDEDataset (Identical to your previous scripts)
#    Ensures data handling is exactly the same for a fair comparison.
# =============================================================================
class UniversalPDEDataset(Dataset):
    def __init__(self, data_list, dataset_type, train_nt_limit=None):
        if not data_list: raise ValueError("data_list cannot be empty")
        self.data_list = data_list
        self.dataset_type = dataset_type.lower()
        self.train_nt_limit = train_nt_limit

        first_sample = data_list[0]
        params = first_sample.get('params', {})
        self.nt_from_sample_file, self.nx_from_sample_file, self.ny_from_sample_file = 0, 0, 1

        if self.dataset_type in ['advection', 'burgers', 'heat_delayed_feedback', 'reaction_diffusion_neumann_feedback', 'heat_nonlinear_feedback_gain', 'convdiff']:
            self.nt_from_sample_file = first_sample['U'].shape[0]
            self.nx_from_sample_file = first_sample['U'].shape[1]
            self.state_keys, self.num_state_vars, self.expected_bc_state_dim = ['U'], 1, 2
        elif self.dataset_type == 'euler':
            self.nt_from_sample_file, self.nx_from_sample_file = first_sample['rho'].shape[0]
            self.state_keys, self.num_state_vars, self.expected_bc_state_dim = ['rho', 'u'], 2, 4
        # ... (Add other dataset types if necessary)
        else: raise ValueError(f"Unknown dataset_type: {self.dataset_type}")

        self.effective_nt_for_loader = self.train_nt_limit if self.train_nt_limit is not None else self.nt_from_sample_file
        self.nx, self.ny = self.nx_from_sample_file, self.ny_from_sample_file

        self.bc_state_key = 'BC_State'
        self.bc_state_dim = first_sample[self.bc_state_key].shape[1]

        self.bc_control_key = 'BC_Control'
        if self.bc_control_key in first_sample and first_sample[self.bc_control_key] is not None and first_sample[self.bc_control_key].size > 0:
            self.num_controls = first_sample[self.bc_control_key].shape[1]
        else:
            self.num_controls = 0

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        # This function is identical to your BENO script's __getitem__
        sample = self.data_list[idx]
        norm_factors = {}
        current_nt_for_item = self.effective_nt_for_loader
        state_tensors_norm_list = []

        for key in self.state_keys:
            state_seq = sample[key][:current_nt_for_item, ...]
            state_mean, state_std = np.mean(state_seq), np.std(state_seq) + 1e-8
            state_norm = (state_seq - state_mean) / state_std
            state_tensors_norm_list.append(torch.tensor(state_norm).float())
            norm_factors[f'{key}_mean'] = state_mean
            norm_factors[f'{key}_std'] = state_std

        bc_state_seq = sample[self.bc_state_key][:current_nt_for_item, :]
        bc_state_norm = np.zeros_like(bc_state_seq, dtype=np.float32)
        norm_factors[f'{self.bc_state_key}_means'] = np.mean(bc_state_seq, axis=0) if bc_state_seq.size > 0 else np.zeros(self.bc_state_dim)
        norm_factors[f'{self.bc_state_key}_stds'] = np.ones(self.bc_state_dim)
        if bc_state_seq.size > 0:
            for k_dim in range(self.bc_state_dim):
                col, mean_k, std_k = bc_state_seq[:, k_dim], np.mean(bc_state_seq[:, k_dim]), np.std(bc_state_seq[:, k_dim])
                if std_k > 1e-8:
                    bc_state_norm[:, k_dim] = (col - mean_k) / std_k
                    norm_factors[f'{self.bc_state_key}_stds'][k_dim] = std_k
                else:
                    bc_state_norm[:, k_dim] = col - mean_k
        bc_state_tensor_norm = torch.tensor(bc_state_norm).float()

        if self.num_controls > 0:
            bc_control_seq = sample[self.bc_control_key][:current_nt_for_item, :]
            bc_control_norm = np.zeros_like(bc_control_seq, dtype=np.float32)
            norm_factors[f'{self.bc_control_key}_means'] = np.mean(bc_control_seq, axis=0) if bc_control_seq.size > 0 else np.zeros(self.num_controls)
            norm_factors[f'{self.bc_control_key}_stds'] = np.ones(self.num_controls)
            if bc_control_seq.size > 0:
                for k_dim in range(self.num_controls):
                    col, mean_k, std_k = bc_control_seq[:, k_dim], np.mean(bc_control_seq[:, k_dim]), np.std(bc_control_seq[:, k_dim])
                    if std_k > 1e-8:
                        bc_control_norm[:, k_dim] = (col - mean_k) / std_k
                        norm_factors[f'{self.bc_control_key}_stds'][k_dim] = std_k
                    else:
                        bc_control_norm[:, k_dim] = col - mean_k
            bc_control_tensor_norm = torch.tensor(bc_control_norm).float()
        else:
            bc_control_tensor_norm = torch.empty((current_nt_for_item, 0), dtype=torch.float32)

        bc_ctrl_tensor_norm = torch.cat((bc_state_tensor_norm, bc_control_tensor_norm), dim=-1)
        output_state_tensors = state_tensors_norm_list[0] if self.num_state_vars == 1 else state_tensors_norm_list
        return output_state_tensors, bc_ctrl_tensor_norm, norm_factors

# =============================================================================
# 1. Unisolver-Stepper Architecture Components
# =============================================================================
def modulate(x, shift, scale):
    """The core modulation function (scale and shift)."""
    return x * (1 + scale) + shift

class Attention(nn.Module):
    """Standard Multi-Head Self-Attention."""
    def __init__(self, dim, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        out = F.scaled_dot_product_attention(q, k, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class UnisolverBlock(nn.Module):
    """
    A single block of the Unisolver, implementing the conditional logic.
    """
    def __init__(self, dim, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(dim, heads=heads, dim_head=dim_head)
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim)
        )
        # This MLP generates the scale, shift, and gate parameters from the PDE condition
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim, bias=True)
        )

    def forward(self, x, c):
        # c is the condition vector (derived from domain-wise components like bc_ctrl_t)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
        
        # Modulate input to Attention block
        x_modulated = modulate(self.norm1(x), shift_msa, scale_msa)
        attn_out = self.attn(x_modulated)
        x = x + gate_msa * attn_out
        
        # Modulate input to FeedForward block
        x_modulated = modulate(self.norm2(x), shift_mlp, scale_mlp)
        ff_out = self.ff(x_modulated)
        x = x + gate_mlp * ff_out
        
        return x

class UnisolverStepper(nn.Module):
    """
    Unisolver-inspired Neural Operator adapted for time-stepping.
    """
    def __init__(self, nx, num_state_vars, bc_ctrl_dim, state_keys,
                 patch_size=8, embed_dim=256, depth=8, heads=8, mlp_dim=512):
        super().__init__()
        self.nx = nx
        self.num_state_vars = num_state_vars
        self.state_keys = state_keys
        self.patch_size = patch_size
        assert nx % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = nx // patch_size
        patch_dim = num_state_vars * patch_size

        # --- Encoders ---
        # 1. Patch embedding for the point-wise input field (u_t)
        self.patch_embed = nn.Sequential(
            Rearrange('b (n p) c -> b n (p c)', p=patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, embed_dim),
            nn.LayerNorm(embed_dim)
        )
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        
        # 2. MLP for the domain-wise condition (bc_ctrl_t)
        self.bc_encoder = nn.Sequential(
            nn.Linear(bc_ctrl_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim)
        )

        # --- Core Transformer ---
        self.blocks = nn.ModuleList([
            UnisolverBlock(embed_dim, heads, embed_dim // heads, mlp_dim) for _ in range(depth)
        ])

        # --- Decoder ---
        self.final_layer = nn.Sequential(
            nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6),
            nn.Linear(embed_dim, patch_dim),
            Rearrange('b n (p c) -> b (n p) c', p=patch_size, c=num_state_vars)
        )
        self.adaLN_final = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 2 * embed_dim, bias=True)
        )
        
    def forward(self, u_t, bc_ctrl_t):
        # 1. Embed inputs
        x = self.patch_embed(u_t) + self.pos_embed
        
        # 2. Encode domain-wise condition
        c = self.bc_encoder(bc_ctrl_t).unsqueeze(1) # Add a sequence dimension for modulation
        
        # 3. Process through conditional Transformer blocks
        for block in self.blocks:
            x = block(x, c)
            
        # 4. Final conditional layer and decoding
        shift, scale = self.adaLN_final(c).chunk(2, dim=-1)
        x_modulated = modulate(self.final_layer[0](x), shift, scale)
        u_tp1_pred = self.final_layer[1:](x_modulated)
        
        return u_tp1_pred

# =============================================================================
# 2. Training and Validation Functions (Adapted for Unisolver Stepper)
# =============================================================================
def train_unisolver_stepper(model, data_loader, dataset_type, train_nt_for_model,
                            lr=1e-3, num_epochs=50, device='cuda',
                            checkpoint_path='unisolver_ckpt.pt', clip_grad_norm=1.0):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10, verbose=True)
    mse_loss = nn.MSELoss()
    
    start_epoch, best_loss = 0, float('inf')
    # ... (Checkpoint loading logic is identical to previous scripts) ...
    if os.path.exists(checkpoint_path):
        print(f"Loading Unisolver ckpt from {checkpoint_path}...")
        try:
            ckpt = torch.load(checkpoint_path, map_location=device); model.load_state_dict(ckpt['model_state_dict'])
            optimizer.load_state_dict(ckpt['optimizer_state_dict'])
            start_epoch = ckpt.get('epoch',0)+1; best_loss=ckpt.get('loss',float('inf'))
            print(f"Resuming Unisolver training from epoch {start_epoch}")
        except Exception as e: print(f"Error loading ckpt: {e}. Starting fresh.")
        
    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss, num_batches = 0.0, 0
        batch_start_time = time.time()
        
        for i, (state_data, bc_ctrl_data, _) in enumerate(data_loader):
            if isinstance(state_data, list):
                state_seq = torch.stack(state_data, dim=-1).to(device)
            else:
                state_seq = state_data.unsqueeze(-1).to(device)
            bc_ctrl_seq = bc_ctrl_data.to(device)

            optimizer.zero_grad()
            total_seq_loss = 0
            
            for t in range(train_nt_for_model - 1):
                u_n = state_seq[:, t, :, :]
                bc_ctrl_n = bc_ctrl_seq[:, t, :]
                u_np1_true = state_seq[:, t + 1, :, :]
                
                u_np1_pred = model(u_n, bc_ctrl_n)
                
                step_loss = mse_loss(u_np1_pred, u_np1_true)
                total_seq_loss += step_loss
            
            batch_loss = total_seq_loss / (train_nt_for_model - 1)
            epoch_loss += batch_loss.item()
            num_batches += 1
            
            batch_loss.backward()
            if clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)
            optimizer.step()

            if (i + 1) % 10 == 0:
                elapsed = time.time() - batch_start_time
                print(f" Unisolver Ep {epoch+1} B {i+1}/{len(data_loader)}, Loss {batch_loss.item():.3e}, Time/50B {elapsed:.2f}s")
                batch_start_time = time.time()
                
        avg_epoch_loss = epoch_loss / max(num_batches, 1)
        print(f"Unisolver Epoch {epoch+1}/{num_epochs} Avg Loss: {avg_epoch_loss:.6f}")
        scheduler.step(avg_epoch_loss)
        
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            print(f"Saving Unisolver ckpt with loss {best_loss:.6f}")
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': best_loss}, checkpoint_path)
            
    print("Unisolver Training finished.")
    if os.path.exists(checkpoint_path):
        print("Loading best Unisolver model for validation.")
        ckpt = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
    return model

def validate_unisolver_stepper(model, data_loader, dataset_type,
                               T_value_for_model_training, full_T_in_datafile, full_nt_in_datafile,
                               dataset_params_for_plot, device='cuda',
                               save_fig_path_prefix='unisolver_result'):
    # This entire function is identical to the one for BENO and GNOT,
    # ensuring the testing methodology is exactly the same.
    # It already contains the required multi-horizon evaluation logic.
    model.eval()
    state_keys_val = model.state_keys
    num_state_vars_val = model.num_state_vars

    test_horizons_T_values = [
        T_value_for_model_training,
        T_value_for_model_training + 0.5 * (full_T_in_datafile - T_value_for_model_training),
        full_T_in_datafile
    ]
    test_horizons_T_values = sorted(list(set(test_horizons_T_values)))
    print(f"Unisolver Validation for T_horizons: {test_horizons_T_values}")

    try:
        state_data_full_loaded, BC_Ctrl_tensor_full_loaded, norm_factors_batch = next(iter(data_loader))
    except StopIteration:
        print("Validation data_loader is empty. Skipping validation.")
        return

    with torch.no_grad():
        if isinstance(state_data_full_loaded, list):
            state_seq_true_norm_full = torch.stack(state_data_full_loaded, dim=-1)[0].to(device)
        else:
            state_seq_true_norm_full = state_data_full_loaded.unsqueeze(-1)[0].to(device)
        
        BC_Ctrl_seq_norm_full = BC_Ctrl_tensor_full_loaded[0].to(device)
        _, nx_plot, _ = state_seq_true_norm_full.shape
        norm_factors_sample = {key: val[0].cpu().numpy() for key, val in norm_factors_batch.items()}
        u_initial_norm = state_seq_true_norm_full[0:1, :, :]

        for T_horizon_current in test_horizons_T_values:
            nt_for_rollout = int((T_horizon_current / full_T_in_datafile) * (full_nt_in_datafile - 1)) + 1
            nt_for_rollout = min(nt_for_rollout, full_nt_in_datafile)
            print(f"\n  Rollout for T_horizon = {T_horizon_current:.2f} (nt = {nt_for_rollout})")

            u_pred_seq_norm_horizon = torch.zeros(nt_for_rollout, nx_plot, num_state_vars_val, device=device)
            u_current_pred_step = u_initial_norm.clone()
            u_pred_seq_norm_horizon[0,:,:] = u_current_pred_step.squeeze(0)

            for t_step in range(nt_for_rollout - 1):
                bc_ctrl_n_step = BC_Ctrl_seq_norm_full[t_step:t_step+1, :]
                u_next_pred_norm_step = model(u_current_pred_step, bc_ctrl_n_step)
                u_pred_seq_norm_horizon[t_step+1,:,:] = u_next_pred_norm_step.squeeze(0)
                u_current_pred_step = u_next_pred_norm_step
            
            U_pred_denorm_h, U_gt_denorm_h = {}, {}
            state_true_norm_sliced_h = state_seq_true_norm_full[:nt_for_rollout,:,:]
            pred_np_h, gt_np_h = u_pred_seq_norm_horizon.cpu().numpy(), state_true_norm_sliced_h.cpu().numpy()

            for k_idx, key_val in enumerate(state_keys_val):
                mean_k, std_k = norm_factors_sample[f'{key_val}_mean'], norm_factors_sample[f'{key_val}_std']
                pred_denorm_v, gt_denorm_v = pred_np_h[:,:,k_idx] * std_k + mean_k, gt_np_h[:,:,k_idx] * std_k + mean_k
                U_pred_denorm_h[key_val], U_gt_denorm_h[key_val] = pred_denorm_v, gt_denorm_v
                mse_k_h = np.mean((pred_denorm_v - gt_denorm_v)**2)
                rel_err_k_h = np.linalg.norm(pred_denorm_v - gt_denorm_v, 'fro') / (np.linalg.norm(gt_denorm_v, 'fro') + 1e-10)
                print(f"    Metrics '{key_val}' @ T={T_horizon_current:.2f}: MSE={mse_k_h:.3e}, RelErr={rel_err_k_h:.4f}")

            fig, axs = plt.subplots(num_state_vars_val, 3, figsize=(18, 5*num_state_vars_val), squeeze=False)
            for k_idx, key_val in enumerate(state_keys_val):
                gt_p, pred_p = U_gt_denorm_h[key_val], U_pred_denorm_h[key_val]
                diff_p = np.abs(pred_p - gt_p)
                v_min, v_max = min(np.min(gt_p), np.min(pred_p)), max(np.max(gt_p), np.max(pred_p))
                plot_ext = [0, dataset_params_for_plot.get('L',1.0), 0, T_horizon_current]
                im0=axs[k_idx,0].imshow(gt_p, aspect='auto', origin='lower', vmin=v_min, vmax=v_max, extent=plot_ext, cmap='viridis')
                im1=axs[k_idx,1].imshow(pred_p, aspect='auto', origin='lower',vmin=v_min,vmax=v_max,extent=plot_ext,cmap='viridis')
                im2=axs[k_idx,2].imshow(diff_p, aspect='auto', origin='lower', extent=plot_ext, cmap='magma')
                axs[k_idx,0].set_title(f"Truth ({key_val})"); axs[k_idx,1].set_title(f"Unisolver Pred ({key_val})"); axs[k_idx,2].set_title(f"Abs Error (Max:{np.max(diff_p):.2e})")
                plt.colorbar(im0,ax=axs[k_idx,0]); plt.colorbar(im1,ax=axs[k_idx,1]); plt.colorbar(im2,ax=axs[k_idx,2])

            fig.suptitle(f"Unisolver Validation ({dataset_type.capitalize()}) @ T={T_horizon_current:.2f}")
            fig.tight_layout(rect=[0, 0.03, 1, 0.95])
            curr_fig_path = save_fig_path_prefix+f"_T{str(T_horizon_current).replace('.','p')}.png"
            plt.savefig(curr_fig_path)
            print(f"  Saved Unisolver validation plot to {curr_fig_path}")
            plt.close(fig)

# =============================================================================
# 3. Main Execution Block
# =============================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train a Unisolver-inspired stepper model.")
    parser.add_argument('--datatype', type=str, required=True, 
                        choices=['advection', 'euler', 'burgers', 'darcy','heat_delayed_feedback',
                                 'reaction_diffusion_neumann_feedback','heat_nonlinear_feedback_gain','convdiff'])
    args = parser.parse_args()

    DATASET_TYPE = args.datatype
    MODEL_TYPE = 'Unisolver_Stepper'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"--- Running {MODEL_TYPE} for {DATASET_TYPE} on {device} ---")

    # --- Model and Training Hyperparameters ---
    EMBED_DIM = 256
    DEPTH = 8
    HEADS = 8
    MLP_DIM = 512
    PATCH_SIZE = 8
    LEARNING_RATE = 5e-4
    BATCH_SIZE = 16
    NUM_EPOCHS = 150
    CLIP_GRAD_NORM = 1.0
    
    # --- Dataset and Time Horizon Parameters ---
    FULL_T_IN_DATAFILE = 2.0
    FULL_NT_IN_DATAFILE = 300
    TRAIN_T_TARGET = 1.5 # CRITICAL: Train only on partial data
    TRAIN_NT_FOR_MODEL = int((TRAIN_T_TARGET / FULL_T_IN_DATAFILE) * (FULL_NT_IN_DATAFILE - 1)) + 1

    # --- Data Loading and Model Init ---
    # ... (This section is identical to your previous scripts) ...
    dataset_params_for_plot = {}
    if DATASET_TYPE in ['heat_delayed_feedback', 'reaction_diffusion_neumann_feedback', 'heat_nonlinear_feedback_gain', 'convdiff']:
        dataset_path = f"./datasets_new_feedback/{DATASET_TYPE}_v1_5000s_64nx_300nt.pkl"
        main_state_keys, main_num_state_vars = ['U'], 1
        dataset_params_for_plot = {'nx': 64, 'ny': 1, 'L': 1.0}
    else: # Fallback
        dataset_path = f"./datasets_full/{DATASET_TYPE}_data_10000s_128nx_600nt.pkl"
        main_state_keys = ['rho','u'] if DATASET_TYPE == 'euler' else ['U']
        main_num_state_vars = 2 if DATASET_TYPE == 'euler' else 1
        dataset_params_for_plot = {'nx': 128, 'ny': 1, 'L': 1.0}

    print(f"Loading dataset: {dataset_path}")
    with open(dataset_path, 'rb') as f: data_list_all = pickle.load(f)
    random.shuffle(data_list_all)
    n_train = int(0.8 * len(data_list_all))
    train_data_list, val_data_list = data_list_all[:n_train], data_list_all[n_train:]
    
    train_dataset = UniversalPDEDataset(train_data_list, dataset_type=DATASET_TYPE, train_nt_limit=TRAIN_NT_FOR_MODEL)
    val_dataset = UniversalPDEDataset(val_data_list, dataset_type=DATASET_TYPE, train_nt_limit=None)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

    actual_bc_ctrl_dim = train_dataset.bc_state_dim + train_dataset.num_controls
    current_nx = train_dataset.nx
    
    if current_nx % PATCH_SIZE != 0:
        raise ValueError(f"Dataset spatial dimension nx={current_nx} must be divisible by PATCH_SIZE={PATCH_SIZE}")

    print(f"\nInitializing {MODEL_TYPE} model...")
    unisolver_model = UnisolverStepper(
        nx=current_nx, num_state_vars=main_num_state_vars, bc_ctrl_dim=actual_bc_ctrl_dim,
        state_keys=main_state_keys, patch_size=PATCH_SIZE, embed_dim=EMBED_DIM,
        depth=DEPTH, heads=HEADS, mlp_dim=MLP_DIM
    )
    
    run_name = f"{DATASET_TYPE}_{MODEL_TYPE}_trainT{TRAIN_T_TARGET}"
    # ... (Path setup is identical) ...
    checkpoint_dir = f"./checkpoints_{DATASET_TYPE}_{MODEL_TYPE}"
    results_dir = f"./results_{DATASET_TYPE}_{MODEL_TYPE}"
    os.makedirs(checkpoint_dir, exist_ok=True); os.makedirs(results_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'model_{run_name}.pt')
    save_fig_path_prefix = os.path.join(results_dir, f'result_{run_name}')

    print(f"\nStarting training for {MODEL_TYPE} on {DATASET_TYPE}...")
    start_train_time = time.time()
    unisolver_model = train_unisolver_stepper(
        unisolver_model, train_loader, dataset_type=DATASET_TYPE, train_nt_for_model=TRAIN_NT_FOR_MODEL,
        lr=LEARNING_RATE, num_epochs=NUM_EPOCHS, device=device,
        checkpoint_path=checkpoint_path, clip_grad_norm=CLIP_GRAD_NORM
    )
    print(f"Training took {time.time() - start_train_time:.2f} seconds.")

    if val_data_list:
        print(f"\nStarting validation for {MODEL_TYPE} on {DATASET_TYPE}...")
        validate_unisolver_stepper(
            unisolver_model, val_loader, dataset_type=DATASET_TYPE,
            T_value_for_model_training=TRAIN_T_TARGET, full_T_in_datafile=FULL_T_IN_DATAFILE,
            full_nt_in_datafile=FULL_NT_IN_DATAFILE, dataset_params_for_plot=dataset_params_for_plot,
            device=device, save_fig_path_prefix=save_fig_path_prefix
        )

    print("="*60 + f"\nRun finished: {run_name}\n" + "="*60)