# train_unisolver.py (Corrected and Adapted for All Datasets)
# =============================================================================
#       Unisolver-inspired Time-Stepping Neural Operator (Adapted for Task)
# =============================================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import time
import pickle
import argparse
from einops.layers.torch import Rearrange
from einops import rearrange

# ---------------------
# Fixed random seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# ---------------------

print(f"Unisolver-Stepper Script (Task Adapted) started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# =============================================================================
# 0. UniversalPDEDataset (Corrected and Expanded)
# =============================================================================
class UniversalPDEDataset(Dataset):
    def __init__(self, data_list, dataset_type, train_nt_limit=None):
        if not data_list: raise ValueError("data_list cannot be empty")
        self.data_list = data_list
        self.dataset_type = dataset_type.lower()
        self.train_nt_limit = train_nt_limit

        first_sample = data_list[0]
        params = first_sample.get('params', {})
        self.nt_from_sample_file, self.nx_from_sample_file = 0, 0

        if self.dataset_type in ['advection', 'burgers']:
            self.nt_from_sample_file, self.nx_from_sample_file = first_sample['U'].shape
            self.state_keys, self.num_state_vars = ['U'], 1
        elif self.dataset_type == 'euler':
            self.nt_from_sample_file, self.nx_from_sample_file = first_sample['rho'].shape
            self.state_keys, self.num_state_vars = ['rho', 'u'], 2
        elif self.dataset_type == 'darcy':
            self.nt_from_sample_file, self.nx_from_sample_file = first_sample['P'].shape[0], first_sample['P'].shape[1]
            self.state_keys, self.num_state_vars = ['P'], 1
        else:
            raise ValueError(f"Unknown or unsupported dataset_type for this script: {self.dataset_type}")

        self.effective_nt_for_loader = self.train_nt_limit if self.train_nt_limit is not None else self.nt_from_sample_file
        self.nx = self.nx_from_sample_file

        self.bc_state_key = 'BC_State'
        self.bc_state_dim = first_sample[self.bc_state_key].shape[1]

        self.bc_control_key = 'BC_Control'
        if self.bc_control_key in first_sample and first_sample[self.bc_control_key] is not None and first_sample[self.bc_control_key].size > 0:
            self.num_controls = first_sample[self.bc_control_key].shape[1]
        else:
            self.num_controls = 0

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        sample = self.data_list[idx]
        norm_factors = {}
        current_nt_for_item = self.effective_nt_for_loader
        state_tensors_norm_list = []

        for key in self.state_keys:
            state_seq = sample[key][:current_nt_for_item, ...]
            state_mean, state_std = np.mean(state_seq), np.std(state_seq) + 1e-8
            state_norm = (state_seq - state_mean) / state_std
            state_tensors_norm_list.append(torch.tensor(state_norm).float())
            norm_factors[f'{key}_mean'], norm_factors[f'{key}_std'] = state_mean, state_std

        bc_state_seq = sample[self.bc_state_key][:current_nt_for_item, :]
        bc_state_norm = np.zeros_like(bc_state_seq, dtype=np.float32)
        norm_factors[f'{self.bc_state_key}_means'] = np.mean(bc_state_seq, axis=0) if bc_state_seq.size > 0 else np.zeros(self.bc_state_dim)
        norm_factors[f'{self.bc_state_key}_stds'] = np.ones(self.bc_state_dim)
        if bc_state_seq.size > 0:
            for k_dim in range(self.bc_state_dim):
                col, mean_k, std_k = bc_state_seq[:, k_dim], np.mean(bc_state_seq[:, k_dim]), np.std(bc_state_seq[:, k_dim])
                if std_k > 1e-8:
                    bc_state_norm[:, k_dim] = (col - mean_k) / std_k
                    norm_factors[f'{self.bc_state_key}_stds'][k_dim] = std_k
                else:
                    bc_state_norm[:, k_dim] = col - mean_k
        bc_state_tensor_norm = torch.tensor(bc_state_norm).float()

        if self.num_controls > 0:
            bc_control_seq = sample[self.bc_control_key][:current_nt_for_item, :]
            bc_control_norm = np.zeros_like(bc_control_seq, dtype=np.float32)
            norm_factors[f'{self.bc_control_key}_means'] = np.mean(bc_control_seq, axis=0) if bc_control_seq.size > 0 else np.zeros(self.num_controls)
            norm_factors[f'{self.bc_control_key}_stds'] = np.ones(self.num_controls)
            if bc_control_seq.size > 0:
                for k_dim in range(self.num_controls):
                    col, mean_k, std_k = bc_control_seq[:, k_dim], np.mean(bc_control_seq[:, k_dim]), np.std(bc_control_seq[:, k_dim])
                    if std_k > 1e-8:
                        bc_control_norm[:, k_dim] = (col - mean_k) / std_k
                        norm_factors[f'{self.bc_control_key}_stds'][k_dim] = std_k
                    else:
                        bc_control_norm[:, k_dim] = col - mean_k
            bc_control_tensor_norm = torch.tensor(bc_control_norm).float()
        else:
            bc_control_tensor_norm = torch.empty((current_nt_for_item, 0), dtype=torch.float32)

        bc_ctrl_tensor_norm = torch.cat((bc_state_tensor_norm, bc_control_tensor_norm), dim=-1)
        output_state_tensors = state_tensors_norm_list[0] if self.num_state_vars == 1 else state_tensors_norm_list
        return output_state_tensors, bc_ctrl_tensor_norm, norm_factors


# =============================================================================
# 1. Unisolver-Stepper Architecture Components
# =============================================================================
def modulate(x, shift, scale):
    return x * (1 + scale) + shift

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads; self.heads = heads
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        out = F.scaled_dot_product_attention(q, k, v)
        out = rearrange(out, 'b h n d -> b n (h d)'); return self.to_out(out)

class UnisolverBlock(nn.Module):
    def __init__(self, dim, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(dim, heads=heads, dim_head=dim_head)
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff = nn.Sequential(nn.Linear(dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, dim))
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
        x_modulated = modulate(self.norm1(x), shift_msa, scale_msa)
        x = x + gate_msa * self.attn(x_modulated)
        x_modulated = modulate(self.norm2(x), shift_mlp, scale_mlp)
        x = x + gate_mlp * self.ff(x_modulated)
        return x

class UnisolverStepper(nn.Module):
    def __init__(self, nx, num_state_vars, bc_ctrl_dim, state_keys, patch_size=8, embed_dim=256, depth=8, heads=8, mlp_dim=512):
        super().__init__()
        self.nx, self.num_state_vars, self.state_keys, self.patch_size = nx, num_state_vars, state_keys, patch_size
        assert nx % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches, patch_dim = nx // patch_size, num_state_vars * patch_size
        self.patch_embed = nn.Sequential(Rearrange('b (n p) c -> b n (p c)', p=patch_size), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, embed_dim), nn.LayerNorm(embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        self.bc_encoder = nn.Sequential(nn.Linear(bc_ctrl_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim))
        self.blocks = nn.ModuleList([UnisolverBlock(embed_dim, heads, embed_dim // heads, mlp_dim) for _ in range(depth)])
        self.final_layer = nn.Sequential(nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6), nn.Linear(embed_dim, patch_dim), Rearrange('b n (p c) -> b (n p) c', p=patch_size, c=num_state_vars))
        self.adaLN_final = nn.Sequential(nn.SiLU(), nn.Linear(embed_dim, 2 * embed_dim, bias=True))
    def forward(self, u_t, bc_ctrl_t):
        x = self.patch_embed(u_t) + self.pos_embed
        c = self.bc_encoder(bc_ctrl_t).unsqueeze(1)
        for block in self.blocks:
            x = block(x, c)
        shift, scale = self.adaLN_final(c).chunk(2, dim=-1)
        x_modulated = modulate(self.final_layer[0](x), shift, scale)
        u_tp1_pred = self.final_layer[1:](x_modulated)
        return u_tp1_pred

# =============================================================================
# 2. Training and Validation Functions
# =============================================================================
def train_unisolver_stepper(model, data_loader, train_nt_for_model, lr=1e-3, num_epochs=50, device='cuda', checkpoint_path='unisolver_ckpt.pt', clip_grad_norm=1.0):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10, verbose=True)
    mse_loss = nn.MSELoss()
    start_epoch, best_loss = 0, float('inf')
    if os.path.exists(checkpoint_path):
        print(f"Loading Unisolver ckpt from {checkpoint_path}...")
        try:
            ckpt = torch.load(checkpoint_path, map_location=device); model.load_state_dict(ckpt['model_state_dict'])
            optimizer.load_state_dict(ckpt['optimizer_state_dict']); start_epoch = ckpt.get('epoch',0)+1; best_loss=ckpt.get('loss',float('inf'))
            print(f"Resuming Unisolver training from epoch {start_epoch}")
        except Exception as e: print(f"Error loading ckpt: {e}. Starting fresh.")
        
    for epoch in range(start_epoch, num_epochs):
        model.train(); epoch_loss, num_batches = 0.0, 0; batch_start_time = time.time()
        for i, (state_data, bc_ctrl_data, _) in enumerate(data_loader):
            if isinstance(state_data, list): state_seq = torch.stack(state_data, dim=-1).to(device)
            else: state_seq = state_data.unsqueeze(-1).to(device)
            bc_ctrl_seq = bc_ctrl_data.to(device)
            optimizer.zero_grad(); total_seq_loss = 0
            for t in range(train_nt_for_model - 1):
                u_n, bc_ctrl_n, u_np1_true = state_seq[:, t, :, :], bc_ctrl_seq[:, t, :], state_seq[:, t + 1, :, :]
                u_np1_pred = model(u_n, bc_ctrl_n)
                step_loss = mse_loss(u_np1_pred, u_np1_true); total_seq_loss += step_loss
            batch_loss = total_seq_loss / (train_nt_for_model - 1)
            epoch_loss += batch_loss.item(); num_batches += 1
            batch_loss.backward()
            if clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)
            optimizer.step()
            if (i + 1) % 50 == 0:
                elapsed = time.time() - batch_start_time
                print(f" Unisolver Ep {epoch+1} B {i+1}/{len(data_loader)}, Loss {batch_loss.item():.3e}, Time/50B {elapsed:.2f}s")
                batch_start_time = time.time()
        avg_epoch_loss = epoch_loss / max(num_batches, 1)
        print(f"Unisolver Epoch {epoch+1}/{num_epochs} Avg Loss: {avg_epoch_loss:.6f}")
        scheduler.step(avg_epoch_loss)
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss; print(f"Saving Unisolver ckpt with loss {best_loss:.6f}")
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': best_loss}, checkpoint_path)
    print("Unisolver Training finished.")
    if os.path.exists(checkpoint_path):
        print("Loading best Unisolver model for validation."); ckpt = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
    return model

def validate_unisolver_stepper(model, data_loader, dataset_type, T_value_for_model_training, full_T_in_datafile, full_nt_in_datafile, dataset_params_for_plot, device='cuda', save_fig_path_prefix='unisolver_result'):
    model.eval()
    state_keys_val, num_state_vars_val = model.state_keys, model.num_state_vars
    test_horizons_T_values = sorted(list(set([T_value_for_model_training, T_value_for_model_training + 0.5 * (full_T_in_datafile - T_value_for_model_training), full_T_in_datafile])))
    print(f"Unisolver Validation for T_horizons: {test_horizons_T_values}")
    try: state_data_full_loaded, BC_Ctrl_tensor_full_loaded, norm_factors_batch = next(iter(data_loader))
    except StopIteration: print("Validation data_loader is empty."); return
    with torch.no_grad():
        if isinstance(state_data_full_loaded, list): state_seq_true_norm_full = torch.stack(state_data_full_loaded, dim=-1)[0].to(device)
        else: state_seq_true_norm_full = state_data_full_loaded.unsqueeze(-1)[0].to(device)
        BC_Ctrl_seq_norm_full = BC_Ctrl_tensor_full_loaded[0].to(device)
        _, nx_plot, _ = state_seq_true_norm_full.shape
        norm_factors_sample = {key: val[0].cpu().numpy() for key, val in norm_factors_batch.items()}
        u_initial_norm = state_seq_true_norm_full[0:1, :, :]
        for T_horizon_current in test_horizons_T_values:
            nt_for_rollout = int((T_horizon_current / full_T_in_datafile) * (full_nt_in_datafile - 1)) + 1
            nt_for_rollout = min(nt_for_rollout, full_nt_in_datafile)
            print(f"\n  Rollout for T_horizon = {T_horizon_current:.2f} (nt = {nt_for_rollout})")
            u_pred_seq_norm_horizon = torch.zeros(nt_for_rollout, nx_plot, num_state_vars_val, device=device)
            u_current_pred_step = u_initial_norm.clone()
            u_pred_seq_norm_horizon[0,:,:] = u_current_pred_step.squeeze(0)
            for t_step in range(nt_for_rollout - 1):
                bc_ctrl_n_step = BC_Ctrl_seq_norm_full[t_step:t_step+1, :]
                u_next_pred_norm_step = model(u_current_pred_step, bc_ctrl_n_step)
                u_pred_seq_norm_horizon[t_step+1,:,:] = u_next_pred_norm_step.squeeze(0)
                u_current_pred_step = u_next_pred_norm_step
            U_pred_denorm_h, U_gt_denorm_h = {}, {}
            state_true_norm_sliced_h = state_seq_true_norm_full[:nt_for_rollout,:,:]
            pred_np_h, gt_np_h = u_pred_seq_norm_horizon.cpu().numpy(), state_true_norm_sliced_h.cpu().numpy()
            for k_idx, key_val in enumerate(state_keys_val):
                mean_k, std_k = norm_factors_sample[f'{key_val}_mean'], norm_factors_sample[f'{key_val}_std']
                pred_denorm_v, gt_denorm_v = pred_np_h[:,:,k_idx] * std_k + mean_k, gt_np_h[:,:,k_idx] * std_k + mean_k
                U_pred_denorm_h[key_val], U_gt_denorm_h[key_val] = pred_denorm_v, gt_denorm_v
                mse_k_h = np.mean((pred_denorm_v - gt_denorm_v)**2)
                rel_err_k_h = np.linalg.norm(pred_denorm_v - gt_denorm_v, 'fro') / (np.linalg.norm(gt_denorm_v, 'fro') + 1e-10)
                print(f"    Metrics '{key_val}' @ T={T_horizon_current:.2f}: MSE={mse_k_h:.3e}, RelErr={rel_err_k_h:.4f}")
            fig, axs = plt.subplots(num_state_vars_val, 3, figsize=(18, 5*num_state_vars_val), squeeze=False)
            for k_idx, key_val in enumerate(state_keys_val):
                gt_p, pred_p, diff_p = U_gt_denorm_h[key_val], U_pred_denorm_h[key_val], np.abs(U_pred_denorm_h[key_val] - U_gt_denorm_h[key_val])
                v_min, v_max = min(np.min(gt_p), np.min(pred_p)), max(np.max(gt_p), np.max(pred_p))
                plot_ext = [0, dataset_params_for_plot.get('L',1.0), 0, T_horizon_current]
                im0=axs[k_idx,0].imshow(gt_p, aspect='auto', origin='lower', vmin=v_min, vmax=v_max, extent=plot_ext, cmap='viridis')
                im1=axs[k_idx,1].imshow(pred_p, aspect='auto', origin='lower',vmin=v_min,vmax=v_max,extent=plot_ext,cmap='viridis')
                im2=axs[k_idx,2].imshow(diff_p, aspect='auto', origin='lower', extent=plot_ext, cmap='magma')
                axs[k_idx,0].set_title(f"Truth ({key_val})"); axs[k_idx,1].set_title(f"Unisolver Pred ({key_val})"); axs[k_idx,2].set_title(f"Abs Error (Max:{np.max(diff_p):.2e})")
                plt.colorbar(im0,ax=axs[k_idx,0]); plt.colorbar(im1,ax=axs[k_idx,1]); plt.colorbar(im2,ax=axs[k_idx,2])
            fig.suptitle(f"Unisolver Validation ({dataset_type.capitalize()}) @ T={T_horizon_current:.2f}")
            fig.tight_layout(rect=[0, 0.03, 1, 0.95]); curr_fig_path = save_fig_path_prefix+f"_T{str(T_horizon_current).replace('.','p')}.png"
            plt.savefig(curr_fig_path); print(f"  Saved Unisolver validation plot to {curr_fig_path}"); plt.close(fig)

# =============================================================================
# 3. Main Execution Block
# =============================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train a Unisolver-inspired stepper model.")
    parser.add_argument('--datatype', type=str, required=True, choices=['advection', 'euler', 'burgers', 'darcy'], help='Type of dataset to train on.')
    args = parser.parse_args()

    DATASET_TYPE = args.datatype
    MODEL_TYPE = 'Unisolver_Stepper'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"--- Running {MODEL_TYPE} for {DATASET_TYPE} on {device} ---")

    # --- Model and Training Hyperparameters ---
    EMBED_DIM = 256; DEPTH = 8; HEADS = 8; MLP_DIM = 512; PATCH_SIZE = 8
    LEARNING_RATE = 5e-4; BATCH_SIZE = 16; NUM_EPOCHS = 80; CLIP_GRAD_NORM = 1.0
    
    # --- Dataset and Time Horizon Parameters ---
    FULL_T_IN_DATAFILE = 2.0; FULL_NT_IN_DATAFILE = 600
    TRAIN_T_TARGET = 1.0 # *** UPDATED AS REQUESTED ***
    TRAIN_NT_FOR_MODEL = int((TRAIN_T_TARGET / FULL_T_IN_DATAFILE) * (FULL_NT_IN_DATAFILE - 1)) + 1

    # --- Data Loading ---
    if DATASET_TYPE == 'advection':
        dataset_path, main_state_keys, main_num_state_vars = "./datasets_full/advection_data_10000s_128nx_600nt.pkl", ['U'], 1
    elif DATASET_TYPE == 'euler':
        dataset_path, main_state_keys, main_num_state_vars = "./datasets_full/euler_data_10000s_128nx_600nt.pkl", ['rho','u'], 2
    elif DATASET_TYPE == 'burgers':
        dataset_path, main_state_keys, main_num_state_vars = "./datasets_full/burgers_data_10000s_128nx_600nt.pkl", ['U'], 1
    elif DATASET_TYPE == 'darcy':
        dataset_path, main_state_keys, main_num_state_vars = "./datasets_full/darcy_data_10000s_128nx_600nt.pkl", ['P'], 1
    dataset_params_for_plot={'nx': 128, 'ny': 1, 'L': 1.0, 'T': FULL_T_IN_DATAFILE}

    print(f"Dataset: {DATASET_TYPE}, Training T={TRAIN_T_TARGET}s")
    with open(dataset_path, 'rb') as f: data_list_all = pickle.load(f)
    random.shuffle(data_list_all)
    n_train = int(0.8 * len(data_list_all)); train_data_list, val_data_list = data_list_all[:n_train], data_list_all[n_train:]
    
    train_dataset = UniversalPDEDataset(train_data_list, dataset_type=DATASET_TYPE, train_nt_limit=TRAIN_NT_FOR_MODEL)
    val_dataset = UniversalPDEDataset(val_data_list, dataset_type=DATASET_TYPE, train_nt_limit=None)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

    # --- Model Initialization with Dynamic Patch Size ---
    actual_bc_ctrl_dim, current_nx = train_dataset.bc_state_dim + train_dataset.num_controls, train_dataset.nx
    if current_nx % PATCH_SIZE != 0:
        possible_patch_sizes = [p for p in range(PATCH_SIZE, 0, -1) if current_nx % p == 0]
        if not possible_patch_sizes: raise ValueError(f"nx={current_nx} has no integer divisors <= {PATCH_SIZE}")
        PATCH_SIZE = possible_patch_sizes[0]
        print(f"Warning: nx={current_nx} not divisible by default patch size. Adjusted PATCH_SIZE to {PATCH_SIZE}.")

    print(f"\nInitializing {MODEL_TYPE} model...")
    unisolver_model = UnisolverStepper(
        nx=current_nx, num_state_vars=main_num_state_vars, bc_ctrl_dim=actual_bc_ctrl_dim,
        state_keys=main_state_keys, patch_size=PATCH_SIZE, embed_dim=EMBED_DIM,
        depth=DEPTH, heads=HEADS, mlp_dim=MLP_DIM
    )
    
    # --- Training and Validation ---
    run_name = f"{DATASET_TYPE}_{MODEL_TYPE}_trainT{TRAIN_T_TARGET}"
    checkpoint_dir = f"./checkpoints_{DATASET_TYPE}_{MODEL_TYPE}"; os.makedirs(checkpoint_dir, exist_ok=True)
    results_dir = f"./results_{DATASET_TYPE}_{MODEL_TYPE}"; os.makedirs(results_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'model_{run_name}.pt')
    save_fig_path_prefix = os.path.join(results_dir, f'result_{run_name}')

    print(f"\nStarting training for {MODEL_TYPE} on {DATASET_TYPE}...")
    start_train_time = time.time()
    unisolver_model = train_unisolver_stepper(
        unisolver_model, train_loader, train_nt_for_model=TRAIN_NT_FOR_MODEL,
        lr=LEARNING_RATE, num_epochs=NUM_EPOCHS, device=device,
        checkpoint_path=checkpoint_path, clip_grad_norm=CLIP_GRAD_NORM
    )
    print(f"Training took {time.time() - start_train_time:.2f} seconds.")

    if val_data_list:
        print(f"\nStarting validation for {MODEL_TYPE} on {DATASET_TYPE}...")
        validate_unisolver_stepper(
            unisolver_model, val_loader, dataset_type=DATASET_TYPE, T_value_for_model_training=TRAIN_T_TARGET,
            full_T_in_datafile=FULL_T_IN_DATAFILE, full_nt_in_datafile=FULL_NT_IN_DATAFILE,
            dataset_params_for_plot=dataset_params_for_plot, device=device,
            save_fig_path_prefix=save_fig_path_prefix
        )

    print("="*60 + f"\nRun finished: {run_name}\n" + "="*60)