import argparse
import copy
import json
import logging
import os
import pickle
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, List, Dict # Added for 3.9 compatibility

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tabulate import tabulate # For potentially printing summaries

# --- Script Configuration ---
DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - [%(processName)s/%(threadName)s] - %(message)s'
EPSILON = 1e-8 # For numerical stability in loss weighting

# --- Logger Setup ---
logger = logging.getLogger(__name__)

def setup_logger(log_level_str: str = "INFO", log_file: str = None):
    """Configures the global logger for the script.

    Args:
        log_level_str (str, optional): Logging level string (e.g., "INFO", "DEBUG"). Defaults to "INFO".
        log_file (str, optional): Path to a log file. If provided, logs are also written here. Defaults to None.
    """
    log_level = getattr(logging, log_level_str.upper(), logging.INFO)
    logger.setLevel(log_level)
    
    for handler in logger.handlers[:]: # Remove existing handlers
        logger.removeHandler(handler)
            
    formatter = logging.Formatter(DEFAULT_LOG_FORMAT)
    
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(log_level)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    
    if log_file:
        fh = logging.FileHandler(log_file)
        fh.setLevel(log_level)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

def format_time_seconds(seconds: float) -> str:
    """Formats a duration in seconds into H:M:S string.

    Args:
        seconds (float): Duration in seconds.

    Returns:
        str: Formatted time string (H:M:S) or "--:--:--" for negative input.
    """
    if seconds < 0: return "--:--:--"
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    return f"{hours:02d}:{minutes:02d}:{secs:02d}"

# --- Data Handling ---
class EHRShapleyDataset(Dataset):
    """PyTorch Dataset for EHR features and corresponding Shapley values.

    Handles loading of features and a dictionary of Shapley values for multiple tasks.
    NaNs in features and Shapley values are converted to 0.0.
    """
    def __init__(self, features: np.ndarray, shapley_values_dict: dict):
        """Initializes EHRShapleyDataset.

        Args:
            features (np.ndarray): Raw feature data, shape (num_samples, num_features).
            shapley_values_dict (dict): Dictionary where keys are task names (str)
                                        and values are Shapley value arrays (np.ndarray,
                                        shape: num_samples) for that task.
        """
        # Convert NaNs in features to 0.0
        processed_features = np.nan_to_num(features, nan=0.0)
        self.features = torch.tensor(processed_features, dtype=torch.float32)

        self.task_names = sorted(list(shapley_values_dict.keys())) # Ensure consistent task order

        if not self.task_names and shapley_values_dict: # Dict not empty but no keys after sorting (should not happen)
             logger.warning("Shapley values dictionary is not empty but resulted in no task names after sorting.")
        
        num_feature_samples = self.features.shape[0]

        if not shapley_values_dict: # Handle empty Shapley dictionary (e.g. no tasks)
            logger.info("Shapley values dictionary is empty. Dataset will only contain features.")
            # Create an empty tensor for Shapley values if no tasks, maintaining sample dimension
            self.shapley_values = torch.empty((num_feature_samples, 0), dtype=torch.float32)
            self.num_samples = num_feature_samples
            if num_feature_samples == 0:
                 logger.warning("Both features and Shapley values are effectively empty for the dataset.")
            return

        num_samples_in_shapley_sets = {val.shape[0] for val in shapley_values_dict.values()}
        if len(num_samples_in_shapley_sets) > 1:
            raise ValueError(f"Inconsistent number of samples across different Shapley value tasks: {num_samples_in_shapley_sets}")

        # If shapley_values_dict is not empty, num_samples_in_shapley_sets will have one element or be empty
        num_shapley_samples = list(num_samples_in_shapley_sets)[0] if num_samples_in_shapley_sets else 0
        
        if num_feature_samples != num_shapley_samples:
            # Allow case where features are present but no shapley values (e.g. num_shapley_samples = 0 due to empty dict)
            # This is covered by the `if not shapley_values_dict:` block above.
            # This error should only trigger if shapley_values_dict was non-empty but sample counts mismatched.
            raise ValueError(
                f"Mismatch in number of samples: Features have {num_feature_samples}, "
                f"Shapley values have {num_shapley_samples}."
            )
        self.num_samples = num_feature_samples

        processed_shapley_values = []
        for task_name in self.task_names:
            task_s_values = shapley_values_dict[task_name]
            task_s_values_processed = np.nan_to_num(task_s_values, nan=0.0) # Convert NaNs
            processed_shapley_values.append(task_s_values_processed)
        
        if processed_shapley_values: # If there were tasks
            # Transpose to get shape (num_samples, num_tasks)
            self.shapley_values = torch.tensor(np.array(processed_shapley_values).T, dtype=torch.float32)
        else: # Should be caught by earlier empty dict check, but as a safeguard
             self.shapley_values = torch.empty((self.num_samples, 0), dtype=torch.float32)


    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        return self.features[idx], self.shapley_values[idx]

def load_and_split_data(
    feature_file_path: str, 
    shapley_pkl_path: Optional[str], 
    dataset_identifier: str,
    batch_size_value: int, 
    split_random_seed: int, 
    output_base_dir: str,
    num_workers_dl: int = 0, # New parameter
    pin_memory_dl: bool = False # New parameter
) -> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader], List[str], int, Optional[Dict]]:
    """Loads feature and Shapley value data, splits it, and creates DataLoaders.

    Splits data into train/validation/test sets (80/10/10 default) and saves split indices.

    Args:
        feature_file_path (str): Path to the .npy feature file.
        shapley_pkl_path (Optional[str]): Path to the .pkl Shapley values file (optional).
        dataset_identifier (str): Name for the dataset (e.g., "AKI", "MIMIC"), used in output naming.
        batch_size_value (int): Batch size for DataLoaders.
        split_random_seed (int): Random seed for reproducible data splitting.
        output_base_dir (str): Base directory for saving outputs like split indices.
        num_workers_dl (int, optional): Number of worker processes for DataLoader. Defaults to 0.
        pin_memory_dl (bool, optional): If True, DataLoader will copy Tensors into CUDA pinned memory. Defaults to False.

    Returns:
        tuple:
            - train_loader (Optional[DataLoader]): DataLoader for the training set.
            - val_loader (Optional[DataLoader]): DataLoader for the validation set.
            - test_loader (Optional[DataLoader]): DataLoader for the test set.
            - task_names_list (List[str]): Sorted list of task names from Shapley values.
            - inferred_input_dim (int): Inferred feature dimension from data (-1 on error).
            - data_split_indices (Optional[Dict]): Train, validation, and test indices.
    """
    logger.info(f"Loading features from: {feature_file_path}")
    feature_file = Path(feature_file_path)
    if not feature_file.exists():
        logger.error(f"Feature file not found: {feature_file}")
        return None, None, None, [], -1, None
    
    try:
        features_data_np = np.load(feature_file)
    except Exception as e:
        logger.error(f"Failed to load feature file {feature_file}: {e}", exc_info=True)
        return None, None, None, [], -1, None

    if features_data_np.ndim != 2 or features_data_np.shape[1] == 0:
        logger.error(f"Feature data at {feature_file_path} has invalid shape: {features_data_np.shape}. Expected 2D array with >0 features.")
        # Allow proceeding if num_samples is 0 but features > 0 (empty dataset case)
        if features_data_np.shape[0] == 0 and features_data_np.ndim == 2 and features_data_np.shape[1] > 0 :
            logger.warning(f"Feature file {feature_file_path} contains 0 samples but has {features_data_np.shape[1]} features. Proceeding with empty dataset.")
            inferred_input_dim = features_data_np.shape[1]
        else: # Truly invalid shape or 0 features
            return None, None, None, [], -1, None
    else: # Valid 2D shape with features
        inferred_input_dim = features_data_np.shape[1]
    
    logger.info(f"Inferred input feature dimension: {inferred_input_dim}")

    aggregated_shapley_data = {}
    if shapley_pkl_path:
        logger.info(f"Loading Shapley values from: {shapley_pkl_path}")
        shapley_file = Path(shapley_pkl_path)
        if not shapley_file.exists():
            logger.warning(f"Shapley PKL file not found: {shapley_file}. Proceeding without Shapley values.")
        else:
            try:
                with open(shapley_file, 'rb') as f:
                    aggregated_shapley_data = pickle.load(f)
                if not isinstance(aggregated_shapley_data, dict):
                    logger.warning(f"Shapley data in {shapley_pkl_path} is not a dictionary. Treating as empty.")
                    aggregated_shapley_data = {}
                elif not aggregated_shapley_data:
                     logger.info(f"Shapley PKL file {shapley_pkl_path} loaded successfully but is empty.")
            except Exception as e:
                logger.error(f"Failed to load or parse Shapley PKL file {shapley_pkl_path}: {e}", exc_info=True)
                aggregated_shapley_data = {} # Proceed with empty Shapley on error
    else:
        logger.info("No Shapley PKL file provided. Proceeding without Shapley values.")

    try:
        full_dataset = EHRShapleyDataset(features_data_np, aggregated_shapley_data)
    except ValueError as e:
        logger.error(f"Error creating EHRShapleyDataset for {dataset_identifier}: {e}", exc_info=True)
        return None, None, None, [], inferred_input_dim, None # Return inferred_input_dim as it might be known
    
    task_names_list = full_dataset.task_names
    logger.info(f"Dataset '{dataset_identifier}': {len(full_dataset)} samples, {len(task_names_list)} tasks: {task_names_list}")

    if len(full_dataset) == 0:
        logger.warning(f"Dataset '{dataset_identifier}' is empty. Returning empty DataLoaders.")
        # Create empty DataLoaders to avoid downstream errors expecting DataLoaders
        empty_dl = DataLoader(full_dataset, batch_size=batch_size_value)
        return empty_dl, empty_dl, empty_dl, task_names_list, inferred_input_dim, None

    # Splitting data: 80% train, 10% validation, 10% test
    try:
        train_indices, temp_indices = train_test_split(
            list(range(len(full_dataset))),
            test_size=0.2,  # 20% for validation + test
            random_state=split_random_seed,
            stratify=None # No stratification for now, consider if labels for stratification are available and meaningful
        )
        val_indices, test_indices = train_test_split(
            temp_indices,
            test_size=0.5,  # 50% of temp_indices -> 10% of total for test
            random_state=split_random_seed, # Same seed for consistency
            stratify=None
        )
    except ValueError as e_split: # Handles cases like not enough samples for splitting
        logger.error(f"Error during data splitting for {dataset_identifier}: {e_split}. This might happen with very small datasets.")
        # Fallback: use all data for training, empty for val/test if splitting fails
        train_indices = list(range(len(full_dataset)))
        val_indices, test_indices = [], []
        logger.warning("Using all data for training due to splitting error. Validation and test sets will be empty.")


    train_subset = torch.utils.data.Subset(full_dataset, train_indices)
    val_subset = torch.utils.data.Subset(full_dataset, val_indices)
    test_subset = torch.utils.data.Subset(full_dataset, test_indices)

    logger.info(f"Data split for '{dataset_identifier}': Train={len(train_subset)}, Validation={len(val_subset)}, Test={len(test_subset)}")

    # Save split indices
    data_split_indices = {
        'train_idx': [int(i) for i in train_indices],
        'val_idx': [int(i) for i in val_indices],
        'test_idx': [int(i) for i in test_indices]
    }
    # output_base_dir is received as str, convert to Path for path manipulation
    output_base_dir_path = Path(output_base_dir)
    split_indices_filename = output_base_dir_path / f"{dataset_identifier}_split_indices.json"
    try:
        split_indices_filename.parent.mkdir(parents=True, exist_ok=True) # Ensure parent dir exists
        with open(split_indices_filename, 'w') as f:
            json.dump(data_split_indices, f, indent=4)
        logger.info(f"Data split indices saved to: {split_indices_filename}")
    except IOError as e:
        logger.error(f"Failed to save data split indices to {split_indices_filename}: {e}")

    train_loader = DataLoader(train_subset, batch_size=batch_size_value, shuffle=True, num_workers=num_workers_dl, pin_memory=pin_memory_dl)
    val_loader = DataLoader(val_subset, batch_size=batch_size_value, shuffle=False, num_workers=num_workers_dl, pin_memory=pin_memory_dl)
    test_loader = DataLoader(test_subset, batch_size=batch_size_value, shuffle=False, num_workers=num_workers_dl, pin_memory=pin_memory_dl)

    return train_loader, val_loader, test_loader, task_names_list, inferred_input_dim, data_split_indices

# --- Model Definitions ---

class ShapleyPredictorG(nn.Module):
    """First-stage model g^(t)(x_i, θ^(t)): Predicts Shapley value for task t.

    Outputs the predicted Shapley value and the hidden representation h^(t)(x_i).
    """
    def __init__(self, input_dim: int, hidden_dims: list[int], output_dim: int = 1, dropout_p: float = 0.1):
        """Initializes ShapleyPredictorG.

        Args:
            input_dim (int): Dimensionality of the input feature vector x_i.
            hidden_dims (list[int]): Dimensions of the three hidden layers.
                                     h^(t)(x_i) dimension is hidden_dims[2].
            output_dim (int, optional): Output dimensionality (predicted Shapley value). Defaults to 1.
            dropout_p (float, optional): Dropout probability. Defaults to 0.1.
        """
        super().__init__()
        if not (isinstance(hidden_dims, list) and len(hidden_dims) == 3 and all(isinstance(d, int) and d > 0 for d in hidden_dims)):
            raise ValueError("hidden_dims must be a list of three positive integers.")
        
        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(p=dropout_p)
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(p=dropout_p)
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2]) # Output of this is h^(t)(x_i)
        self.relu3 = nn.ReLU()
        self.fc_out = nn.Linear(hidden_dims[2], output_dim)

        self.h_dim = hidden_dims[2] # Dimension of the representation h

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x = self.relu1(self.fc1(x))
        x = self.dropout1(x)
        x = self.relu2(self.fc2(x))
        x = self.dropout2(x)
        h_representation = self.relu3(self.fc3(x)) # h^(t)(x_i)
        shapley_prediction = self.fc_out(h_representation) # g^(t)(x_i, θ^(t))
        return shapley_prediction, h_representation

class AttentionSubnetwork(nn.Module):
    """Attention subnetwork to compute attention weights α^(t)(x_i).

    Combines representations o(x_i) from Ψ and h^(t)(x_i) from g^(t).
    Refers to equations in the associated academic paper for its transformations.
    """
    def __init__(self, psi_o_dim: int, g_h_dim: int, attention_r_internal_dim: int, num_tasks: int):
        """Initializes AttentionSubnetwork.

        Args:
            psi_o_dim (int): Dimensionality of o(x_i) from the Ψ model.
            g_h_dim (int): Dimensionality of h^(t)(x_i) from g^(t) models.
            attention_r_internal_dim (int): Dimensionality of intermediate representation r^(t)(x_i).
            num_tasks (int): Total number of tasks (T).
        """
        super().__init__()
        self.num_tasks = num_tasks
        
        # Affine transformation for r^(t)(x_i) based on [o(x_i); h^(t)(x_i)] (Eq. 7 in paper)
        self.affine_r = nn.Linear(psi_o_dim + g_h_dim, attention_r_internal_dim)
        self.relu_r = nn.ReLU()
        
        # Affine transformation for unnormalized attention score ~α^(t)(x_i) (Eq. 8 in paper)
        self.affine_alpha_tilde = nn.Linear(attention_r_internal_dim, 1)

    def forward(self, o_x: torch.Tensor, h_x_all_tasks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            o_x: Representation o(x_i) from Ψ model, shape (batch_size, psi_o_dim).
            h_x_all_tasks: Concatenated representations h^(t)(x_i) from all g^(t) models,
                           shape (batch_size, num_tasks, g_h_dim).
        Returns:
            alpha_weights (torch.Tensor): Computed attention weights α^(t)(x_i) after softmax,
                                          shape (batch_size, num_tasks).
            r_t_representations (torch.Tensor): Intermediate representations r^(t)(x_i),
                                                shape (batch_size, num_tasks, attention_r_internal_dim).
        """
        batch_size = o_x.shape[0]

        # Expand o_x to be concatenated with each h^(t)(x_i)
        # o_x: (B, D_o) -> (B, 1, D_o) -> (B, T, D_o)
        o_x_expanded = o_x.unsqueeze(1).expand(-1, self.num_tasks, -1)
        
        # Concatenate o(x_i) with each h^(t)(x_i)
        # concat_input shape: (B, T, D_o + D_h)
        concat_input = torch.cat([o_x_expanded, h_x_all_tasks], dim=2)
        
        # Compute r^(t)(x_i) for all tasks
        # Reshape for nn.Linear: (B * T, D_o + D_h) -> (B * T, D_r)
        # Then reshape back: (B, T, D_r)
        r_t_representations = self.relu_r(
            self.affine_r(concat_input.reshape(-1, concat_input.shape[-1]))
        )
        r_t_representations = r_t_representations.view(batch_size, self.num_tasks, -1)
        
        # Compute unnormalized attention scores ~α^(t)(x_i)
        # Reshape for nn.Linear: (B * T, D_r) -> (B * T, 1)
        # Then reshape back and squeeze: (B, T)
        alpha_tilde_scores = self.affine_alpha_tilde(
            r_t_representations.reshape(-1, r_t_representations.shape[-1])
        ).squeeze(-1)
        alpha_tilde_scores = alpha_tilde_scores.view(batch_size, self.num_tasks)
        
        # Apply softmax to get normalized attention weights α^(t)(x_i)
        alpha_weights = F.softmax(alpha_tilde_scores, dim=1)
        
        return alpha_weights, r_t_representations

class FidelityPredictorPsi(nn.Module):
    """Second-stage model Ψ(x_i, θ): Predicts overall EHR data fidelity.
    
    Outputs the final fidelity prediction and its own hidden representation o(x_i).
    """
    def __init__(self, input_dim: int, hidden_dims: list[int], output_dim: int = 1, dropout_p: float = 0.1):
        """Initializes FidelityPredictorPsi.

        Args:
            input_dim (int): Dimensionality of the input feature vector x_i.
            hidden_dims (list[int]): Dimensions of the three hidden layers.
                                     o(x_i) dimension is hidden_dims[2].
            output_dim (int, optional): Output dimensionality (predicted fidelity). Defaults to 1.
            dropout_p (float, optional): Dropout probability. Defaults to 0.1.
        """
        super().__init__()
        if not (isinstance(hidden_dims, list) and len(hidden_dims) == 3 and all(isinstance(d, int) and d > 0 for d in hidden_dims)):
            raise ValueError("hidden_dims must be a list of three positive integers.")

        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(p=dropout_p)
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(p=dropout_p)
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2]) # Output of this is o(x_i)
        self.relu3 = nn.ReLU()
        self.fc_out = nn.Linear(hidden_dims[2], output_dim)

        self.o_dim = hidden_dims[2] # Dimension of the representation o

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x = self.relu1(self.fc1(x))
        x = self.dropout1(x)
        x = self.relu2(self.fc2(x))
        x = self.dropout2(x)
        o_representation = self.relu3(self.fc3(x)) # o(x_i)
        fidelity_prediction = self.fc_out(o_representation) # Ψ(x_i, θ)
        return fidelity_prediction, o_representation

# --- Training Logic ---
def train_stage1_g_models(
    g_models_list: nn.ModuleList,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer_g: optim.Optimizer,
    num_epochs: int,
    device: torch.device,
    task_names: list[str],
    early_stopping_patience: int,
    experiment_output_dir: str # For saving epoch-wise details if needed
) -> tuple[nn.ModuleList, dict]:
    """Trains task-specific Shapley value predictor models (g^(t)).

    Features dynamic loss weighting based on previous epoch average task losses
    and early stopping based on validation set performance.

    Args:
        g_models_list (nn.ModuleList): List of g^(t) models to train.
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        optimizer_g (optim.Optimizer): Optimizer for g^(t) models.
        num_epochs (int): Maximum training epochs.
        device (torch.device): Device for training (CPU/CUDA).
        task_names (list[str]): Names for tasks, corresponding to g^(t) models.
        early_stopping_patience (int): Epochs to wait for validation improvement before stopping.
        experiment_output_dir (str): Directory for saving artifacts like epoch summaries.

    Returns:
        tuple:
            - trained_g_models_list (nn.ModuleList): Trained g^(t) models.
            - training_summary (dict): Summary of training (e.g., best validation loss, epochs run).
    """
    num_tasks = len(g_models_list)
    if num_tasks == 0:
        logger.info("Stage 1 Training: No tasks to train (g_models_list is empty). Skipping.")
        return g_models_list, {"status": "skipped", "reason": "No tasks"}

    criterion = nn.MSELoss(reduction='none') # Per-sample loss for dynamic weighting
    # Initialize previous task losses for dynamic weighting (ω^(t) in Eq. 5 of the paper)
    # Using average loss from previous epoch, or 1.0 for the first epoch.
    prev_epoch_avg_task_losses_g = torch.ones(num_tasks, device=device)

    best_val_loss_g_aggregate = float('inf')
    epochs_without_improvement_g = 0
    best_g_model_state_dicts = [None] * num_tasks
    actual_epochs_run_g = 0
    
    # Prepare path for epoch summary CSV
    epoch_summary_g_path = Path(experiment_output_dir) / "stage1_g_epoch_summary.csv"
    try:
        with open(epoch_summary_g_path, 'w') as f_csv:
            header_cols = ["Epoch"] + [f"TrainLoss_{tn}" for tn in task_names] + \
                          [f"ValLoss_{tn}" for tn in task_names] + ["ValLoss_Agg", "Status"]
            f_csv.write(",".join(header_cols) + "\n")
    except IOError as e:
        logger.error(f"Failed to create Stage 1 epoch summary file at {epoch_summary_g_path}: {e}")
        epoch_summary_g_path = None # Disable CSV writing

    logger.info(f"--- Stage 1: Training g^(t) Models (Max Epochs: {num_epochs}, Patience: {early_stopping_patience}) ---")
    
    for epoch in range(num_epochs):
        actual_epochs_run_g = epoch + 1
        for model_g in g_models_list:
            model_g.train()

        # Accumulators for average epoch loss calculation
        current_epoch_train_task_losses_sum = torch.zeros(num_tasks, device=device)
        num_train_samples_processed_per_task = torch.zeros(num_tasks, device=device)

        if not train_loader or len(train_loader.dataset) == 0:
            logger.warning(f"Epoch {actual_epochs_run_g}/{num_epochs} (Stage 1): Training data is empty. Skipping training for this epoch.")
            if epoch == num_epochs - 1 and (not val_loader or len(val_loader.dataset) == 0):
                 logger.error("Stage 1: Both training and validation data are empty. Models are not trained.")
            # If val data is also empty, early stopping might not trigger correctly.
            if not val_loader or len(val_loader.dataset) == 0:
                logger.warning("Stage 1: Training and validation data are empty. Stopping Stage 1 training.")
                break # Stop if no data to train or validate
            continue


        for batch_idx, (features, target_shapley_values) in enumerate(train_loader):
            features = features.to(device)
            target_shapley_values = target_shapley_values.to(device) # Shape: (batch_size, num_tasks)

            if features.shape[0] == 0: continue

            optimizer_g.zero_grad()

            batch_task_losses_mse = [] # Store MSE loss for each task in this batch
            for task_idx in range(num_tasks):
                model_g = g_models_list[task_idx]
                # g^(t) model predicts Shapley for one task. Target is target_shapley_values[:, task_idx]
                pred_shapley_task, _ = model_g(features) # pred_shapley_task shape: (batch_size, 1)
                
                # Calculate per-sample loss for this task
                loss_task_per_sample = criterion(pred_shapley_task.squeeze(-1), target_shapley_values[:, task_idx])
                
                # Accumulate sum of losses for epoch average calculation
                current_epoch_train_task_losses_sum[task_idx] += loss_task_per_sample.sum()
                num_train_samples_processed_per_task[task_idx] += features.size(0)
                
                batch_task_losses_mse.append(loss_task_per_sample.mean()) # Average loss for this task in this batch

            # Dynamic weighting (ω^(t)) based on previous epoch's average losses
            current_batch_avg_task_losses_tensor = torch.stack(batch_task_losses_mse)
            
            # Omega weights based on the average loss of the *previous epoch*
            # This avoids rapid fluctuations from batch to batch.
            omega_weights_g = current_batch_avg_task_losses_tensor / (prev_epoch_avg_task_losses_g + EPSILON)
            if epoch == 0: # For the very first epoch, use uniform weights
                omega_weights_g = torch.ones_like(omega_weights_g)

            # Weighted total loss for this batch
            total_loss_g_batch = (omega_weights_g * current_batch_avg_task_losses_tensor).sum()
            
            total_loss_g_batch.backward()
            optimizer_g.step()

        # After all batches in an epoch, calculate average training loss for each task
        avg_epoch_train_task_losses = current_epoch_train_task_losses_sum / (num_train_samples_processed_per_task + EPSILON)
        
        # --- Validation Phase ---
        epoch_status_message = ""
        current_epoch_val_task_losses_list = [0.0] * num_tasks # For logging
        aggregate_val_loss_this_epoch = 0.0

        if val_loader and len(val_loader.dataset) > 0:
            for model_g in g_models_list:
                model_g.eval()
            
            current_epoch_val_task_losses_sum = torch.zeros(num_tasks, device=device)
            num_val_samples_processed_per_task = torch.zeros(num_tasks, device=device)

            with torch.no_grad():
                for features_val, target_shapley_values_val in val_loader:
                    features_val = features_val.to(device)
                    target_shapley_values_val = target_shapley_values_val.to(device)
                    if features_val.shape[0] == 0: continue

                    for task_idx in range(num_tasks):
                        model_g = g_models_list[task_idx]
                        pred_shapley_val_task, _ = model_g(features_val)
                        loss_val_task_per_sample = criterion(pred_shapley_val_task.squeeze(-1), target_shapley_values_val[:, task_idx])
                        current_epoch_val_task_losses_sum[task_idx] += loss_val_task_per_sample.sum()
                        num_val_samples_processed_per_task[task_idx] += features_val.size(0)
            
            if num_val_samples_processed_per_task.sum() > 0: # If validation samples were processed
                avg_epoch_val_task_losses = current_epoch_val_task_losses_sum / (num_val_samples_processed_per_task + EPSILON)
                current_epoch_val_task_losses_list = avg_epoch_val_task_losses.tolist() # For logging
                aggregate_val_loss_this_epoch = avg_epoch_val_task_losses.sum().item() # Sum of average losses per task

                if aggregate_val_loss_this_epoch < best_val_loss_g_aggregate:
                    best_val_loss_g_aggregate = aggregate_val_loss_this_epoch
                    epochs_without_improvement_g = 0
                    for i in range(num_tasks):
                        best_g_model_state_dicts[i] = copy.deepcopy(g_models_list[i].state_dict())
                    epoch_status_message = "New Best"
                else:
                    epochs_without_improvement_g += 1
                    epoch_status_message = f"No Improve ({epochs_without_improvement_g}/{early_stopping_patience})"
            else: # No validation samples processed
                epoch_status_message = "Val Skip (No Data)"
                # If no val data, cannot determine if model improved. Consider not changing patience counter or stopping.
                # For now, we will let patience counter increase if no val data to compare against a best_val_loss.

        else: # No validation loader or empty validation set
            epoch_status_message = "Val Skip (No Loader/Empty)"
            # If no validation, cannot use early stopping based on validation loss.
            # For now, if no validation, early stopping won't trigger effectively.
            # We can choose to save the model at the end of all epochs if no validation.
            if epoch == num_epochs -1 : # Last epoch and no validation
                 logger.warning("Stage 1: No validation data. Saving models from the last epoch.")
                 for i in range(num_tasks): # Save current state as "best"
                      best_g_model_state_dicts[i] = copy.deepcopy(g_models_list[i].state_dict())

        # Logging epoch summary
        train_loss_strs = [f"{l:.6f}" for l in avg_epoch_train_task_losses.tolist()]
        val_loss_strs = [f"{l:.6f}" for l in current_epoch_val_task_losses_list] # Uses list initialized to 0.0 if val skip
        log_msg_epoch = (f"Epoch {actual_epochs_run_g}/{num_epochs} [Stage 1 G] | "
                         f"Train Losses: [{', '.join(train_loss_strs)}] | "
                         f"Val Losses: [{', '.join(val_loss_strs)}] | Val Agg: {aggregate_val_loss_this_epoch:.6f} | Status: {epoch_status_message}")
        logger.info(log_msg_epoch)
        if epoch_summary_g_path:
            try:
                with open(epoch_summary_g_path, 'a') as f_csv:
                    row_data = [str(actual_epochs_run_g)] + train_loss_strs + val_loss_strs + \
                               [f"{aggregate_val_loss_this_epoch:.6f}", epoch_status_message]
                    f_csv.write(",".join(row_data) + "\n")
            except IOError as e:
                logger.warning(f"Failed to write to Stage 1 epoch summary file: {e}")


        # Update losses for next epoch's dynamic weighting
        # Only update if training samples were processed, otherwise keep previous
        if num_train_samples_processed_per_task.sum() > 0:
            prev_epoch_avg_task_losses_g = avg_epoch_train_task_losses.detach()

        if val_loader and len(val_loader.dataset) > 0 and epochs_without_improvement_g >= early_stopping_patience:
            logger.info(f"Stage 1 Training: Early stopping triggered after {actual_epochs_run_g} epochs due to no improvement in validation loss for {early_stopping_patience} consecutive epochs.")
            break
    
    # Load best model states if any were saved (i.e., if validation occurred)
    loaded_best_g = False
    for i in range(num_tasks):
        if best_g_model_state_dicts[i] is not None:
            g_models_list[i].load_state_dict(best_g_model_state_dicts[i])
            loaded_best_g = True
    if loaded_best_g:
        logger.info("Stage 1 Training: Loaded best performing g^(t) model states based on validation performance.")
    elif not (val_loader and len(val_loader.dataset) > 0) and actual_epochs_run_g > 0 : # No validation, but training happened
        logger.info("Stage 1 Training: No validation data was available. Using models from the last trained epoch.")
    else: # No validation and no training epochs run (e.g. no train data)
        logger.warning("Stage 1 Training: No best models to load (no validation or no training occurred).")


    final_training_summary = {
        "status": "completed" if actual_epochs_run_g > 0 else "skipped_no_data",
        "epochs_run": actual_epochs_run_g,
        "early_stopped": epochs_without_improvement_g >= early_stopping_patience if (val_loader and len(val_loader.dataset) > 0) else False,
        "best_aggregate_validation_loss": best_val_loss_g_aggregate if best_val_loss_g_aggregate != float('inf') else None,
        "final_average_train_losses_per_task": {name: loss.item() for name, loss in zip(task_names, prev_epoch_avg_task_losses_g)} if actual_epochs_run_g >0 else None
    }
    logger.info(f"--- Stage 1: Training g^(t) Models Finished (Actual Epochs: {actual_epochs_run_g}) ---")
    return g_models_list, final_training_summary

def _calculate_psi_loss_components(
    features: torch.Tensor,
    psi_model: FidelityPredictorPsi,
    attention_model: AttentionSubnetwork,
    g_models_list_trained: nn.ModuleList,
    num_tasks: int,
    temperature_tau: float,
    device: torch.device,
    epsilon: float = EPSILON # Use global EPSILON by default
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Helper to calculate loss components for the Psi model.

    Args:
        features (torch.Tensor): Input features.
        psi_model (FidelityPredictorPsi): The Psi model.
        attention_model (AttentionSubnetwork): The attention subnetwork.
        g_models_list_trained (nn.ModuleList): List of pre-trained g models (in eval mode).
        num_tasks (int): Number of tasks.
        temperature_tau (float): Temperature for L_sim.
        device (torch.device): Computation device.
        epsilon (float): Small value for numerical stability.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            - loss_kd: Knowledge distillation loss.
            - loss_ent_neg: Negative entropy loss.
            - loss_sim: Similarity loss.
            - psi_fidelity_pred: Predictions from Psi model.
            - alpha_attention_weights: Attention weights.
    """
    # Get outputs from frozen g^(t) models
    g_outputs_all_tasks_list = []
    h_outputs_all_tasks_list = []
    # Ensure g_models are in eval mode and no gradients are computed for them here
    # This should be handled by the calling function (train_stage2 or evaluate)
    # For safety, we can wrap this part in torch.no_grad() if g_models are not guaranteed to be frozen by caller
    with torch.no_grad(): # Assuming g_models_list_trained are already in eval mode and grads frozen by caller
        for task_idx in range(num_tasks):
            g_pred_task, h_repr_task = g_models_list_trained[task_idx](features)
            g_outputs_all_tasks_list.append(g_pred_task)
            h_outputs_all_tasks_list.append(h_repr_task.unsqueeze(1))
    
    g_outputs_all_tasks = torch.cat(g_outputs_all_tasks_list, dim=1)
    h_outputs_all_tasks = torch.cat(h_outputs_all_tasks_list, dim=1)

    psi_fidelity_pred, o_representation = psi_model(features)
    alpha_attention_weights, _ = attention_model(o_representation, h_outputs_all_tasks)

    # 1. Knowledge Distillation Loss (L_kd)
    target_for_psi = (alpha_attention_weights * g_outputs_all_tasks).sum(dim=1, keepdim=True)
    # Detach target_for_psi because we are distilling knowledge from g_outputs (via alpha) to psi_fidelity_pred.
    # The gradients should flow through psi_fidelity_pred, not back into g_outputs or alpha_attention_weights for this specific loss term's target.
    loss_kd = F.mse_loss(psi_fidelity_pred, target_for_psi.detach())

    # 2. Entropy Constraint (L_ent_neg)
    loss_ent_neg = (alpha_attention_weights * torch.log(alpha_attention_weights + epsilon)).sum(dim=1).mean()
    loss_ent_neg = loss_ent_neg + torch.log(torch.tensor(num_tasks, dtype=torch.float, device=device))

    # 3. Similarity Constraint (L_sim)
    loss_sim = torch.tensor(0.0, device=device)
    if num_tasks > 1:
        g_t_detached = g_outputs_all_tasks.detach().unsqueeze(2)
        g_t_prime_detached = g_outputs_all_tasks.detach().unsqueeze(1)
        squared_diffs = (g_t_detached - g_t_prime_detached).pow(2).sum(dim=-1)
        rho_tt_prime = torch.exp(-squared_diffs / temperature_tau)
        
        alpha_t = alpha_attention_weights.unsqueeze(2)
        alpha_t_prime = alpha_attention_weights.unsqueeze(1)
        
        product_terms = alpha_t * alpha_t_prime * rho_tt_prime
        
        indices = torch.triu_indices(num_tasks, num_tasks, offset=1, device=device)
        loss_sim_per_sample = product_terms[:, indices[0], indices[1]].sum(dim=1)
        loss_sim = loss_sim_per_sample.mean()
        
    return loss_kd, loss_ent_neg, loss_sim, psi_fidelity_pred, alpha_attention_weights

def train_stage2_psi_model(
    psi_model: FidelityPredictorPsi,
    attention_model: Optional[AttentionSubnetwork], # Can be None if num_tasks is 0
    g_models_list_trained: nn.ModuleList, # Should be in eval mode
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer_psi: optim.Optimizer, # Optimizes both psi_model and attention_model parameters
    num_epochs: int,
    device: torch.device,
    temperature_tau: float,
    num_tasks: int, # Passed explicitly to handle cases where g_models_list might be empty but num_tasks was intended
    task_names: List[str], # For logging consistency
    early_stopping_patience: int,
    experiment_output_dir: str # For saving epoch-wise details
) -> Tuple[FidelityPredictorPsi, Optional[AttentionSubnetwork], Dict]:
    """Trains the second-stage Fidelity Predictor (Ψ) model and Attention Subnetwork.

    Pre-trained g^(t) models are used in evaluation mode. Implements a composite loss
    with dynamic weighting and early stopping.

    Args:
        psi_model (FidelityPredictorPsi): Ψ model instance.
        attention_model (Optional[AttentionSubnetwork]): AttentionSubnetwork instance. None if num_tasks is 0.
        g_models_list_trained (nn.ModuleList): Pre-trained g^(t) models (in eval mode).
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        optimizer_psi (optim.Optimizer): Optimizer for Ψ and Attention models.
        num_epochs (int): Maximum training epochs.
        device (torch.device): Device for training (CPU/CUDA).
        temperature_tau (float): Temperature (τ) for L_sim loss component.
        num_tasks (int): Number of tasks.
        task_names (List[str]): Task names (for logging).
        early_stopping_patience (int): Epochs for early stopping.
        experiment_output_dir (str): Directory for saving artifacts.

    Returns:
        tuple:
            - trained_psi_model (FidelityPredictorPsi): Trained Ψ model.
            - trained_attention_model (Optional[AttentionSubnetwork]): Trained Attention model.
            - training_summary (Dict): Summary of Stage 2 training.
    """
    if num_tasks == 0 or not attention_model:
        logger.info("Stage 2 Training: num_tasks is 0 or attention_model is None. Skipping Stage 2 training.")
        # Return models as they are, Ψ might still be trainable with a different loss if designed so, but not with current setup.
        return psi_model, attention_model, {"status": "skipped", "reason": "No tasks or no attention model for multi-task learning components"}

    # Ensure g^(t) models are in evaluation mode and their parameters are frozen
    for model_g in g_models_list_trained:
        model_g.eval()
        for param in model_g.parameters():
            param.requires_grad = False

    # Initialize previous losses for dynamic weighting (λ_kd, λ_ent, λ_sim - Eq. 13)
    # Using average losses from the previous epoch for stability.
    prev_epoch_avg_losses_psi = {
        'kd': torch.tensor(1.0, device=device),
        'ent_neg': torch.tensor(1.0, device=device), # Negative Entropy (to be maximized, so loss is minimized)
        'sim': torch.tensor(1.0, device=device)
    }

    best_val_loss_psi_aggregate = float('inf')
    epochs_without_improvement_psi = 0
    best_psi_model_state_dict = None
    best_attention_model_state_dict = None
    actual_epochs_run_psi = 0

    epoch_summary_psi_path = Path(experiment_output_dir) / "stage2_psi_epoch_summary.csv"
    try:
        with open(epoch_summary_psi_path, 'w') as f_csv:
            header_cols = ["Epoch", "Train_L_kd", "Train_L_ent_neg", "Train_L_sim", "Train_L_total",
                           "Val_L_kd", "Val_L_ent_neg", "Val_L_sim", "Val_L_total", "Status"]
            f_csv.write(",".join(header_cols) + "\n")
    except IOError as e:
        logger.error(f"Failed to create Stage 2 epoch summary file at {epoch_summary_psi_path}: {e}")
        epoch_summary_psi_path = None

    logger.info(f"--- Stage 2: Training Ψ Model (Max Epochs: {num_epochs}, Patience: {early_stopping_patience}) ---")

    for epoch in range(num_epochs):
        actual_epochs_run_psi = epoch + 1
        psi_model.train()
        if attention_model: attention_model.train()

        epoch_train_loss_kd_sum, epoch_train_loss_ent_neg_sum, epoch_train_loss_sim_sum, epoch_train_loss_total_sum = 0.0, 0.0, 0.0, 0.0
        num_train_batches_processed = 0

        if not train_loader or len(train_loader.dataset) == 0:
            logger.warning(f"Epoch {actual_epochs_run_psi}/{num_epochs} (Stage 2): Training data is empty. Skipping training for this epoch.")
            if epoch == num_epochs - 1 and (not val_loader or len(val_loader.dataset) == 0):
                 logger.error("Stage 2: Both training and validation data are empty. Ψ model is not trained.")
            if not val_loader or len(val_loader.dataset) == 0:
                logger.warning("Stage 2: Training and validation data are empty. Stopping Stage 2 training.")
                break
            continue

        for batch_idx, (features, _) in enumerate(train_loader): # Shapley values from loader not used here
            features = features.to(device)
            if features.shape[0] == 0: continue
            batch_size = features.shape[0]

            optimizer_psi.zero_grad()

            # Calculate loss components using the helper function
            loss_kd, loss_ent_neg, loss_sim, _, _ = _calculate_psi_loss_components(
                features=features,
                psi_model=psi_model,
                attention_model=attention_model,
                g_models_list_trained=g_models_list_trained,
                num_tasks=num_tasks,
                temperature_tau=temperature_tau,
                device=device,
                epsilon=EPSILON
            )
            
            # Dynamic loss weights (λ_kd, λ_ent, λ_sim - Eq. 13)
            # Based on previous epoch's average component losses for stability
            lambda_kd = loss_kd.detach() / (prev_epoch_avg_losses_psi['kd'] + EPSILON)
            lambda_ent_neg = loss_ent_neg.abs().detach() / (prev_epoch_avg_losses_psi['ent_neg'].abs() + EPSILON)
            lambda_sim = loss_sim.detach() / (prev_epoch_avg_losses_psi['sim'] + EPSILON)
            
            if epoch == 0: # Uniform weights for the first epoch
                lambda_kd, lambda_ent_neg, lambda_sim = torch.tensor(1.0, device=device), torch.tensor(1.0, device=device), torch.tensor(1.0, device=device)

            # Overall Objective (L_Ψ - Eq. 12)
            # L = λ_kd * L_kd + λ_ent * L_ent_neg + λ_sim * L_sim
            # Note: The paper aims to maximize entropy (minimize -L_ent_orig). Our loss_ent_neg is already set up to be minimized.
            total_loss_psi_batch = (lambda_kd * loss_kd) + (lambda_ent_neg * loss_ent_neg) + (lambda_sim * loss_sim)
            
            total_loss_psi_batch.backward()
            optimizer_psi.step()

            epoch_train_loss_kd_sum += loss_kd.item() * batch_size
            epoch_train_loss_ent_neg_sum += loss_ent_neg.item() * batch_size
            epoch_train_loss_sim_sum += loss_sim.item() * batch_size
            epoch_train_loss_total_sum += total_loss_psi_batch.item() * batch_size
            num_train_batches_processed += 1
            # No per-batch prev_loss update for stability, use per-epoch avg.

        # After all training batches in an epoch
        total_train_samples_epoch = len(train_loader.dataset)
        avg_epoch_train_kd = epoch_train_loss_kd_sum / total_train_samples_epoch if total_train_samples_epoch > 0 else 0
        avg_epoch_train_ent_neg = epoch_train_loss_ent_neg_sum / total_train_samples_epoch if total_train_samples_epoch > 0 else 0
        avg_epoch_train_sim = epoch_train_loss_sim_sum / total_train_samples_epoch if total_train_samples_epoch > 0 else 0
        avg_epoch_train_total = epoch_train_loss_total_sum / total_train_samples_epoch if total_train_samples_epoch > 0 else 0
        
        # --- Validation Phase ---
        epoch_status_message_psi = ""
        avg_epoch_val_kd, avg_epoch_val_ent_neg, avg_epoch_val_sim, avg_epoch_val_total = 0.0, 0.0, 0.0, 0.0

        if val_loader and len(val_loader.dataset) > 0:
            psi_model.eval()
            if attention_model: attention_model.eval()
            
            epoch_val_loss_kd_sum, epoch_val_loss_ent_neg_sum, epoch_val_loss_sim_sum, epoch_val_loss_total_sum = 0.0, 0.0, 0.0, 0.0
            num_val_batches_processed = 0

            with torch.no_grad():
                for features_val, _ in val_loader:
                    features_val = features_val.to(device)
                    if features_val.shape[0] == 0: continue
                    val_batch_size = features_val.shape[0]

                    # Calculate validation loss components using the helper
                    loss_kd_val, loss_ent_neg_val, loss_sim_val, _, _ = _calculate_psi_loss_components(
                        features=features_val,
                        psi_model=psi_model,
                        attention_model=attention_model,
                        g_models_list_trained=g_models_list_trained,
                        num_tasks=num_tasks,
                        temperature_tau=temperature_tau,
                        device=device,
                        epsilon=EPSILON
                    )
                    
                    # For consistency with how loss is computed, use lambdas=1 (or fixed) for validation
                    total_loss_psi_val_batch = loss_kd_val + loss_ent_neg_val + loss_sim_val 

                    epoch_val_loss_kd_sum += loss_kd_val.item() * val_batch_size
                    epoch_val_loss_ent_neg_sum += loss_ent_neg_val.item() * val_batch_size
                    epoch_val_loss_sim_sum += loss_sim_val.item() * val_batch_size
                    epoch_val_loss_total_sum += total_loss_psi_val_batch.item() * val_batch_size
                    num_val_batches_processed +=1
            
            total_val_samples_epoch = len(val_loader.dataset)
            if total_val_samples_epoch > 0:
                avg_epoch_val_kd = epoch_val_loss_kd_sum / total_val_samples_epoch
                avg_epoch_val_ent_neg = epoch_val_loss_ent_neg_sum / total_val_samples_epoch
                avg_epoch_val_sim = epoch_val_loss_sim_sum / total_val_samples_epoch
                avg_epoch_val_total = epoch_val_loss_total_sum / total_val_samples_epoch

                if avg_epoch_val_total < best_val_loss_psi_aggregate:
                    best_val_loss_psi_aggregate = avg_epoch_val_total
                    epochs_without_improvement_psi = 0
                    best_psi_model_state_dict = copy.deepcopy(psi_model.state_dict())
                    if attention_model: best_attention_model_state_dict = copy.deepcopy(attention_model.state_dict())
                    epoch_status_message_psi = "New Best"
                else:
                    epochs_without_improvement_psi += 1
                    epoch_status_message_psi = f"No Improve ({epochs_without_improvement_psi}/{early_stopping_patience})"
            else: # No val samples processed
                 epoch_status_message_psi = "Val Skip (No Data)"
        else: # No val loader
            epoch_status_message_psi = "Val Skip (No Loader/Empty)"
            if epoch == num_epochs -1: # Last epoch, no validation
                logger.warning("Stage 2: No validation data. Saving Ψ/Attention models from the last epoch.")
                best_psi_model_state_dict = copy.deepcopy(psi_model.state_dict())
                if attention_model: best_attention_model_state_dict = copy.deepcopy(attention_model.state_dict())
        
        # Logging epoch summary for Stage 2
        log_msg_psi_epoch = (
            f"Epoch {actual_epochs_run_psi}/{num_epochs} [Stage 2 Ψ] | "
            f"Train Losses (KD/EntN/Sim/Total): {avg_epoch_train_kd:.4f}/{avg_epoch_train_ent_neg:.4f}/{avg_epoch_train_sim:.4f}/{avg_epoch_train_total:.4f} | "
            f"Val Losses (KD/EntN/Sim/Total): {avg_epoch_val_kd:.4f}/{avg_epoch_val_ent_neg:.4f}/{avg_epoch_val_sim:.4f}/{avg_epoch_val_total:.4f} | "
            f"Status: {epoch_status_message_psi}"
        )
        logger.info(log_msg_psi_epoch)
        if epoch_summary_psi_path:
            try:
                with open(epoch_summary_psi_path, 'a') as f_csv:
                    row_data = [str(actual_epochs_run_psi), 
                                f"{avg_epoch_train_kd:.6f}", f"{avg_epoch_train_ent_neg:.6f}", f"{avg_epoch_train_sim:.6f}", f"{avg_epoch_train_total:.6f}",
                                f"{avg_epoch_val_kd:.6f}", f"{avg_epoch_val_ent_neg:.6f}", f"{avg_epoch_val_sim:.6f}", f"{avg_epoch_val_total:.6f}",
                                epoch_status_message_psi]
                    f_csv.write(",".join(row_data) + "\n")
            except IOError as e:
                 logger.warning(f"Failed to write to Stage 2 epoch summary file: {e}")


        # Update previous epoch average losses for dynamic weighting in the next epoch
        if total_train_samples_epoch > 0:
            prev_epoch_avg_losses_psi['kd'] = torch.tensor(avg_epoch_train_kd, device=device)
            prev_epoch_avg_losses_psi['ent_neg'] = torch.tensor(avg_epoch_train_ent_neg, device=device)
            prev_epoch_avg_losses_psi['sim'] = torch.tensor(avg_epoch_train_sim, device=device)

        if val_loader and len(val_loader.dataset) > 0 and epochs_without_improvement_psi >= early_stopping_patience:
            logger.info(f"Stage 2 Training: Early stopping triggered after {actual_epochs_run_psi} epochs.")
            break

    # Load best model states
    if best_psi_model_state_dict is not None:
        psi_model.load_state_dict(best_psi_model_state_dict)
        logger.info("Stage 2 Training: Loaded best performing Ψ model state based on validation.")
    if attention_model and best_attention_model_state_dict is not None:
        attention_model.load_state_dict(best_attention_model_state_dict)
        logger.info("Stage 2 Training: Loaded best performing Attention model state based on validation.")
    elif not (val_loader and len(val_loader.dataset) > 0) and actual_epochs_run_psi > 0:
        logger.info("Stage 2 Training: No validation data. Using Ψ/Attention models from the last trained epoch.")
    else:
        logger.warning("Stage 2 Training: No best Ψ/Attention models to load.")


    final_training_summary_psi = {
        "status": "completed" if actual_epochs_run_psi > 0 else "skipped_no_data",
        "epochs_run": actual_epochs_run_psi,
        "early_stopped": epochs_without_improvement_psi >= early_stopping_patience if (val_loader and len(val_loader.dataset) > 0) else False,
        "best_aggregate_validation_loss": best_val_loss_psi_aggregate if best_val_loss_psi_aggregate != float('inf') else None,
        "final_average_train_losses": {
            'L_kd': avg_epoch_train_kd,
            'L_ent_neg': avg_epoch_train_ent_neg,
            'L_sim': avg_epoch_train_sim,
            'L_total': avg_epoch_train_total
        } if actual_epochs_run_psi > 0 else None
    }
    logger.info(f"--- Stage 2: Training Ψ Model Finished (Actual Epochs: {actual_epochs_run_psi}) ---")
    return psi_model, attention_model, final_training_summary_psi

# Placeholder for NpEncoder if needed for JSON serialization of numpy types
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer): return int(obj)
        if isinstance(obj, np.floating): return float(obj)
        if isinstance(obj, np.ndarray): return obj.tolist()
        if isinstance(obj, torch.Tensor): return obj.detach().cpu().tolist()
        return super(NpEncoder, self).default(obj)

def evaluate_models_on_test_set(
    g_models_list_eval: nn.ModuleList, # Trained g models
    psi_model_eval: FidelityPredictorPsi, # Trained Psi model
    attention_model_eval: Optional[AttentionSubnetwork], # Trained Attention model
    test_loader: DataLoader,
    device: torch.device,
    num_tasks: int,
    task_names: List[str],
    temperature_tau: float,
    criterion_g_loss_fn: nn.Module # e.g., nn.MSELoss() for g model loss
) -> Dict:
    """Evaluates trained g^(t) and Ψ models on the test set.

    Calculates test MSE for g^(t) models and component losses (L_kd, L_ent_neg, L_sim)
    for the Ψ model.

    Args:
        g_models_list_eval (nn.ModuleList): Trained g^(t) models.
        psi_model_eval (FidelityPredictorPsi): Trained Ψ model.
        attention_model_eval (Optional[AttentionSubnetwork]): Trained Attention Subnetwork. None if num_tasks is 0.
        test_loader (DataLoader): DataLoader for test data.
        device (torch.device): Device for evaluation.
        num_tasks (int): Number of tasks.
        task_names (List[str]): Task names.
        temperature_tau (float): Temperature for L_sim calculation.
        criterion_g_loss_fn (nn.Module): Loss function for evaluating g^(t) models (e.g., MSE).

    Returns:
        Dict: Evaluation results, including test losses for g^(t) and Ψ components.
    """
    logger.info("--- Starting Model Evaluation on Test Set ---")
    evaluation_results = {
        'stage1_g_test_losses_mse': {},
        'stage2_psi_test_losses': {}
    }

    if not test_loader or len(test_loader.dataset) == 0:
        logger.warning("Test data is empty. Skipping evaluation.")
        return {"status": "skipped", "reason": "No test data"}

    # --- Stage 1: Evaluate g^(t) Models ---
    if g_models_list_eval and num_tasks > 0:
        for model_g in g_models_list_eval:
            model_g.eval()

        test_g_task_losses_sum = torch.zeros(num_tasks, device=device)
        test_g_samples_per_task = torch.zeros(num_tasks, device=device)
        total_test_samples_g = 0

        with torch.no_grad():
            for features_test, target_shapleys_test in test_loader:
                features_test = features_test.to(device)
                target_shapleys_test = target_shapleys_test.to(device)
                if features_test.shape[0] == 0: continue
                total_test_samples_g += features_test.shape[0]

                for task_idx in range(num_tasks):
                    model_g = g_models_list_eval[task_idx]
                    pred_shapley_task_test, _ = model_g(features_test)
                    # Using the provided criterion (e.g., MSE reduction='sum' or 'mean')
                    # If criterion is reduction='none', need to sum/mean manually.
                    # Assuming criterion_g_loss_fn is like nn.MSELoss(reduction='mean') for per-sample average
                    loss_task_test = criterion_g_loss_fn(pred_shapley_task_test.squeeze(-1),
                                                         target_shapleys_test[:, task_idx])
                    test_g_task_losses_sum[task_idx] += loss_task_test.item() * features_test.shape[0] # Accumulate sum of losses
                    test_g_samples_per_task[task_idx] += features_test.shape[0]
        
        if total_test_samples_g > 0:
            avg_test_g_task_losses = test_g_task_losses_sum / (test_g_samples_per_task + EPSILON) # Average loss per task
            for i, task_name in enumerate(task_names):
                evaluation_results['stage1_g_test_losses_mse'][task_name] = avg_test_g_task_losses[i].item()
            logger.info(f"Stage 1 (g^(t)) Test MSE Losses: {evaluation_results['stage1_g_test_losses_mse']}")
        else:
            logger.info("Stage 1 (g^(t)) Test: No samples processed.")
    else:
        logger.info("Stage 1 (g^(t)) Test: Skipped (no g_models or num_tasks is 0).")

    # --- Stage 2: Evaluate Ψ Model ---
    if psi_model_eval and attention_model_eval and num_tasks > 0:
        psi_model_eval.eval()
        attention_model_eval.eval()

        test_psi_loss_kd_sum, test_psi_loss_ent_neg_sum, test_psi_loss_sim_sum, test_psi_loss_total_sum = 0.0, 0.0, 0.0, 0.0
        total_test_samples_psi = 0

        with torch.no_grad():
            for features_test, _ in test_loader: # True Shapley values not directly used by Psi eval here
                features_test = features_test.to(device)
                if features_test.shape[0] == 0: continue
                test_batch_size = features_test.shape[0]
                total_test_samples_psi += test_batch_size

                # Calculate test loss components using the helper
                loss_kd_test, loss_ent_neg_test, loss_sim_test, _, _ = _calculate_psi_loss_components(
                    features=features_test,
                    psi_model=psi_model_eval,
                    attention_model=attention_model_eval,
                    g_models_list_trained=g_models_list_eval, # Ensure correct g_models are passed
                    num_tasks=num_tasks,
                    temperature_tau=temperature_tau,
                    device=device,
                    epsilon=EPSILON
                )
                
                total_loss_psi_test_batch = loss_kd_test + loss_ent_neg_test + loss_sim_test

                test_psi_loss_kd_sum += loss_kd_test.item() * test_batch_size
                test_psi_loss_ent_neg_sum += loss_ent_neg_test.item() * test_batch_size
                test_psi_loss_sim_sum += loss_sim_test.item() * test_batch_size
                test_psi_loss_total_sum += total_loss_psi_test_batch.item() * test_batch_size
        
        if total_test_samples_psi > 0:
            evaluation_results['stage2_psi_test_losses'] = {
                'L_kd': test_psi_loss_kd_sum / total_test_samples_psi,
                'L_ent_neg': test_psi_loss_ent_neg_sum / total_test_samples_psi,
                'L_sim': test_psi_loss_sim_sum / total_test_samples_psi,
                'L_total': test_psi_loss_total_sum / total_test_samples_psi
            }
            logger.info(f"Stage 2 (Ψ) Test Losses: {evaluation_results['stage2_psi_test_losses']}")
        else:
            logger.info("Stage 2 (Ψ) Test: No samples processed.")
    else:
        logger.info("Stage 2 (Ψ) Test: Skipped (no Psi/Attention model or num_tasks is 0).")

    logger.info("--- Model Evaluation on Test Set Finished ---")
    return evaluation_results


# --- Model Persistence ---
def save_trained_models_and_config(
    g_models_to_save: Optional[nn.ModuleList],
    psi_model_to_save: Optional[FidelityPredictorPsi],
    attention_model_to_save: Optional[AttentionSubnetwork],
    model_config_params: Dict, # Contains input_dim, hidden_dims, task_names etc.
    training_summaries: Dict, # Dict with 'stage1' and 'stage2' summaries
    run_params_to_log: Dict, # Original argparse parameters for the run
    model_output_dir: str,
    dataset_name_str: str
):
    """Saves trained models (g^(t), Ψ, Attention), configurations, and summaries.

    Organizes outputs in a structured directory format under `model_output_dir`/
    `dataset_name_str`/trained_models/.

    Args:
        g_models_to_save (Optional[nn.ModuleList]): Trained g^(t) models.
        psi_model_to_save (Optional[FidelityPredictorPsi]): Trained Ψ model.
        attention_model_to_save (Optional[AttentionSubnetwork]): Trained Attention Subnetwork.
        model_config_params (Dict): Model architecture parameters.
        training_summaries (Dict): Summaries from Stage 1 and Stage 2 training, plus test evaluation.
        run_params_to_log (Dict): Initial argparse arguments for the run.
        model_output_dir (str): Base directory for this run's outputs.
        dataset_name_str (str): Dataset name, used for subdirectory creation.
    """
    # model_output_dir is received as str (it's experiment_output_dir from main which is now Path)
    base_model_dir = Path(model_output_dir) 
    dataset_specific_model_dir = base_model_dir / dataset_name_str / "trained_models"
    dataset_specific_model_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Saving models and configuration to: {dataset_specific_model_dir}")

    # --- Save Full Configuration and Summaries ---
    full_run_config_summary = {
        'model_architecture_config': model_config_params,
        'training_summaries': training_summaries,
        'original_run_parameters': run_params_to_log,
        'saved_timestamp': time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime())
    }
    config_summary_path = dataset_specific_model_dir / "run_config_and_summary.json"
    try:
        with open(config_summary_path, 'w') as f_json:
            json.dump(full_run_config_summary, f_json, indent=4, cls=NpEncoder)
        logger.info(f"Full run configuration and training summaries saved to: {config_summary_path}")
    except IOError as e:
        logger.error(f"Failed to save run_config_and_summary.json: {e}")

    # --- Save g^(t) Models ---
    if g_models_to_save and model_config_params.get('num_tasks', 0) > 0:
        g_models_checkpoint = {
            'models_state_dict': [model.state_dict() for model in g_models_to_save],
            # Include relevant parts of model_config_params for g_models reconstruction
            'g_model_params': {
                'input_dim': model_config_params.get('input_dim'),
                'g_hidden_dims': model_config_params.get('g_hidden_dims'),
                'num_tasks': model_config_params.get('num_tasks'),
                'task_names': model_config_params.get('task_names')
            }
        }
        g_models_path = dataset_specific_model_dir / "g_models.pt"
        try:
            torch.save(g_models_checkpoint, g_models_path)
            logger.info(f"g^(t) models saved to: {g_models_path}")
        except Exception as e:
            logger.error(f"Failed to save g^(t) models: {e}", exc_info=True)
    else:
        logger.info("Skipping g^(t) models save (not provided or no tasks).")

    # --- Save Ψ Model ---
    if psi_model_to_save:
        psi_model_checkpoint = {
            'model_state_dict': psi_model_to_save.state_dict(),
            'psi_model_params': {
                'input_dim': model_config_params.get('input_dim'),
                'psi_hidden_dims': model_config_params.get('psi_hidden_dims')
            }
        }
        psi_model_path = dataset_specific_model_dir / "psi_model.pt"
        try:
            torch.save(psi_model_checkpoint, psi_model_path)
            logger.info(f"Ψ model saved to: {psi_model_path}")
        except Exception as e:
            logger.error(f"Failed to save Ψ model: {e}", exc_info=True)
    else:
        logger.info("Skipping Ψ model save (not provided).")

    # --- Save Attention Model ---
    if attention_model_to_save and model_config_params.get('num_tasks', 0) > 0:
        # Prioritize getting dimensions directly from model instances if available
        psi_o_dim_val = None
        if psi_model_to_save and hasattr(psi_model_to_save, 'o_dim'):
            psi_o_dim_val = psi_model_to_save.o_dim
        else:
            psi_o_dim_val = model_config_params.get('psi_hidden_dims', [0,0,0])[-1] # Fallback

        g_h_dim_val = None
        if g_models_to_save and len(g_models_to_save) > 0 and hasattr(g_models_to_save[0], 'h_dim'):
            g_h_dim_val = g_models_to_save[0].h_dim
        else:
            g_h_dim_val = model_config_params.get('g_hidden_dims', [0,0,0])[-1] # Fallback

        attention_model_checkpoint = {
            'model_state_dict': attention_model_to_save.state_dict(),
            'attention_model_params': {
                'psi_o_dim': psi_o_dim_val,
                'g_h_dim': g_h_dim_val,
                'attention_r_dim': model_config_params.get('attention_r_dim'),
                'num_tasks': model_config_params.get('num_tasks')
            }
        }
        attention_model_path = dataset_specific_model_dir / "attention_model.pt"
        try:
            torch.save(attention_model_checkpoint, attention_model_path)
            logger.info(f"Attention model saved to: {attention_model_path}")
        except Exception as e:
            logger.error(f"Failed to save Attention model: {e}", exc_info=True)
    else:
        logger.info("Skipping Attention model save (not provided or no tasks).")

def load_models_and_config(model_load_dir: str, device: torch.device) -> Optional[Dict]:
    """Loads trained models (g^(t), Ψ, Attention) and configurations from a directory.

    Args:
        model_load_dir (str): Directory to load models and config from 
                              (e.g., .../dataset_name/trained_models/).
        device (torch.device): Device (CPU/CUDA) to load models onto.

    Returns:
        Optional[Dict]: Contains loaded models ('g_models', 'psi_model', 'attention_model')
                     and full configuration ('full_run_config'). Returns None on critical load failure.
    """
    logger.info(f"Loading models and configuration from: {model_load_dir}") # model_load_dir is str
    load_dir_path = Path(model_load_dir)
    if not load_dir_path.is_dir():
        logger.error(f"Model load directory not found: {load_dir_path}")
        return None

    loaded_artifacts = {
        'g_models': nn.ModuleList(),
        'psi_model': None,
        'attention_model': None,
        'full_run_config': None
    }

    # --- Load Full Configuration and Summaries ---
    config_summary_path = load_dir_path / "run_config_and_summary.json"
    if config_summary_path.exists():
        try:
            with open(config_summary_path, 'r') as f_json:
                loaded_artifacts['full_run_config'] = json.load(f_json)
            logger.info("Successfully loaded run_config_and_summary.json.")
            arch_config = loaded_artifacts['full_run_config'].get('model_architecture_config', {})
        except Exception as e:
            logger.error(f"Failed to load or parse run_config_and_summary.json: {e}", exc_info=True)
            arch_config = {} # Attempt to load models even if full config fails, using params from checkpoints
    else:
        logger.warning("run_config_and_summary.json not found. Model parameters must be inferred from checkpoints.")
        arch_config = {}

    # --- Load g^(t) Models ---
    g_models_path = load_dir_path / "g_models.pt"
    if g_models_path.exists():
        try:
            checkpoint = torch.load(g_models_path, map_location=device)
            g_params = checkpoint.get('g_model_params', arch_config) # Prioritize checkpoint params
            
            # Ensure all necessary params are present
            if not all(k in g_params for k in ['input_dim', 'g_hidden_dims', 'num_tasks']):
                raise ValueError("Missing critical parameters for g_model reconstruction in checkpoint or config.")

            for state_dict in checkpoint['models_state_dict']:
                model_g = ShapleyPredictorG(
                    input_dim=g_params['input_dim'], 
                    hidden_dims=g_params['g_hidden_dims']
                )
                model_g.load_state_dict(state_dict)
                model_g.to(device).eval()
                loaded_artifacts['g_models'].append(model_g)
            logger.info(f"Successfully loaded {len(loaded_artifacts['g_models'])} g^(t) model(s).")
        except Exception as e:
            logger.error(f"Failed to load g^(t) models from {g_models_path}: {e}", exc_info=True)
    else:
        logger.info("g_models.pt not found. No g^(t) models loaded.")

    # --- Load Ψ Model ---
    psi_model_path = load_dir_path / "psi_model.pt"
    if psi_model_path.exists():
        try:
            checkpoint = torch.load(psi_model_path, map_location=device)
            psi_params = checkpoint.get('psi_model_params', arch_config)
            if not all(k in psi_params for k in ['input_dim', 'psi_hidden_dims']):
                raise ValueError("Missing critical parameters for psi_model reconstruction.")

            loaded_artifacts['psi_model'] = FidelityPredictorPsi(
                input_dim=psi_params['input_dim'], 
                hidden_dims=psi_params['psi_hidden_dims']
            )
            loaded_artifacts['psi_model'].load_state_dict(checkpoint['model_state_dict'])
            loaded_artifacts['psi_model'].to(device).eval()
            logger.info("Successfully loaded Ψ model.")
        except Exception as e:
            logger.error(f"Failed to load Ψ model from {psi_model_path}: {e}", exc_info=True)
            loaded_artifacts['psi_model'] = None # Ensure it's None on failure
    else:
        logger.info("psi_model.pt not found. No Ψ model loaded.")

    # --- Load Attention Model ---
    attention_model_path = load_dir_path / "attention_model.pt"
    if attention_model_path.exists():
        try:
            checkpoint = torch.load(attention_model_path, map_location=device)
            attn_params = checkpoint.get('attention_model_params', arch_config)
            # Infer missing params if possible from arch_config (if full_run_config was loaded)
            if 'psi_o_dim' not in attn_params and arch_config.get('psi_hidden_dims'):
                attn_params['psi_o_dim'] = arch_config['psi_hidden_dims'][-1]
            if 'g_h_dim' not in attn_params and arch_config.get('g_hidden_dims'):
                attn_params['g_h_dim'] = arch_config['g_hidden_dims'][-1]
            
            if not all(k in attn_params for k in ['psi_o_dim', 'g_h_dim', 'attention_r_dim', 'num_tasks']):
                raise ValueError("Missing critical parameters for attention_model reconstruction.")
            if attn_params['num_tasks'] > 0 : # Only create if num_tasks > 0
                loaded_artifacts['attention_model'] = AttentionSubnetwork(
                    psi_o_dim=attn_params['psi_o_dim'],
                    g_h_dim=attn_params['g_h_dim'],
                    attention_r_internal_dim=attn_params['attention_r_dim'],
                    num_tasks=attn_params['num_tasks']
                )
                loaded_artifacts['attention_model'].load_state_dict(checkpoint['model_state_dict'])
                loaded_artifacts['attention_model'].to(device).eval()
                logger.info("Successfully loaded Attention model.")
            else:
                logger.info("Attention model checkpoint found, but num_tasks is 0. Not loading.")
        except Exception as e:
            logger.error(f"Failed to load Attention model from {attention_model_path}: {e}", exc_info=True)
            loaded_artifacts['attention_model'] = None
    else:
        logger.info("attention_model.pt not found. No Attention model loaded.")

    # Basic check if any core component failed to load properly, but config might still be useful
    if not loaded_artifacts['full_run_config'] and not loaded_artifacts['g_models'] and not loaded_artifacts['psi_model']:
        logger.error("Critical failure: Neither config nor any model could be loaded.")
        # return None # Decide if returning partial load is okay or should be None

    return loaded_artifacts


def _setup_run_environment(args: argparse.Namespace) -> Tuple[Path, torch.device]: # Return Path
    """Sets up the run environment: output directories, logging, random seeds, and device.

    Args:
        args (argparse.Namespace): Parsed command-line arguments.

    Returns:
        Tuple[Path, torch.device]: 
            - experiment_output_dir (Path): Path to the main output directory for this run.
            - device (torch.device): Computed device (cuda or cpu).
    """
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    run_specific_dir_name = f"{args.run_name_prefix}_{args.dataset_name}_{timestamp}"
    experiment_output_dir = Path(args.base_output_dir) / run_specific_dir_name
    experiment_output_dir.mkdir(parents=True, exist_ok=True)

    log_file = experiment_output_dir / "training_run_log.txt"
    # setup_logger expects a string path for the log file.
    setup_logger(args.log_level, str(log_file))

    logger.info("Starting Unified Distillation Training Script")
    logger.info(f"Run Output Directory: {experiment_output_dir}")
    logger.info(f"Command: python {' '.join(sys.argv)}")
    logger.info(f"Parsed Arguments: {vars(args)}")

    # Set Random Seeds for Reproducibility
    torch.manual_seed(args.global_random_seed)
    np.random.seed(args.global_random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.global_random_seed)
        # torch.backends.cudnn.deterministic = True # Consider for strict reproducibility
        # torch.backends.cudnn.benchmark = False   # Consider for strict reproducibility

    # Determine Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    return experiment_output_dir, device

def _load_data_for_run(args: argparse.Namespace, experiment_output_dir: Path) -> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader], List[str], int, Optional[Dict]]: # Expect Path
    """Loads and splits data for the training run.

    Args:
        args (argparse.Namespace): Parsed command-line arguments.
        experiment_output_dir (Path): Path for saving data split indices.

    Returns:
        Tuple as per load_and_split_data.
    """
    train_loader, val_loader, test_loader, task_names, input_dim, data_split_indices = load_and_split_data(
        feature_file_path=args.feature_file,
        shapley_pkl_path=args.shapley_pkl_file,
        dataset_identifier=args.dataset_name,
        batch_size_value=args.batch_size,
        split_random_seed=args.random_seed_data_split,
        output_base_dir=str(experiment_output_dir), # load_and_split_data takes str for this, then converts
        num_workers_dl=args.dataloader_num_workers,
        pin_memory_dl=args.dataloader_pin_memory and device.type == 'cuda' # Only pin if CUDA and flag is true
    )
    if train_loader is None or input_dim == -1:
        logger.error("Data loading failed. Exiting script.")
        sys.exit(1)
    
    num_tasks = len(task_names)
    logger.info(f"Successfully loaded data for dataset '{args.dataset_name}'. Input dimension: {input_dim}, Num tasks: {num_tasks}, Task names: {task_names}")
    return train_loader, val_loader, test_loader, task_names, input_dim, data_split_indices

def _initialize_or_load_models_for_run(args: argparse.Namespace, input_dim: int, num_tasks: int, task_names: List[str], device: torch.device) -> Tuple[nn.ModuleList, Optional[FidelityPredictorPsi], Optional[AttentionSubnetwork], Dict]:
    """Initializes or loads models (g, Psi, Attention) for the run.

    Args:
        args (argparse.Namespace): Parsed command-line arguments.
        input_dim (int): Input feature dimension.
        num_tasks (int): Number of tasks.
        task_names (List[str]): List of task names.
        device (torch.device): Computation device.

    Returns:
        Tuple containing g_models_list, psi_model, attention_subnetwork, model_architecture_params.
    """
    g_models_list = nn.ModuleList()
    psi_model = None
    attention_subnetwork = None
    model_architecture_params = {
        'input_dim': input_dim,
        'g_hidden_dims': args.g_hidden_dims,
        'psi_hidden_dims': args.psi_hidden_dims,
        'attention_r_dim': args.attention_r_dim,
        'dropout_rate': args.dropout_rate, # Added dropout_rate
        'num_tasks': num_tasks,
        'task_names': task_names
    }

    if args.load_models_from_dir:
        logger.info(f"Attempting to load models from: {args.load_models_from_dir}")
        loaded_artifacts = load_models_and_config(args.load_models_from_dir, device)
        if loaded_artifacts:
            g_models_list = loaded_artifacts.get('g_models', nn.ModuleList())
            psi_model = loaded_artifacts.get('psi_model')
            attention_subnetwork = loaded_artifacts.get('attention_model')
            if loaded_artifacts.get('full_run_config') and loaded_artifacts['full_run_config'].get('model_architecture_config'):
                # Update params from loaded config, potentially overriding args if they differ
                # This is important if loaded model architecture differs from current args
                loaded_arch_params = loaded_artifacts['full_run_config']['model_architecture_config']
                model_architecture_params.update(loaded_arch_params) # Update with loaded, args are defaults if not in config
                # Re-check consistency based on loaded params
                input_dim = model_architecture_params['input_dim']
                num_tasks = model_architecture_params['num_tasks']
                task_names = model_architecture_params['task_names']
                logger.info("Model architecture parameters updated from loaded configuration.")
            logger.info(f"Models loaded. g_models: {len(g_models_list)}, Psi: {psi_model is not None}, Attention: {attention_subnetwork is not None}")
        else:
            logger.warning(f"Failed to load models from {args.load_models_from_dir}. Will proceed with initialization if training is not skipped.")
    
    # Initialize if not loaded or if loading failed/partially failed
    if not g_models_list and num_tasks > 0:
        g_models_list = nn.ModuleList([ShapleyPredictorG(input_dim, model_architecture_params['g_hidden_dims'], dropout_p=model_architecture_params['dropout_rate']) for _ in range(num_tasks)])
        logger.info(f"Initialized {num_tasks} g^(t) models with hidden_dims: {model_architecture_params['g_hidden_dims']} and dropout: {model_architecture_params['dropout_rate']}.")
    if psi_model is None:
        psi_model = FidelityPredictorPsi(input_dim, model_architecture_params['psi_hidden_dims'], dropout_p=model_architecture_params['dropout_rate'])
        logger.info(f"Initialized Ψ model with hidden_dims: {model_architecture_params['psi_hidden_dims']} and dropout: {model_architecture_params['dropout_rate']}.")
    if attention_subnetwork is None and num_tasks > 0:
        psi_o_dim = psi_model.o_dim # Assumes psi_model is initialized by now
        g_h_dim_default = model_architecture_params['g_hidden_dims'][-1]
        g_h_dim = g_models_list[0].h_dim if g_models_list and len(g_models_list) > 0 else g_h_dim_default
        
        attention_subnetwork = AttentionSubnetwork(
            psi_o_dim=psi_o_dim,
            g_h_dim=g_h_dim,
            attention_r_internal_dim=model_architecture_params['attention_r_dim'],
            num_tasks=num_tasks
        )
        logger.info("Initialized Attention Subnetwork.")
    
    g_models_list.to(device)
    if psi_model: psi_model.to(device)
    if attention_subnetwork: attention_subnetwork.to(device)
    
    return g_models_list, psi_model, attention_subnetwork, model_architecture_params

def _initialize_optimizers_for_run(args: argparse.Namespace, g_models_list: nn.ModuleList, psi_model: Optional[FidelityPredictorPsi], attention_subnetwork: Optional[AttentionSubnetwork], num_tasks: int) -> Tuple[Optional[optim.Optimizer], Optional[optim.Optimizer]]:
    """Initializes optimizers for Stage 1 (g) and Stage 2 (Psi, Attention).

    Args:
        args (argparse.Namespace): Parsed command-line arguments.
        g_models_list (nn.ModuleList): List of g models.
        psi_model (Optional[FidelityPredictorPsi]): Psi model.
        attention_subnetwork (Optional[AttentionSubnetwork]): Attention model.
        num_tasks (int): Number of tasks.

    Returns:
        Tuple[Optional[optim.Optimizer], Optional[optim.Optimizer]]: optimizer_g, optimizer_psi.
    """
    optimizer_g = None
    if num_tasks > 0 and len(g_models_list) > 0 and any(p.numel() for p in g_models_list.parameters()):
        optimizer_g = optim.Adam(g_models_list.parameters(), lr=args.learning_rate_g)
        logger.info(f"Initialized Adam optimizer for g^(t) models with LR: {args.learning_rate_g}")
    
    optimizer_psi = None
    psi_params_to_optimize = []
    if psi_model and any(p.numel() for p in psi_model.parameters()): 
        psi_params_to_optimize.extend(list(psi_model.parameters()))
    if attention_subnetwork and any(p.numel() for p in attention_subnetwork.parameters()): 
        psi_params_to_optimize.extend(list(attention_subnetwork.parameters()))
    
    if psi_params_to_optimize:
        optimizer_psi = optim.Adam(psi_params_to_optimize, lr=args.learning_rate_psi)
        logger.info(f"Initialized Adam optimizer for Ψ and Attention models with LR: {args.learning_rate_psi}")
    return optimizer_g, optimizer_psi

def main():
    """Main function orchestrating two-stage distillation model training and evaluation.
    
    Handles argument parsing, setup, data loading, model initialization/loading, 
    optimizer setup, conditional execution of training stages and evaluation, 
    and finally saves all results and summaries.
    """
    parser = argparse.ArgumentParser(description="Unified Two-Stage Knowledge Distillation Trainer for EHR Data Fidelity.",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # --- Data Arguments ---
    data_group = parser.add_argument_group("Data Input Arguments")
    data_group.add_argument("--dataset_name", type=str, required=True, help="Unique name for the dataset (e.g., AKI, MIMIC)")
    data_group.add_argument("--feature_file", type=str, required=True, help="Path to the .npy file containing input features (X).")
    data_group.add_argument("--shapley_pkl_file", type=str, default=None, help="Path to the .pkl file containing pre-calculated Shapley values (for g_t model targets). Optional.")

    # --- Model Hyperparameters ---
    model_hparam_group = parser.add_argument_group("Model Hyperparameters")
    model_hparam_group.add_argument("--g_hidden_dims", type=int, nargs=3, default=[64, 64, 32], help="Hidden layer dimensions for g^(t) models (list of 3 ints).")
    model_hparam_group.add_argument("--psi_hidden_dims", type=int, nargs=3, default=[128, 64, 64], help="Hidden layer dimensions for Ψ model (list of 3 ints).")
    model_hparam_group.add_argument("--attention_r_dim", type=int, default=32, help="Dimension of r^(t)(x) in the Attention Subnetwork.")
    model_hparam_group.add_argument("--dropout_rate", type=float, default=0.1, help="Dropout rate for g and Psi models.")

    # --- Training Parameters ---
    train_param_group = parser.add_argument_group("Training Parameters")
    train_param_group.add_argument("--learning_rate_g", type=float, default=1e-3, help="Learning rate for Stage 1 (g^(t) models).")
    train_param_group.add_argument("--learning_rate_psi", type=float, default=1e-3, help="Learning rate for Stage 2 (Ψ and Attention models).")
    train_param_group.add_argument("--num_epochs_g", type=int, default=100, help="Maximum epochs for Stage 1 training.")
    train_param_group.add_argument("--num_epochs_psi", type=int, default=100, help="Maximum epochs for Stage 2 training.")
    train_param_group.add_argument("--batch_size", type=int, default=64, help="Batch size for training and evaluation.")
    train_param_group.add_argument("--temperature_tau", type=float, default=1.0, help="Temperature (τ) for L_sim loss component in Stage 2.")
    train_param_group.add_argument("--early_stopping_patience_g", type=int, default=10, help="Patience for early stopping in Stage 1.")
    train_param_group.add_argument("--early_stopping_patience_psi", type=int, default=10, help="Patience for early stopping in Stage 2.")
    train_param_group.add_argument("--random_seed_data_split", type=int, default=42, help="Random seed for train/val/test data splitting.")
    train_param_group.add_argument("--global_random_seed", type=int, default=42, help="Global random seed for PyTorch and NumPy for reproducibility.")
    train_param_group.add_argument("--dataloader_num_workers", type=int, default=0, help="Number of workers for DataLoaders. Set >0 for parallel data loading.")
    train_param_group.add_argument("--dataloader_pin_memory", action="store_true", help="Enable pin_memory for DataLoaders (useful with CUDA).")

    # --- Output & Logging & Mode ---
    output_group = parser.add_argument_group("Output, Logging & Mode Arguments")
    output_group.add_argument("--base_output_dir", type=str, default="./distillation_run_outputs", help="Base directory for all run outputs.")
    output_group.add_argument("--run_name_prefix", type=str, default="distill_run", help="Prefix for the specific run's output directory.")
    output_group.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level.")
    output_group.add_argument("--skip_stage1_training", action="store_true", help="Skip Stage 1 (g model) training. Requires models to be loadable if Stage 2 is run.")
    output_group.add_argument("--skip_stage2_training", action="store_true", help="Skip Stage 2 (Psi model) training. Requires models to be loadable if evaluation is run.")
    output_group.add_argument("--skip_evaluation", action="store_true", help="Skip final model evaluation on the test set.")
    output_group.add_argument("--load_models_from_dir", type=str, default=None, help="Path to a directory (e.g., .../dataset_name/trained_models/) to load pre-trained models from. If set, training stages might be skipped based on other flags.")

    args = parser.parse_args()

    experiment_output_dir, device = _setup_run_environment(args)

    train_loader, val_loader, test_loader, task_names, input_dim, data_split_indices = _load_data_for_run(args, experiment_output_dir)
    num_tasks = len(task_names)

    g_models_list, psi_model, attention_subnetwork, model_architecture_params = _initialize_or_load_models_for_run(
        args, input_dim, num_tasks, task_names, device
    )
    
    # Re-sync num_tasks and task_names from model_architecture_params if they were updated from loaded config
    num_tasks = model_architecture_params['num_tasks']
    task_names = model_architecture_params['task_names']

    optimizer_g, optimizer_psi = _initialize_optimizers_for_run(
        args, g_models_list, psi_model, attention_subnetwork, num_tasks
    )

    # --- Training Summaries Initialization ---
    training_summary_stage1 = None
    training_summary_stage2 = None

    # --- Stage 1 Training: g^(t) Models ---
    if not args.skip_stage1_training:
        if num_tasks > 0 and optimizer_g and train_loader and len(train_loader.dataset) > 0:
            g_models_list, training_summary_stage1 = train_stage1_g_models(
                g_models_list=g_models_list,
                train_loader=train_loader,
                val_loader=val_loader,
                optimizer_g=optimizer_g,
                num_epochs=args.num_epochs_g,
                device=device,
                task_names=task_names,
                early_stopping_patience=args.early_stopping_patience_g,
                experiment_output_dir=experiment_output_dir
            )
        elif num_tasks == 0:
            logger.info("Skipping Stage 1 training: No tasks defined (num_tasks is 0).")
            training_summary_stage1 = {"status": "skipped", "reason": "No tasks"}
        elif not optimizer_g:
            logger.info("Skipping Stage 1 training: Optimizer for g_models not initialized (possibly no g_models or no parameters).")
            training_summary_stage1 = {"status": "skipped", "reason": "No optimizer_g"}
        else:
            logger.info("Skipping Stage 1 training: Training data is empty.")
            training_summary_stage1 = {"status": "skipped", "reason": "Empty train_loader"}
    else:
        logger.info("Stage 1 training explicitly skipped by user (--skip_stage1_training).")
        training_summary_stage1 = {"status": "skipped_by_user"}

    # --- Stage 2 Training: Ψ Model ---
    if not args.skip_stage2_training:
        if num_tasks > 0 and psi_model and attention_subnetwork and optimizer_psi and train_loader and len(train_loader.dataset) > 0:
            psi_model, attention_subnetwork, training_summary_stage2 = train_stage2_psi_model(
                psi_model=psi_model,
                attention_model=attention_subnetwork,
                g_models_list_trained=g_models_list, # Pass the (potentially trained) g_models
                train_loader=train_loader,
                val_loader=val_loader,
                optimizer_psi=optimizer_psi,
                num_epochs=args.num_epochs_psi,
                device=device,
                temperature_tau=args.temperature_tau,
                num_tasks=num_tasks,
                task_names=task_names,
                early_stopping_patience=args.early_stopping_patience_psi,
                experiment_output_dir=experiment_output_dir
            )
        elif num_tasks == 0:
            logger.info("Skipping Stage 2 training: No tasks defined.")
            training_summary_stage2 = {"status": "skipped", "reason": "No tasks"}
        elif not (psi_model and attention_subnetwork and optimizer_psi):
            logger.info("Skipping Stage 2 training: Ψ model, Attention, or Optimizer not properly initialized.")
            training_summary_stage2 = {"status": "skipped", "reason": "Model/Optimizer not ready"}
        else:
            logger.info("Skipping Stage 2 training: Training data is empty.")
            training_summary_stage2 = {"status": "skipped", "reason": "Empty train_loader"}
    else:
        logger.info("Stage 2 training explicitly skipped by user (--skip_stage2_training).")
        training_summary_stage2 = {"status": "skipped_by_user"}

    # --- Evaluation ---
    test_evaluation_results = None
    if not args.skip_evaluation:
        if test_loader and len(test_loader.dataset) > 0:
            # Use a basic MSE loss for g_model evaluation on test set
            g_eval_criterion = nn.MSELoss(reduction='mean') 
            test_evaluation_results = evaluate_models_on_test_set(
                g_models_list_eval=g_models_list,
                psi_model_eval=psi_model,
                attention_model_eval=attention_subnetwork,
                test_loader=test_loader,
                device=device,
                num_tasks=num_tasks,
                task_names=task_names,
                temperature_tau=args.temperature_tau,
                criterion_g_loss_fn=g_eval_criterion
            )
        else:
            logger.warning("Skipping evaluation: Test data is empty or not available.")
            test_evaluation_results = {"status": "skipped", "reason": "No test data"}
    else:
        logger.info("Evaluation explicitly skipped by user (--skip_evaluation).")
        test_evaluation_results = {"status": "skipped_by_user"}

    # --- Save Models and Final Summary ---
    # Consolidate all summaries and params for saving
    final_run_summaries = {
        'stage1_g_training': training_summary_stage1,
        'stage2_psi_training': training_summary_stage2,
        'test_set_evaluation': test_evaluation_results,
        'data_split_info_file': str(experiment_output_dir / f"{args.dataset_name}_split_indices.json") if data_split_indices else None
    }

    save_trained_models_and_config(
        g_models_to_save=g_models_list if num_tasks > 0 else None,
        psi_model_to_save=psi_model,
        attention_model_to_save=attention_subnetwork if num_tasks > 0 else None,
        model_config_params=model_architecture_params, # This now includes num_tasks, task_names etc.
        training_summaries=final_run_summaries, 
        run_params_to_log=vars(args), # Save all initial run args
        model_output_dir=str(experiment_output_dir), # Pass as string, save_trained_models_and_config will convert
        dataset_name_str=args.dataset_name
    )
    
    # Save a combined summary of the entire run at the top level of experiment_output_dir
    # experiment_output_dir is already a Path object from _setup_run_environment
    overall_summary_path = experiment_output_dir / "overall_run_summary.json"
    try:
        with open(overall_summary_path, 'w') as f_summary_json:
            json.dump({
                "run_parameters": vars(args),
                "model_architecture": model_architecture_params,
                "results_and_summaries": final_run_summaries
            }, f_summary_json, indent=4, cls=NpEncoder)
        logger.info(f"Overall run summary saved to: {overall_summary_path}")
    except IOError as e:
        logger.error(f"Failed to save overall_run_summary.json: {e}")

    logger.info(f"Unified Distillation Training Script finished for dataset '{args.dataset_name}'.")
    logger.info(f"All outputs, logs, and models are in: {experiment_output_dir}")


if __name__ == '__main__':
    main()