# gnot_stepper.py
# =============================================================================
#       GNOT-inspired Time-Stepping Neural Operator (Adapted for All Tasks)
# =============================================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import time
import pickle
import argparse

# ---------------------
# Fixed random seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# ---------------------

print(f"GNOT-Stepper Script (Task Adapted) started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# =============================================================================
# 0. UniversalPDEDataset (Corrected and Expanded for Darcy)
# =============================================================================
class UniversalPDEDataset(Dataset):
    def __init__(self, data_list, dataset_type, train_nt_limit=None):
        if not data_list: raise ValueError("data_list cannot be empty")
        self.data_list = data_list
        self.dataset_type = dataset_type.lower()
        self.train_nt_limit = train_nt_limit

        first_sample = data_list[0]
        params = first_sample.get('params', {})
        self.nt_from_sample_file, self.nx_from_sample_file, self.ny_from_sample_file = 0, 0, 1

        if self.dataset_type in ['advection', 'burgers']:
            self.nt_from_sample_file = first_sample['U'].shape[0]
            self.nx_from_sample_file = first_sample['U'].shape[1]
            self.state_keys, self.num_state_vars, self.expected_bc_state_dim = ['U'], 1, 2
        elif self.dataset_type == 'euler':
            self.nt_from_sample_file = first_sample['rho'].shape[0]
            self.nx_from_sample_file = first_sample['rho'].shape[1]
            self.state_keys, self.num_state_vars, self.expected_bc_state_dim = ['rho', 'u'], 2, 4
        elif self.dataset_type == 'darcy': # *** ADDED THIS BLOCK ***
            self.nt_from_sample_file = first_sample['P'].shape[0]
            self.nx_from_sample_file = params.get('nx', first_sample['P'].shape[1])
            self.state_keys, self.num_state_vars, self.expected_bc_state_dim = ['P'], 1, 2
        else:
            # Added a fallback for other potential datasets you mentioned
            print(f"Warning: Dataset type '{self.dataset_type}' not explicitly handled. Assuming 'U' as state variable.")
            self.nt_from_sample_file = first_sample['U'].shape[0]
            self.nx_from_sample_file = first_sample['U'].shape[1]
            self.state_keys, self.num_state_vars, self.expected_bc_state_dim = ['U'], 1, 2

        self.effective_nt_for_loader = self.train_nt_limit if self.train_nt_limit is not None else self.nt_from_sample_file
        self.nx, self.ny = self.nx_from_sample_file, self.ny_from_sample_file

        self.bc_state_key = 'BC_State'
        self.bc_state_dim = first_sample[self.bc_state_key].shape[1]

        self.bc_control_key = 'BC_Control'
        if self.bc_control_key in first_sample and first_sample[self.bc_control_key] is not None and first_sample[self.bc_control_key].size > 0:
            self.num_controls = first_sample[self.bc_control_key].shape[1]
        else:
            self.num_controls = 0

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        # This function is identical to your BENO script's __getitem__
        sample = self.data_list[idx]
        norm_factors = {}
        current_nt_for_item = self.effective_nt_for_loader
        state_tensors_norm_list = []

        for key in self.state_keys:
            state_seq = sample[key][:current_nt_for_item, ...]
            state_mean, state_std = np.mean(state_seq), np.std(state_seq) + 1e-8
            state_norm = (state_seq - state_mean) / state_std
            state_tensors_norm_list.append(torch.tensor(state_norm).float())
            norm_factors[f'{key}_mean'] = state_mean
            norm_factors[f'{key}_std'] = state_std

        bc_state_seq = sample[self.bc_state_key][:current_nt_for_item, :]
        bc_state_norm = np.zeros_like(bc_state_seq, dtype=np.float32)
        norm_factors[f'{self.bc_state_key}_means'] = np.mean(bc_state_seq, axis=0) if bc_state_seq.size > 0 else np.zeros(self.bc_state_dim)
        norm_factors[f'{self.bc_state_key}_stds'] = np.ones(self.bc_state_dim)
        if bc_state_seq.size > 0:
            for k_dim in range(self.bc_state_dim):
                col, mean_k, std_k = bc_state_seq[:, k_dim], np.mean(bc_state_seq[:, k_dim]), np.std(bc_state_seq[:, k_dim])
                if std_k > 1e-8:
                    bc_state_norm[:, k_dim] = (col - mean_k) / std_k
                    norm_factors[f'{self.bc_state_key}_stds'][k_dim] = std_k
                else:
                    bc_state_norm[:, k_dim] = col - mean_k
        bc_state_tensor_norm = torch.tensor(bc_state_norm).float()

        if self.num_controls > 0:
            bc_control_seq = sample[self.bc_control_key][:current_nt_for_item, :]
            bc_control_norm = np.zeros_like(bc_control_seq, dtype=np.float32)
            norm_factors[f'{self.bc_control_key}_means'] = np.mean(bc_control_seq, axis=0) if bc_control_seq.size > 0 else np.zeros(self.num_controls)
            norm_factors[f'{self.bc_control_key}_stds'] = np.ones(self.num_controls)
            if bc_control_seq.size > 0:
                for k_dim in range(self.num_controls):
                    col, mean_k, std_k = bc_control_seq[:, k_dim], np.mean(bc_control_seq[:, k_dim]), np.std(bc_control_seq[:, k_dim])
                    if std_k > 1e-8:
                        bc_control_norm[:, k_dim] = (col - mean_k) / std_k
                        norm_factors[f'{self.bc_control_key}_stds'][k_dim] = std_k
                    else:
                        bc_control_norm[:, k_dim] = col - mean_k
            bc_control_tensor_norm = torch.tensor(bc_control_norm).float()
        else:
            bc_control_tensor_norm = torch.empty((current_nt_for_item, 0), dtype=torch.float32)

        bc_ctrl_tensor_norm = torch.cat((bc_state_tensor_norm, bc_control_tensor_norm), dim=-1)
        output_state_tensors = state_tensors_norm_list[0] if self.num_state_vars == 1 else state_tensors_norm_list
        return output_state_tensors, bc_ctrl_tensor_norm, norm_factors

# =============================================================================
# 1. GNOT-Stepper Architecture Components
# =============================================================================
class MLP(nn.Module):
    """A simple Multi-Layer Perceptron."""
    def __init__(self, input_dim, output_dim, hidden_dims=[], activation=nn.GELU):
        super().__init__()
        layers = []; current_dim = input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(current_dim, h_dim)); layers.append(activation())
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, output_dim)); self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

class LinearAttention(nn.Module):
    def __init__(self, embed_dim, n_head):
        super().__init__()
        self.n_head = n_head; self.head_dim = embed_dim // n_head
        assert self.head_dim * n_head == embed_dim, "embed_dim must be divisible by n_head"
        self.query = nn.Linear(embed_dim, embed_dim); self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim); self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(0.1)
    def forward(self, x_query, x_kv):
        B, N_q, C = x_query.shape; _, N_kv, _ = x_kv.shape
        q = self.query(x_query).view(B, N_q, self.n_head, self.head_dim).transpose(1, 2)
        k = self.key(x_kv).view(B, N_kv, self.n_head, self.head_dim).transpose(1, 2)
        v = self.value(x_kv).view(B, N_kv, self.n_head, self.head_dim).transpose(1, 2)
        q, k = q.softmax(dim=-1), k.softmax(dim=-1)
        context = torch.einsum('b h j d, b h j e -> b h d e', k, v)
        out = torch.einsum('b h i d, b h d e -> b h i e', q, context)
        out = out.transpose(1, 2).reshape(B, N_q, C); out = self.out_proj(out)
        return self.attn_drop(out)

class GNOTAttentionBlock(nn.Module):
    def __init__(self, embed_dim, n_head, n_experts, space_dim):
        super().__init__()
        self.embed_dim, self.n_experts = embed_dim, n_experts
        self.ln1 = nn.LayerNorm(embed_dim); self.ln2 = nn.LayerNorm(embed_dim)
        self.cross_attention = LinearAttention(embed_dim, n_head)
        self.ln3 = nn.LayerNorm(embed_dim)
        self.self_attention = LinearAttention(embed_dim, n_head)
        self.gating_network = MLP(space_dim, n_experts, [embed_dim, embed_dim])
        self.expert_mlps = nn.ModuleList([MLP(embed_dim, embed_dim, [embed_dim * 2, embed_dim]) for _ in range(n_experts)])
    def forward(self, query_embed, input_embed, coords):
        x = query_embed + self.cross_attention(self.ln1(query_embed), self.ln2(input_embed))
        gate_scores = F.softmax(self.gating_network(coords), dim=-1)
        expert_outputs = torch.stack([expert(x) for expert in self.expert_mlps], dim=-1)
        ffn_out = torch.einsum('b n c e, b n e -> b n c', expert_outputs, gate_scores)
        x = x + ffn_out
        x = x + self.self_attention(self.ln3(x), self.ln3(x))
        return x

class GNOTStepper(nn.Module):
    def __init__(self, nx, num_state_vars, bc_ctrl_dim, state_keys, embed_dim=128, n_layers=4, n_head=8, n_experts=4):
        super().__init__()
        self.nx, self.num_state_vars, self.state_keys, self.embed_dim, self.space_dim = nx, num_state_vars, state_keys, embed_dim, 1
        self.coord_encoder = MLP(self.space_dim, embed_dim, [embed_dim, embed_dim])
        self.input_encoder = MLP(num_state_vars + self.space_dim, embed_dim, [embed_dim, embed_dim])
        self.bc_encoder = MLP(bc_ctrl_dim, embed_dim, [embed_dim, embed_dim])
        self.blocks = nn.ModuleList([GNOTAttentionBlock(embed_dim, n_head, n_experts, self.space_dim) for _ in range(n_layers)])
        self.decoder = MLP(embed_dim, num_state_vars, [embed_dim * 2, embed_dim])
        self.register_buffer('grid', torch.linspace(0, 1, nx).view(1, nx, 1))
    def forward(self, u_t, bc_ctrl_t):
        B, device = u_t.shape[0], u_t.device
        grid = self.grid.repeat(B, 1, 1).to(device)
        query_embed = self.coord_encoder(grid)
        u_t_with_coords = torch.cat([u_t, grid], dim=-1)
        input_embed = self.input_encoder(u_t_with_coords)
        bc_embed = self.bc_encoder(bc_ctrl_t).unsqueeze(1).repeat(1, self.nx, 1)
        x = query_embed + bc_embed
        for block in self.blocks:
            x = block(x, input_embed, grid)
        u_tp1_pred = self.decoder(x)
        return u_tp1_pred

# =============================================================================
# 2. Training and Validation Functions (Adapted for GNOT Stepper)
# =============================================================================
def train_gnot_stepper(model, data_loader, dataset_type, train_nt_for_model,
                       lr=1e-3, num_epochs=50, device='cuda',
                       checkpoint_path='gnot_ckpt.pt', clip_grad_norm=1.0):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10, verbose=True)
    mse_loss = nn.MSELoss()
    
    start_epoch, best_loss = 0, float('inf')
    if os.path.exists(checkpoint_path):
        print(f"Loading GNOT checkpoint from {checkpoint_path}...")
        try:
            ckpt = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(ckpt['model_state_dict'])
            optimizer.load_state_dict(ckpt['optimizer_state_dict'])
            start_epoch = ckpt.get('epoch', 0) + 1
            best_loss = ckpt.get('loss', float('inf'))
            print(f"Resuming GNOT training from epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading GNOT checkpoint: {e}. Starting fresh.")

    for epoch in range(start_epoch, num_epochs):
        model.train(); epoch_loss = 0.0; num_batches = 0; batch_start_time = time.time()
        for i, (state_data, bc_ctrl_data, _) in enumerate(data_loader):
            if isinstance(state_data, list): state_seq = torch.stack(state_data, dim=-1).to(device)
            else: state_seq = state_data.unsqueeze(-1).to(device)
            bc_ctrl_seq = bc_ctrl_data.to(device)
            optimizer.zero_grad(); total_seq_loss = 0
            for t in range(train_nt_for_model - 1):
                u_n, bc_ctrl_n, u_np1_true = state_seq[:, t, :, :], bc_ctrl_seq[:, t, :], state_seq[:, t + 1, :, :]
                u_np1_pred = model(u_n, bc_ctrl_n)
                step_loss = mse_loss(u_np1_pred, u_np1_true)
                total_seq_loss += step_loss
            
            batch_loss = total_seq_loss / (train_nt_for_model - 1)
            epoch_loss += batch_loss.item(); num_batches += 1
            batch_loss.backward()
            if clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)
            optimizer.step()

            if (i + 1) % 50 == 0:
                elapsed = time.time() - batch_start_time
                print(f" GNOT Ep {epoch+1} B {i+1}/{len(data_loader)}, Loss {batch_loss.item():.3e}, Time/50B {elapsed:.2f}s")
                batch_start_time = time.time()
        
        avg_epoch_loss = epoch_loss / max(num_batches, 1)
        print(f"GNOT Epoch {epoch+1}/{num_epochs} Avg Loss: {avg_epoch_loss:.6f}")
        scheduler.step(avg_epoch_loss)
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss; print(f"Saving GNOT checkpoint with loss {best_loss:.6f}")
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(), 'loss': best_loss}, checkpoint_path)
            
    print("GNOT Training finished.")
    if os.path.exists(checkpoint_path):
        print("Loading best GNOT model for validation."); ckpt = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
    return model

def validate_gnot_stepper(model, data_loader, dataset_type,
                          T_value_for_model_training, full_T_in_datafile, full_nt_in_datafile,
                          dataset_params_for_plot, device='cuda',
                          save_fig_path_prefix='gnot_result'):
    model.eval()
    state_keys_val, num_state_vars_val = model.state_keys, model.num_state_vars
    test_horizons_T_values = sorted(list(set([T_value_for_model_training, T_value_for_model_training + 0.5 * (full_T_in_datafile - T_value_for_model_training), full_T_in_datafile])))
    print(f"GNOT Validation for T_horizons: {test_horizons_T_values}")
    try:
        state_data_full_loaded, BC_Ctrl_tensor_full_loaded, norm_factors_batch = next(iter(data_loader))
    except StopIteration:
        print("Validation data_loader is empty."); return

    with torch.no_grad():
        if isinstance(state_data_full_loaded, list): state_seq_true_norm_full = torch.stack(state_data_full_loaded, dim=-1)[0].to(device)
        else: state_seq_true_norm_full = state_data_full_loaded.unsqueeze(-1)[0].to(device)
        BC_Ctrl_seq_norm_full = BC_Ctrl_tensor_full_loaded[0].to(device)
        _, nx_plot, _ = state_seq_true_norm_full.shape
        norm_factors_sample = {key: val[0].cpu().numpy() for key, val in norm_factors_batch.items()}
        u_initial_norm = state_seq_true_norm_full[0:1, :, :]

        for T_horizon_current in test_horizons_T_values:
            nt_for_rollout = int((T_horizon_current / full_T_in_datafile) * (full_nt_in_datafile - 1)) + 1
            nt_for_rollout = min(nt_for_rollout, full_nt_in_datafile)
            print(f"\n  Rollout for T_horizon = {T_horizon_current:.2f} (nt = {nt_for_rollout})")
            u_pred_seq_norm_horizon = torch.zeros(nt_for_rollout, nx_plot, num_state_vars_val, device=device)
            u_current_pred_step = u_initial_norm.clone()
            u_pred_seq_norm_horizon[0,:,:] = u_current_pred_step.squeeze(0)
            for t_step in range(nt_for_rollout - 1):
                bc_ctrl_n_step = BC_Ctrl_seq_norm_full[t_step:t_step+1, :]
                u_next_pred_norm_step = model(u_current_pred_step, bc_ctrl_n_step)
                u_pred_seq_norm_horizon[t_step+1,:,:] = u_next_pred_norm_step.squeeze(0)
                u_current_pred_step = u_next_pred_norm_step
            
            U_pred_denorm_h, U_gt_denorm_h = {}, {}
            state_true_norm_sliced_h = state_seq_true_norm_full[:nt_for_rollout,:,:]
            pred_np_h, gt_np_h = u_pred_seq_norm_horizon.cpu().numpy(), state_true_norm_sliced_h.cpu().numpy()
            for k_idx, key_val in enumerate(state_keys_val):
                mean_k, std_k = norm_factors_sample[f'{key_val}_mean'], norm_factors_sample[f'{key_val}_std']
                pred_denorm_v, gt_denorm_v = pred_np_h[:,:,k_idx] * std_k + mean_k, gt_np_h[:,:,k_idx] * std_k + mean_k
                U_pred_denorm_h[key_val], U_gt_denorm_h[key_val] = pred_denorm_v, gt_denorm_v
                mse_k_h = np.mean((pred_denorm_v - gt_denorm_v)**2)
                rel_err_k_h = np.linalg.norm(pred_denorm_v - gt_denorm_v, 'fro') / (np.linalg.norm(gt_denorm_v, 'fro') + 1e-10)
                print(f"    Metrics '{key_val}' @ T={T_horizon_current:.2f}: MSE={mse_k_h:.3e}, RelErr={rel_err_k_h:.4f}")
            
            fig, axs = plt.subplots(num_state_vars_val, 3, figsize=(18, 5*num_state_vars_val), squeeze=False)
            for k_idx, key_val in enumerate(state_keys_val):
                gt_p, pred_p = U_gt_denorm_h[key_val], U_pred_denorm_h[key_val]; diff_p = np.abs(pred_p - gt_p)
                v_min, v_max = min(np.min(gt_p), np.min(pred_p)), max(np.max(gt_p), np.max(pred_p))
                plot_ext = [0, dataset_params_for_plot.get('L',1.0), 0, T_horizon_current]
                im0=axs[k_idx,0].imshow(gt_p, aspect='auto', origin='lower', vmin=v_min, vmax=v_max, extent=plot_ext, cmap='viridis')
                axs[k_idx,0].set_title(f"Truth ({key_val})")
                im1=axs[k_idx,1].imshow(pred_p, aspect='auto', origin='lower',vmin=v_min,vmax=v_max,extent=plot_ext,cmap='viridis')
                axs[k_idx,1].set_title(f"GNOT Pred ({key_val})")
                im2=axs[k_idx,2].imshow(diff_p, aspect='auto', origin='lower', extent=plot_ext, cmap='magma')
                axs[k_idx,2].set_title(f"Abs Error (Max:{np.max(diff_p):.2e})")
                plt.colorbar(im0,ax=axs[k_idx,0]); plt.colorbar(im1,ax=axs[k_idx,1]); plt.colorbar(im2,ax=axs[k_idx,2])
            fig.suptitle(f"GNOT Validation ({dataset_type.capitalize()}) @ T={T_horizon_current:.2f}")
            fig.tight_layout(rect=[0, 0.03, 1, 0.95])
            curr_fig_path = save_fig_path_prefix+f"_T{str(T_horizon_current).replace('.','p')}.png"
            plt.savefig(curr_fig_path); print(f"  Saved GNOT validation plot to {curr_fig_path}"); plt.close(fig)

# =============================================================================
# 3. Main Execution Block
# =============================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train a GNOT-inspired stepper model.")
    parser.add_argument('--datatype', type=str, required=True, choices=['advection', 'euler', 'burgers', 'darcy'], help='Type of dataset to train on.')
    args = parser.parse_args()

    DATASET_TYPE = args.datatype
    MODEL_TYPE = 'GNOT_Stepper'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"--- Running {MODEL_TYPE} for {DATASET_TYPE} on {device} ---")

    # --- Model and Training Hyperparameters ---
    EMBED_DIM = 128; N_LAYERS = 4; NHEAD = 8; N_EXPERTS = 4
    LEARNING_RATE = 5e-4; BATCH_SIZE = 8; NUM_EPOCHS = 150; CLIP_GRAD_NORM = 1.0
    
    # --- Dataset and Time Horizon Parameters ---
    # *** UPDATED TRAINING HORIZON AS REQUESTED ***
    FULL_T_IN_DATAFILE = 2.0
    FULL_NT_IN_DATAFILE = 600 # For the original 4 datasets
    TRAIN_T_TARGET = 1.0 
    TRAIN_NT_FOR_MODEL = int((TRAIN_T_TARGET / FULL_T_IN_DATAFILE) * (FULL_NT_IN_DATAFILE - 1)) + 1

    # --- Data Loading ---
    dataset_params_for_plot = {}
    if DATASET_TYPE == 'advection':
        dataset_path = "./datasets_full/advection_data_10000s_128nx_600nt.pkl"
        main_state_keys=['U']; main_num_state_vars=1
    elif DATASET_TYPE == 'euler':
        dataset_path = "./datasets_full/euler_data_10000s_128nx_600nt.pkl"
        main_state_keys=['rho','u']; main_num_state_vars=2
    elif DATASET_TYPE == 'burgers':
        dataset_path = "./datasets_full/burgers_data_10000s_128nx_600nt.pkl"
        main_state_keys=['U']; main_num_state_vars=1
    elif DATASET_TYPE == 'darcy':
        dataset_path = "./datasets_full/darcy_data_10000s_128nx_600nt.pkl"
        main_state_keys=['P']; main_num_state_vars=1
    else:
        raise ValueError(f"Unknown dataset type: {DATASET_TYPE}")

    dataset_params_for_plot={'nx': 128, 'ny': 1, 'L': 1.0, 'T': FULL_T_IN_DATAFILE}

    print(f"Dataset: {DATASET_TYPE}")
    print(f"Datafile T={FULL_T_IN_DATAFILE}, nt={FULL_NT_IN_DATAFILE}")
    print(f"Training with T_duration={TRAIN_T_TARGET}, nt_steps={TRAIN_NT_FOR_MODEL}")

    print(f"Loading dataset: {dataset_path}")
    with open(dataset_path, 'rb') as f: data_list_all = pickle.load(f)
    random.shuffle(data_list_all)
    n_train = int(0.8 * len(data_list_all))
    train_data_list, val_data_list = data_list_all[:n_train], data_list_all[n_train:]
    
    train_dataset = UniversalPDEDataset(train_data_list, dataset_type=DATASET_TYPE, train_nt_limit=TRAIN_NT_FOR_MODEL)
    val_dataset = UniversalPDEDataset(val_data_list, dataset_type=DATASET_TYPE, train_nt_limit=None)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

    # --- Model Initialization ---
    actual_bc_ctrl_dim = train_dataset.bc_state_dim + train_dataset.num_controls
    current_nx = train_dataset.nx

    print(f"\nInitializing {MODEL_TYPE} model...")
    gnot_model = GNOTStepper(
        nx=current_nx, num_state_vars=main_num_state_vars, bc_ctrl_dim=actual_bc_ctrl_dim,
        state_keys=main_state_keys, embed_dim=EMBED_DIM, n_layers=N_LAYERS,
        n_head=NHEAD, n_experts=N_EXPERTS
    )
    
    # --- Training and Validation ---
    run_name = f"{DATASET_TYPE}_{MODEL_TYPE}_trainT{TRAIN_T_TARGET}"
    checkpoint_dir = f"./checkpoints_{DATASET_TYPE}_{MODEL_TYPE}"; os.makedirs(checkpoint_dir, exist_ok=True)
    results_dir = f"./results_{DATASET_TYPE}_{MODEL_TYPE}"; os.makedirs(results_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'model_{run_name}.pt')
    save_fig_path_prefix = os.path.join(results_dir, f'result_{run_name}')

    print(f"\nStarting training for {MODEL_TYPE} on {DATASET_TYPE}...")
    start_train_time = time.time()
    gnot_model = train_gnot_stepper(
        gnot_model, train_loader, dataset_type=DATASET_TYPE, train_nt_for_model=TRAIN_NT_FOR_MODEL,
        lr=LEARNING_RATE, num_epochs=NUM_EPOCHS, device=device,
        checkpoint_path=checkpoint_path, clip_grad_norm=CLIP_GRAD_NORM
    )
    print(f"Training took {time.time() - start_train_time:.2f} seconds.")

    if val_data_list:
        print(f"\nStarting validation for {MODEL_TYPE} on {DATASET_TYPE}...")
        validate_gnot_stepper(
            gnot_model, val_loader, dataset_type=DATASET_TYPE, T_value_for_model_training=TRAIN_T_TARGET,
            full_T_in_datafile=FULL_T_IN_DATAFILE, full_nt_in_datafile=FULL_NT_IN_DATAFILE,
            dataset_params_for_plot=dataset_params_for_plot, device=device,
            save_fig_path_prefix=save_fig_path_prefix
        )

    print("="*60 + f"\nRun finished: {run_name}" + "\n" + "="*60)