import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix, silhouette_score
from sklearn.utils.class_weight import compute_class_weight
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import logging
import os
import time
from tqdm import tqdm
import xgboost as xgb
import sys
from sklearn.model_selection import StratifiedKFold
import argparse

from contrastive_mapping_model_final import EnhancedMLPWithAttention

os.chdir(os.path.dirname(os.path.abspath(__file__)))
#print(f"Current working directory: {os.getcwd()}")

logger = logging.getLogger() 
log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

def setup_logging(log_filename, log_level_str):
    log_level = getattr(logging, log_level_str.upper(), logging.INFO)
    logger.setLevel(log_level)
    logger.handlers.clear()

    file_handler = logging.FileHandler(log_filename, mode='a')
    file_handler.setLevel(log_level)
    file_handler.setFormatter(log_format)
    logger.addHandler(file_handler)

    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setLevel(log_level)
    stream_handler.setFormatter(log_format)
    logger.addHandler(stream_handler)
    
    # Get a logger for the current module after setup
    return logging.getLogger(__name__)


X_TRAIN_FILENAME = "X_train.npy"
Y_TRAIN_FILENAME = "y_train.npy"
X_VAL_FILENAME = "X_val.npy"
Y_VAL_FILENAME = "y_val.npy"
X_TEST_FILENAME = "X_test.npy"
Y_TEST_FILENAME = "y_test.npy"


class LightweightTransformerClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim, num_layers, num_heads, dropout):
        super().__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=input_dim, nhead=num_heads, dim_feedforward=hidden_dim, batch_first=True, dropout=dropout, activation='relu'
            ),
            num_layers=num_layers
        )
        self.layer_norm = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        if x.ndim == 2:
            x = x.unsqueeze(1)
        x = self.transformer(x)
        x = x.squeeze(1)
        x = self.layer_norm(x)
        x = self.dropout(x)
        return self.fc(x)


def train_eval_transformer(X_train, y_train, X_val, y_val, X_test, y_test, 
                           class_weights_tensor, logger, input_dim, num_classes,
                           classifier_config, device, title_suffix="", 
                           evaluate_only=False, model_save_path=None):
    
    logger.info(f"Transformer{title_suffix}: Training on {len(X_train)}, Validating on {len(X_val)} (Input Dim: {input_dim})")

    train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long))
    val_dataset = TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long))
    test_dataset = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long))
    
    effective_batch_size = classifier_config['batch_size']

    train_loader = DataLoader(train_dataset, batch_size=effective_batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=effective_batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=effective_batch_size, shuffle=False, num_workers=2, pin_memory=True)

    model = LightweightTransformerClassifier(
        input_dim=input_dim, 
        num_classes=num_classes,
        hidden_dim=classifier_config['hidden_dim'],
        num_layers=classifier_config['num_layers'],
        num_heads=classifier_config['num_heads'],
        dropout=classifier_config['dropout']
    ).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=classifier_config['learning_rate'])
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor.to(device) if class_weights_tensor is not None else None)

    if not model_save_path:
        model_save_path = f"transformer_model{title_suffix.replace(' ', '_').replace('/', '_')}.pth"
        logger.info(f"Model save path not provided, using default: {model_save_path}")

    if evaluate_only:
        if not os.path.exists(model_save_path):
            logger.error(f"Model file not found at '{model_save_path}'. Cannot proceed with evaluation.")
            return None, None, None, None, None, None, None
        try:
            model.load_state_dict(torch.load(model_save_path, map_location=device), strict=False)
            logger.info(f"Loaded pre-trained Transformer{title_suffix} model from {model_save_path} for evaluation.")
        except Exception as e:
            logger.error(f"Error loading model state dict for Transformer{title_suffix}: {e}")
            return None, None, None, None, None, None, None
    else:
        best_val_loss = float('inf')
        best_model_state = None
        patience_counter = 0

        logger.info(f"Starting Transformer{title_suffix} training...")
        training_start_time = time.time()

        for epoch in range(classifier_config['max_epochs']):
            model.train()
            epoch_train_loss = 0.0
            for embeddings, labels in train_loader:
                embeddings, labels = embeddings.to(device), labels.to(device)
                optimizer.zero_grad()
                logits = model(embeddings)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()
                epoch_train_loss += loss.item()
            epoch_train_loss /= len(train_loader)

            model.eval()
            epoch_val_loss = 0.0
            with torch.no_grad():
                for embeddings, labels in val_loader:
                    embeddings, labels = embeddings.to(device), labels.to(device)
                    logits = model(embeddings)
                    loss = criterion(logits, labels)
                    epoch_val_loss += loss.item()
            epoch_val_loss /= len(val_loader)

            logger.info(f"Epoch {epoch+1}/{classifier_config['max_epochs']}, Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")

            if epoch_val_loss < best_val_loss:
                best_val_loss = epoch_val_loss
                best_model_state = model.state_dict().copy()
                patience_counter = 0
                torch.save(best_model_state, model_save_path)
                logger.info(f"Validation loss improved to {best_val_loss:.4f}, saved Transformer{title_suffix} model to {model_save_path}")
            else:
                patience_counter += 1
                if patience_counter >= classifier_config['patience']:
                    logger.info(f"Early stopping triggered for Transformer{title_suffix} at epoch {epoch+1}.")
                    break
        
        training_end_time = time.time()
        logger.info(f"Transformer{title_suffix} training finished. Total time: {training_end_time - training_start_time:.2f} seconds.")

        if best_model_state is None:
            logger.error(f"Training did not improve validation loss for Transformer{title_suffix}. Cannot evaluate with this model.")
            if os.path.exists(model_save_path): # If a model from a previous run exists, try loading that
                 logger.warning(f"Attempting to load existing model from {model_save_path} due to failed training.")
                 try:
                     model.load_state_dict(torch.load(model_save_path, map_location=device))
                     logger.info(f"Successfully loaded existing model {model_save_path} for evaluation.")
                 except Exception as e:
                     logger.error(f"Failed to load existing model {model_save_path}: {e}. Evaluation will be skipped.")
                     return None, None, None, None, None, None, None
            else:
                return None, None, None, None, None, None, None
        else:
            model.load_state_dict(best_model_state)

    model.eval()
    all_test_preds = []
    all_test_probs = []
    all_test_labels = []

    with torch.no_grad():
        for embeddings, labels in test_loader:
            embeddings = embeddings.to(device)
            logits = model(embeddings)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            all_test_preds.extend(preds.cpu().numpy())
            if args.num_classes == 2:
                all_test_probs.extend(probs[:, 1].cpu().numpy())
            else:
                all_test_probs.extend(probs.cpu().numpy())
            all_test_labels.extend(labels.cpu().numpy())
    
    if not all_test_labels: # Should not happen if test_loader is not empty
        logger.warning(f"Transformer{title_suffix}: No test labels found, skipping metrics calculation.")
        return 0,0,0,0,0, [], []

    if args.num_classes == 2:
        test_auc = roc_auc_score(all_test_labels, all_test_probs)
    else:
        test_auc = roc_auc_score(all_test_labels, all_test_probs, multi_class='ovr')
    f1_scores_all_classes = f1_score(all_test_labels, all_test_preds, average=None, zero_division=0)
    test_f1_minority = f1_scores_all_classes[0] if len(f1_scores_all_classes) > 0 else 0.0
    test_f1_overall = f1_score(all_test_labels, all_test_preds, average='weighted', zero_division=0)
    test_precision = precision_score(all_test_labels, all_test_preds, average='weighted', zero_division=0)
    test_recall = recall_score(all_test_labels, all_test_preds, average='weighted', zero_division=0)

    logger.info(f"\n--- Transformer{title_suffix} Test Set Performance ---")
    logger.info(f"AUC: {test_auc:.4f}")
    #logger.info(f"F1 Score (Class 0): {test_f1_minority:.4f}") # Changed from "Minority Class 1" to "Class 0"
    logger.info(f"Overall F1 Score: {test_f1_overall:.4f}")
    logger.info(f"Precision: {test_precision:.4f}")
    logger.info(f"Recall: {test_recall:.4f}")
    logger.info(f"F1 Scores per class: {f1_scores_all_classes}")


    return test_auc, test_f1_minority, test_f1_overall, test_precision, test_recall, all_test_preds, all_test_probs


def transform_embeddings(contrastive_model, raw_embeddings_np, batch_size=1024, device="cpu", logger=None):
    if logger: logger.info(f"Transforming {len(raw_embeddings_np)} embeddings using the contrastive model...")
    contrastive_model.eval()
    optimized_embeddings_list = []
    dataset = TensorDataset(torch.tensor(raw_embeddings_np, dtype=torch.float32))
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    with torch.no_grad():
        for batch_data in tqdm(loader, desc="Transforming Embeddings", leave=False, disable=logger is None):
            embeddings_batch = batch_data[0].to(device)
            model_output = contrastive_model(embeddings_batch)
            optimized_batch = model_output["enhanced_features"]
            optimized_embeddings_list.append(optimized_batch.cpu().numpy())

    optimized_embeddings_np = np.concatenate(optimized_embeddings_list, axis=0)
    if logger: logger.info(f"Transformation complete. Optimized shape: {optimized_embeddings_np.shape}")
    return optimized_embeddings_np


def compute_regression_metrics(y_true, y_pred):
    mse = np.mean((y_true - y_pred) ** 2)
    rmse = np.sqrt(mse)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
    ss_res = np.sum((y_true - y_pred) ** 2)
    r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
    
    mse_per_class = {}
    unique_classes = np.unique(y_true) # Use actual unique classes in this fold
    for cls_val in unique_classes:
        mask = (y_true == cls_val)
        if np.sum(mask) > 0:
            mse_cls = np.mean((y_true[mask] - y_pred[mask]) ** 2)
        else:
            mse_cls = 0.0 
        mse_per_class[int(cls_val)] = mse_cls
    return mse, rmse, r2, mse_per_class


def calculate_variance_metrics(embeddings, labels):
    classes = np.unique(labels)
    n_classes = len(classes)
    
    within_class_var_list = []
    class_means = np.zeros((n_classes, embeddings.shape[1]))
    class_counts = np.zeros(n_classes)
    
    for i, c in enumerate(classes):
        class_mask = (labels == c)
        class_embeddings = embeddings[class_mask]
        class_counts[i] = len(class_embeddings)
        if class_counts[i] > 0:
            class_means[i] = np.mean(class_embeddings, axis=0)
            if len(class_embeddings) > 0: 
                within_class_var_list.append(np.sum((class_embeddings - class_means[i])**2))
        else: # Should not happen if classes are from np.unique(labels)
            class_means[i] = np.zeros(embeddings.shape[1])


    total_sum_sq_within = np.sum(within_class_var_list)
    total_samples = np.sum(class_counts)
    avg_within_class_var = total_sum_sq_within / total_samples if total_samples > 0 else 0

    global_mean = np.mean(embeddings, axis=0)
    between_class_ss = np.sum(class_counts * np.sum((class_means - global_mean)**2, axis=1))
    between_class_var = between_class_ss / total_samples if total_samples > 0 else 0
     
    ratio = avg_within_class_var / between_class_var if between_class_var > 0 else float('inf')
    
    class_specific_var_values = {}
    current_within_idx = 0
    for i, c in enumerate(classes):
        if class_counts[i] > 0:
             # Variance is sum of squares / N
             class_specific_var_values[int(c)] = np.sum((embeddings[labels == c] - class_means[i])**2) / class_counts[i]
        else:
             class_specific_var_values[int(c)] = 0.0

    return {
        'within_class_var': avg_within_class_var,
        'between_class_var': between_class_var,
        'ratio': ratio,
        'class_specific_var': class_specific_var_values 
    }

def calculate_class_overlap(embeddings, labels):
    from sklearn.neighbors import NearestNeighbors
    
    if len(embeddings) < 11 : # k for kneighbors is 11
        return 0.0 # Not enough samples for meaningful overlap

    nn_model = NearestNeighbors(n_neighbors=min(11, len(embeddings))) 
    nn_model.fit(embeddings)
    
    distances, indices = nn_model.kneighbors(embeddings)
    
    overlap_scores = []
    for i, idx_list in enumerate(indices):
        current_label = labels[i]
        if len(idx_list) > 1: # Ensure there are neighbors
            neighbor_labels = labels[idx_list[1:]] 
            overlap = np.mean(neighbor_labels != current_label)
            overlap_scores.append(overlap)
        else: # Only self as neighbor
            overlap_scores.append(0.0) # No different class neighbors
            
    return np.mean(overlap_scores) if overlap_scores else 0.0

def calculate_fisher_ratio(embeddings, labels):
    classes = np.unique(labels)
    n_classes = len(classes)
    if n_classes < 2: return 0.0
    
    fisher_ratios_pairwise = []
    
    class_means_dict = {}
    class_vars_dict = {}
    for c_val in classes:
        class_mask = (labels == c_val)
        class_embeddings = embeddings[class_mask]
        if len(class_embeddings) > 0:
            class_means_dict[c_val] = np.mean(class_embeddings, axis=0)
            class_vars_dict[c_val] = np.var(class_embeddings, axis=0) # Variance per dimension
        else: # Should ideally not happen if classes come from np.unique
            class_means_dict[c_val] = np.zeros(embeddings.shape[1])
            class_vars_dict[c_val] = np.zeros(embeddings.shape[1])

    for i in range(n_classes):
        for j in range(i + 1, n_classes):
            c1, c2 = classes[i], classes[j]
            if c1 not in class_means_dict or c2 not in class_means_dict: continue

            numerator_per_dim = (class_means_dict[c1] - class_means_dict[c2])**2
            denominator_per_dim = class_vars_dict[c1] + class_vars_dict[c2]
            
            valid_dims_mask = denominator_per_dim > 1e-9 # Avoid division by zero or tiny values
            if np.any(valid_dims_mask):
                dim_ratios = numerator_per_dim[valid_dims_mask] / denominator_per_dim[valid_dims_mask]
                fisher_ratios_pairwise.append(np.mean(dim_ratios)) 
            
    return np.mean(fisher_ratios_pairwise) if fisher_ratios_pairwise else 0.0


def parse_arguments():
    parser = argparse.ArgumentParser(description="Ordinal Classification Experiment Script")
    
    g_data = parser.add_argument_group('Data and Model Paths')
    g_data.add_argument('--data_prefix', type=str, default="", help="Prefix for data NPY file paths.")
    g_data.add_argument('--contrastive_model_path', type=str, default="contrastive_model.pth", help="Path to the pre-trained contrastive model.")
    g_data.add_argument('--log_filename', type=str, default='experiment_log.log', help="Name of the log file.")
    g_data.add_argument('--log_level', type=str, default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help="Logging level.")

    g_cm_arch = parser.add_argument_group('Contrastive Model Architecture (for loading)')
    g_cm_arch.add_argument('--contrastive_input_dim', type=int, default=1024, help="Input dimension for the contrastive model.")
    g_cm_arch.add_argument('--contrastive_hidden_dims', type=int, nargs='+', default=[512, 256], help="Hidden dimensions for the contrastive model's MLP.")
    g_cm_arch.add_argument('--contrastive_output_dim', type=int, default=64, help="Output dimension of the contrastive model.")
    g_cm_arch.add_argument('--num_classes', type=int, default=5, help="Number of classes in the dataset.")
    g_cm_arch.add_argument('--contrastive_alpha_mix', type=float, default=0.3, help="Alpha mix parameter for the contrastive model (if its architecture expects it).")

    g_clf_params = parser.add_argument_group('Lightweight Transformer Classifier Parameters')
    g_clf_params.add_argument('--classifier_hidden_dim', type=int, default=128)
    g_clf_params.add_argument('--classifier_num_layers', type=int, default=2)
    g_clf_params.add_argument('--classifier_num_heads', type=int, default=2)
    g_clf_params.add_argument('--classifier_dropout', type=float, default=0.2)
    g_clf_params.add_argument('--classifier_learning_rate', type=float, default=1e-5)
    g_clf_params.add_argument('--classifier_batch_size', type=int, default=512)
    g_clf_params.add_argument('--classifier_max_epochs', type=int, default=200)
    g_clf_params.add_argument('--classifier_patience', type=int, default=10)
    
    g_eval_flags = parser.add_argument_group('Evaluation Flags')
    g_eval_flags.add_argument('--enable_magnitude_loss_metrics', action='store_true', help="Enable calculation and reporting of ordinal metrics (MSE, RMSE, R^2) assuming magnitude loss was relevant.")
    g_eval_flags.add_argument('--eval_only_transformer_raw', action='store_true', help="Load pre-trained Transformer (RAW) model instead of training.")
    g_eval_flags.add_argument('--transformer_raw_model_save_path', type=str, default="transformer_model_RAW.pth", help="Path for saving/loading Transformer (RAW) model.")
    g_eval_flags.add_argument('--eval_only_transformer_opt', action='store_true', help="Load pre-trained Transformer (Optimized) model instead of training.")
    g_eval_flags.add_argument('--transformer_opt_model_save_path', type=str, default="transformer_model_Optimized.pth", help="Path for saving/loading Transformer (Optimized) model.")
    g_eval_flags.add_argument('--n_splits_kfold', type=int, default=10, help="Number of splits for K-fold cross-validation.")
    
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    logger = setup_logging(args.log_filename, args.log_level)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    logger.info(f"Using data prefix: {args.data_prefix}")
    logger.info(f"Using mapping model: {args.contrastive_model_path}")
    logger.info(f"Number of classes: {args.num_classes}")
    if args.enable_magnitude_loss_metrics:
        logger.info("Ordinal regression metrics (MSE, RMSE, R^2) will be calculated and reported.")

    X_TRAIN_PATH = f"{args.data_prefix}{X_TRAIN_FILENAME}"
    Y_TRAIN_PATH = f"{args.data_prefix}{Y_TRAIN_FILENAME}"
    X_VAL_PATH = f"{args.data_prefix}{X_VAL_FILENAME}"
    Y_VAL_PATH = f"{args.data_prefix}{Y_VAL_FILENAME}"
    X_TEST_PATH = f"{args.data_prefix}{X_TEST_FILENAME}"
    Y_TEST_PATH = f"{args.data_prefix}{Y_TEST_FILENAME}"
    
    classifier_config = {
        'hidden_dim': args.classifier_hidden_dim,
        'num_layers': args.classifier_num_layers,
        'num_heads': args.classifier_num_heads,
        'dropout': args.classifier_dropout,
        'learning_rate': args.classifier_learning_rate,
        'batch_size': args.classifier_batch_size,
        'max_epochs': args.classifier_max_epochs,
        'patience': args.classifier_patience,
    }

    logger.info("Loading RAW pre-processed data...")
    try:
        X_train_raw = np.load(X_TRAIN_PATH)
        y_train = np.load(Y_TRAIN_PATH)
        X_val_raw = np.load(X_VAL_PATH)
        y_val = np.load(Y_VAL_PATH)
        X_test_raw = np.load(X_TEST_PATH)
        y_test = np.load(Y_TEST_PATH)
        logger.info("RAW Data loaded successfully.")
        logger.info(f"Raw Train shapes: X={X_train_raw.shape}, y={y_train.shape}")
        logger.info(f"Raw Val shapes:   X={X_val_raw.shape}, y={y_val.shape}")
        logger.info(f"Raw Test shapes:  X={X_test_raw.shape}, y={y_test.shape}")
    except FileNotFoundError as e:
        logger.error(f"Error loading data files: {e}. Please ensure .npy files exist at specified paths based on prefix '{args.data_prefix}'.")
        sys.exit(1)

    logger.info("Calculating class weights...")
    class_weights_tensor = None
    class_weights_dict = None
    try:
        unique_classes = np.unique(y_train)
        logger.info(f"Unique classes found in y_train: {unique_classes}")
        if not np.array_equal(unique_classes, np.arange(len(unique_classes))):
             logger.warning("Class labels are not 0-indexed contiguous. This might affect class weight indexing if not handled carefully.")
        
        weights = compute_class_weight(class_weight='balanced', classes=unique_classes, y=y_train)
        
        temp_class_weights_dict = dict(zip(unique_classes, weights))
        
        # Ensure class_weights_tensor has weights for all classes from 0 to num_classes-1
        class_weights_list_for_tensor = [temp_class_weights_dict.get(i, 1.0) for i in range(args.num_classes)] # Default to 1.0 if a class is missing in y_train
        class_weights_tensor = torch.tensor(class_weights_list_for_tensor, dtype=torch.float32)

        # For sklearn models that might take dicts keyed by actual class labels
        class_weights_dict = 'balanced' # Most sklearn classifiers accept 'balanced' directly
                                    # or use the temp_class_weights_dict if specific weights per actual label are needed
        
        logger.info(f"Calculated Class Weights for PyTorch (for {args.num_classes} classes): {class_weights_tensor.numpy()}")
        logger.info(f"Using 'balanced' for scikit-learn classifiers or specific weights: {temp_class_weights_dict if class_weights_dict != 'balanced' else 'balanced'}")

    except Exception as e:
        logger.error(f"Error calculating class weights: {e}")
        logger.warning(f"Using default uniform weights for {args.num_classes} classes.")
        class_weights_tensor = torch.ones(args.num_classes, dtype=torch.float32)
        class_weights_dict = {i: 1.0 for i in range(args.num_classes)} # Fallback for sklearn

    
    logger.info("\n" + "="*10 + " Experiment 1: SVM Baseline (RAW Embeddings) " + "="*10)
    svm_raw_start_time = time.time()
    svm_raw = SVC(kernel="linear", class_weight=class_weights_dict, random_state=42, probability=True)
    logger.info("Training SVM (RAW)...")
    svm_raw.fit(X_train_raw, y_train)
    svm_raw_train_end_time = time.time()
    logger.info(f"SVM (RAW) training finished. Time: {svm_raw_train_end_time - svm_raw_start_time:.2f} seconds.")
    logger.info("Evaluating SVM (RAW) on the test set...")
    y_pred_svm_raw = svm_raw.predict(X_test_raw)
    y_prob_svm_raw = svm_raw.predict_proba(X_test_raw)
    if args.num_classes == 2:
        y_prob_svm_raw = y_prob_svm_raw[:, 1]
        svm_raw_test_auc = roc_auc_score(y_test, y_prob_svm_raw)
    else:
        svm_raw_test_auc = roc_auc_score(y_test, y_prob_svm_raw, multi_class='ovr')
    svm_raw_f1_scores = f1_score(y_test, y_pred_svm_raw, average=None, zero_division=0)
    svm_raw_test_f1_minority = svm_raw_f1_scores[0] if len(svm_raw_f1_scores) > 0 else 0.0
    svm_raw_test_f1_overall = f1_score(y_test, y_pred_svm_raw, average='weighted', zero_division=0)
    svm_raw_test_precision = precision_score(y_test, y_pred_svm_raw, average='weighted', zero_division=0)
    svm_raw_test_recall = recall_score(y_test, y_pred_svm_raw, average='weighted', zero_division=0)
    logger.info(f"--- SVM (RAW) Test Set Performance ---")
    logger.info(f"AUC: {svm_raw_test_auc:.4f}")
    logger.info(f"F1 Score (Class 0): {svm_raw_test_f1_minority:.4f}")
    logger.info(f"Overall F1 Score: {svm_raw_test_f1_overall:.4f}")
    logger.info(f"Precision: {svm_raw_test_precision:.4f}")
    logger.info(f"Recall: {svm_raw_test_recall:.4f}")
    logger.info(f"F1 Scores per class: {svm_raw_f1_scores}")


    logger.info("\n" + "="*10 + " Experiment 2: Random Forest Baseline (RAW Embeddings) " + "="*10)
    rf_raw_start_time = time.time()
    rf_raw = RandomForestClassifier(n_estimators=100, class_weight=class_weights_dict, random_state=42, n_jobs=-1)
    logger.info("Training Random Forest (RAW)...")
    rf_raw.fit(X_train_raw, y_train)
    rf_raw_train_end_time = time.time()
    logger.info(f"RF (RAW) training finished. Time: {rf_raw_train_end_time - rf_raw_start_time:.2f} seconds.")
    logger.info("Evaluating Random Forest (RAW) on the test set...")
    y_pred_rf_raw = rf_raw.predict(X_test_raw)
    y_prob_rf_raw = rf_raw.predict_proba(X_test_raw)
    if args.num_classes == 2:
        y_prob_rf_raw = y_prob_rf_raw[:, 1]
        rf_raw_test_auc = roc_auc_score(y_test, y_prob_rf_raw)
    else:
        rf_raw_test_auc = roc_auc_score(y_test, y_prob_rf_raw, multi_class='ovr')
    rf_raw_f1_scores = f1_score(y_test, y_pred_rf_raw, average=None, zero_division=0)
    rf_raw_test_f1_minority = rf_raw_f1_scores[0] if len(rf_raw_f1_scores) > 0 else 0.0
    rf_raw_test_f1_overall = f1_score(y_test, y_pred_rf_raw, average='weighted', zero_division=0)
    rf_raw_test_precision = precision_score(y_test, y_pred_rf_raw, average='weighted', zero_division=0)
    rf_raw_test_recall = recall_score(y_test, y_pred_rf_raw, average='weighted', zero_division=0)
    logger.info(f"--- RF (RAW) Test Set Performance ---")
    logger.info(f"AUC: {rf_raw_test_auc:.4f}")
    logger.info(f"F1 Score (Class 0): {rf_raw_test_f1_minority:.4f}")
    logger.info(f"Overall F1 Score: {rf_raw_test_f1_overall:.4f}")
    logger.info(f"Precision: {rf_raw_test_precision:.4f}")
    logger.info(f"Recall: {rf_raw_test_recall:.4f}")
    logger.info(f"F1 Scores per class: {rf_raw_f1_scores}")

    logger.info("\n" + "="*10 + " Experiment 3: Logistic Regression Baseline (RAW Embeddings) " + "="*10)
    lr_raw_start_time = time.time()
    log_reg_raw = LogisticRegression(class_weight='balanced', max_iter=1000, random_state=42, n_jobs=1)
    logger.info("Training Logistic Regression (RAW)...")
    log_reg_raw.fit(X_train_raw, y_train)
    lr_raw_train_end_time = time.time()
    logger.info(f"LR (RAW) training finished. Time: {lr_raw_train_end_time - lr_raw_start_time:.2f} seconds.")
    logger.info("Evaluating Logistic Regression (RAW) on the test set...")
    y_pred_lr_raw = log_reg_raw.predict(X_test_raw)
    y_prob_lr_raw = log_reg_raw.predict_proba(X_test_raw)
    if args.num_classes == 2:
        y_prob_lr_raw = y_prob_lr_raw[:, 1]
        lr_raw_test_auc = roc_auc_score(y_test, y_prob_lr_raw)
    else:
        lr_raw_test_auc = roc_auc_score(y_test, y_prob_lr_raw, multi_class='ovr')
    lr_raw_f1_scores = f1_score(y_test, y_pred_lr_raw, average=None, zero_division=0)
    lr_raw_test_f1_minority = lr_raw_f1_scores[0] if len(lr_raw_f1_scores) > 0 else 0.0
    lr_raw_test_f1_overall = f1_score(y_test, y_pred_lr_raw, average='weighted', zero_division=0)
    lr_raw_test_precision = precision_score(y_test, y_pred_lr_raw, average='weighted', zero_division=0)
    lr_raw_test_recall = recall_score(y_test, y_pred_lr_raw, average='weighted', zero_division=0)
    logger.info(f"--- LR (RAW) Test Set Performance ---")
    logger.info(f"AUC: {lr_raw_test_auc:.4f}")
    logger.info(f"F1 Score (Class 0): {lr_raw_test_f1_minority:.4f}")
    logger.info(f"Overall F1 Score: {lr_raw_test_f1_overall:.4f}")
    logger.info(f"Precision: {lr_raw_test_precision:.4f}")
    logger.info(f"Recall: {lr_raw_test_recall:.4f}")
    logger.info(f"F1 Scores per class: {lr_raw_f1_scores}")

    logger.info("\n" + "="*10 + " Experiment 4: XGBoost Baseline (RAW Embeddings) " + "="*10)
    logger.info("Training XGBoost (RAW)...")
    start_time = time.time()
    xgb_model_raw = xgb.XGBClassifier(
        objective='multi:softprob', 
        eval_metric='mlogloss', 
        num_class=args.num_classes,
        max_depth=6,
        learning_rate=0.1,
        n_estimators=150,
        random_state=42,
        use_label_encoder=False 
    )
    xgb_model_raw.fit(X_train_raw, y_train, eval_set=[(X_val_raw, y_val)], early_stopping_rounds=10, verbose=False)
    logger.info(f"XGBoost (RAW) training finished. Time: {time.time() - start_time:.2f} seconds.")
    logger.info("Evaluating XGBoost (RAW) on the test set...")
    xgb_probs_raw = xgb_model_raw.predict_proba(X_test_raw)
    xgb_preds_raw = np.argmax(xgb_probs_raw, axis=1) 
    if args.num_classes == 2:
        xgb_probs_raw = xgb_probs_raw[:, 1]
        xgb_raw_test_auc = roc_auc_score(y_test, xgb_probs_raw)
    else:
        xgb_raw_test_auc = roc_auc_score(y_test, xgb_probs_raw, multi_class='ovr')
    xgb_raw_f1_scores = f1_score(y_test, xgb_preds_raw, average=None, zero_division=0)
    xgb_raw_test_f1_minority = xgb_raw_f1_scores[0] if len(xgb_raw_f1_scores) > 0 else 0.0
    xgb_raw_test_f1_overall = f1_score(y_test, xgb_preds_raw, average='weighted', zero_division=0)
    xgb_raw_test_precision = precision_score(y_test, xgb_preds_raw, average='weighted', zero_division=0)
    xgb_raw_test_recall = recall_score(y_test, xgb_preds_raw, average='weighted', zero_division=0)
    logger.info(f"--- XGBoost (RAW) Test Set Performance ---")
    logger.info(f"AUC: {xgb_raw_test_auc:.4f}")
    logger.info(f"F1 Score (Class 0): {xgb_raw_test_f1_minority:.4f}")
    logger.info(f"Overall F1 Score: {xgb_raw_test_f1_overall:.4f}")
    logger.info(f"Precision: {xgb_raw_test_precision:.4f}")
    logger.info(f"Recall: {xgb_raw_test_recall:.4f}")
    logger.info(f"F1 Scores per class: {xgb_raw_f1_scores}")


    logger.info("\n" + "="*10 + " Experiment 5: Lightweight Transformer Baseline (RAW Embeddings) " + "="*10)
    transformer_raw_results = train_eval_transformer(
        X_train_raw, y_train, X_val_raw, y_val, X_test_raw, y_test,
        class_weights_tensor, logger, input_dim=args.contrastive_input_dim, 
        num_classes=args.num_classes, classifier_config=classifier_config, device=device,
        title_suffix=" (RAW)", evaluate_only=args.eval_only_transformer_raw,
        model_save_path=args.transformer_raw_model_save_path
    )
    if transformer_raw_results[0] is not None:
      transformer_raw_auc, transformer_raw_f1_minority, transformer_raw_f1_overall, transformer_raw_precision, transformer_raw_recall, _, _ = transformer_raw_results
    else:
      transformer_raw_auc, transformer_raw_f1_minority, transformer_raw_f1_overall, transformer_raw_precision, transformer_raw_recall = 0,0,0,0,0


    logger.info("\n" + "="*10 + " Phase: Generating Optimized Embeddings " + "="*10)
    logger.info(f"Loading trained contrastive model from: {args.contrastive_model_path}")
    if not os.path.exists(args.contrastive_model_path):
        logger.error(f"Contrastive model file not found at '{args.contrastive_model_path}'. Cannot proceed.")
        sys.exit(1)
    
    contrastive_model = EnhancedMLPWithAttention(
        input_dim=args.contrastive_input_dim,
        hidden_dims=args.contrastive_hidden_dims,
        output_dim=args.contrastive_output_dim,
        num_classes=args.num_classes, 
        alpha_mix=args.contrastive_alpha_mix 
    ).to(device)
    
    try:
        checkpoint = torch.load(args.contrastive_model_path, map_location=device)
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
        
        contrastive_model.load_state_dict(state_dict, strict=False)
        logger.info("Contrastive model loaded successfully.")
    except Exception as e:
        logger.error(f"Error loading contrastive model state dict: {e}")
        sys.exit(1)

    X_train_opt = transform_embeddings(contrastive_model, X_train_raw, classifier_config['batch_size'], device, logger)
    X_val_opt = transform_embeddings(contrastive_model, X_val_raw, classifier_config['batch_size'], device, logger)
    X_test_opt = transform_embeddings(contrastive_model, X_test_raw, classifier_config['batch_size'], device, logger)


    logger.info("\n" + "="*10 + " Experiment 6: SVM (OPTIMIZED Embeddings) " + "="*10)
    svm_opt_start_time = time.time()
    svm_opt = SVC(kernel="linear", class_weight=class_weights_dict, random_state=42, probability=True)
    logger.info("Training SVM (Optimized)...")
    svm_opt.fit(X_train_opt, y_train)
    svm_opt_train_end_time = time.time()
    logger.info(f"SVM (Optimized) training finished. Time: {svm_opt_train_end_time - svm_opt_start_time:.2f} seconds.")
    logger.info("Evaluating SVM (Optimized) on the test set...")
    y_pred_svm_opt = svm_opt.predict(X_test_opt)
    y_prob_svm_opt = svm_opt.predict_proba(X_test_opt)
    if args.num_classes == 2:
        y_prob_svm_opt = y_prob_svm_opt[:, 1]
        svm_opt_test_auc = roc_auc_score(y_test, y_prob_svm_opt)
    else:
        svm_opt_test_auc = roc_auc_score(y_test, y_prob_svm_opt, multi_class='ovr')
    svm_opt_f1_scores = f1_score(y_test, y_pred_svm_opt, average=None, zero_division=0)
    svm_opt_test_f1_minority = svm_opt_f1_scores[0] if len(svm_opt_f1_scores) > 0 else 0.0
    svm_opt_test_f1_overall = f1_score(y_test, y_pred_svm_opt, average='weighted', zero_division=0)
    svm_opt_test_precision = precision_score(y_test, y_pred_svm_opt, average='weighted', zero_division=0)
    svm_opt_test_recall = recall_score(y_test, y_pred_svm_opt, average='weighted', zero_division=0)
    logger.info(f"--- SVM (Optimized) Test Set Performance ---")
    logger.info(f"AUC: {svm_opt_test_auc:.4f}")
    logger.info(f"F1 Score (Class 0): {svm_opt_test_f1_minority:.4f}")
    logger.info(f"Overall F1 Score: {svm_opt_test_f1_overall:.4f}")
    logger.info(f"Precision: {svm_opt_test_precision:.4f}")
    logger.info(f"Recall: {svm_opt_test_recall:.4f}")
    logger.info(f"F1 Scores per class: {svm_opt_f1_scores}")


    logger.info("\n" + "="*10 + " Experiment 7: Random Forest (OPTIMIZED Embeddings) " + "="*10)
    rf_opt_start_time = time.time()
    rf_opt = RandomForestClassifier(n_estimators=100, class_weight=class_weights_dict, random_state=42, n_jobs=-1)
    logger.info("Training Random Forest (Optimized)...")
    rf_opt.fit(X_train_opt, y_train)
    rf_opt_train_end_time = time.time()
    logger.info(f"RF (Optimized) training finished. Time: {rf_opt_train_end_time - rf_opt_start_time:.2f} seconds.")
    logger.info("Evaluating Random Forest (Optimized) on the test set...")
    y_pred_rf_opt = rf_opt.predict(X_test_opt)
    y_prob_rf_opt = rf_opt.predict_proba(X_test_opt)
    if args.num_classes == 2:
        y_prob_rf_opt = y_prob_rf_opt[:, 1]
        rf_opt_test_auc = roc_auc_score(y_test, y_prob_rf_opt)
    else:
        rf_opt_test_auc = roc_auc_score(y_test, y_prob_rf_opt, multi_class='ovr')
    rf_opt_f1_scores = f1_score(y_test, y_pred_rf_opt, average=None, zero_division=0)
    rf_opt_test_f1_minority = rf_opt_f1_scores[0] if len(rf_opt_f1_scores) > 0 else 0.0
    rf_opt_test_f1_overall = f1_score(y_test, y_pred_rf_opt, average='weighted', zero_division=0)
    rf_opt_test_precision = precision_score(y_test, y_pred_rf_opt, average='weighted', zero_division=0)
    rf_opt_test_recall = recall_score(y_test, y_pred_rf_opt, average='weighted', zero_division=0)
    logger.info(f"--- RF (Optimized) Test Set Performance ---")
    logger.info(f"AUC: {rf_opt_test_auc:.4f}")
    logger.info(f"F1 Score (Class 0): {rf_opt_test_f1_minority:.4f}")
    logger.info(f"Overall F1 Score: {rf_opt_test_f1_overall:.4f}")
    logger.info(f"Precision: {rf_opt_test_precision:.4f}")
    logger.info(f"Recall: {rf_opt_test_recall:.4f}")
    logger.info(f"F1 Scores per class: {rf_opt_f1_scores}")

    logger.info("\n" + "="*10 + " Experiment 8: Logistic Regression (OPTIMIZED Embeddings) " + "="*10)
    lr_opt_start_time = time.time()
    log_reg_opt = LogisticRegression(class_weight='balanced', max_iter=1000, random_state=42, solver='liblinear', n_jobs=1)
    logger.info("Training Logistic Regression (Optimized)...")
    log_reg_opt.fit(X_train_opt, y_train)
    lr_opt_train_end_time = time.time()
    logger.info(f"LR (Optimized) training finished. Time: {lr_opt_train_end_time - lr_opt_start_time:.2f} seconds.")
    logger.info("Evaluating Logistic Regression (Optimized) on the test set...")
    y_pred_lr_opt = log_reg_opt.predict(X_test_opt)
    y_prob_lr_opt = log_reg_opt.predict_proba(X_test_opt)
    if args.num_classes == 2:
        y_prob_lr_opt = y_prob_lr_opt[:, 1]
        lr_opt_test_auc = roc_auc_score(y_test, y_prob_lr_opt)
    else:
        lr_opt_test_auc = roc_auc_score(y_test, y_prob_lr_opt, multi_class='ovr')
    lr_opt_f1_scores = f1_score(y_test, y_pred_lr_opt, average=None, zero_division=0)
    lr_opt_test_f1_minority = lr_opt_f1_scores[0] if len(lr_opt_f1_scores) > 0 else 0.0
    lr_opt_test_f1_overall = f1_score(y_test, y_pred_lr_opt, average='weighted', zero_division=0)
    lr_opt_test_precision = precision_score(y_test, y_pred_lr_opt, average='weighted', zero_division=0)
    lr_opt_test_recall = recall_score(y_test, y_pred_lr_opt, average='weighted', zero_division=0)
    logger.info(f"--- LR (Optimized) Test Set Performance ---")
    logger.info(f"AUC: {lr_opt_test_auc:.4f}")
    logger.info(f"F1 Score (Class 0): {lr_opt_test_f1_minority:.4f}")
    logger.info(f"Overall F1 Score: {lr_opt_test_f1_overall:.4f}")
    logger.info(f"Precision: {lr_opt_test_precision:.4f}")
    logger.info(f"Recall: {lr_opt_test_recall:.4f}")
    logger.info(f"F1 Scores per class: {lr_opt_f1_scores}")

    logger.info("\n" + "="*10 + " Experiment 9: XGBoost (OPTIMIZED Embeddings) " + "="*10)
    logger.info("Training XGBoost (Optimized)...")
    xgb_opt_start_time = time.time()
    xgb_model_opt = xgb.XGBClassifier(
        objective='multi:softprob', 
        eval_metric='mlogloss',
        num_class=args.num_classes,
        max_depth=6,
        learning_rate=0.1,
        n_estimators=150,
        random_state=42,
        use_label_encoder=False 
    )
    xgb_model_opt.fit(X_train_opt, y_train, eval_set=[(X_val_opt, y_val)], early_stopping_rounds=10, verbose=False)
    xgb_opt_train_end_time = time.time()
    logger.info(f"XGBoost (Optimized) training finished. Time: {xgb_opt_train_end_time - xgb_opt_start_time:.2f} seconds.")
    logger.info("Evaluating XGBoost (Optimized) on the test set...")
    y_prob_xgb_opt = xgb_model_opt.predict_proba(X_test_opt)
    y_pred_xgb_opt = np.argmax(y_prob_xgb_opt, axis=1)
    if args.num_classes == 2:
        y_prob_xgb_opt = y_prob_xgb_opt[:, 1]
        xgb_opt_test_auc = roc_auc_score(y_test, y_prob_xgb_opt)
    else:
        xgb_opt_test_auc = roc_auc_score(y_test, y_prob_xgb_opt, multi_class='ovr')
    xgb_opt_f1_scores = f1_score(y_test, y_pred_xgb_opt, average=None, zero_division=0)
    xgb_opt_test_f1_minority = xgb_opt_f1_scores[0] if len(xgb_opt_f1_scores) > 0 else 0.0
    xgb_opt_test_f1_overall = f1_score(y_test, y_pred_xgb_opt, average='weighted', zero_division=0)
    xgb_opt_test_precision = precision_score(y_test, y_pred_xgb_opt, average='weighted', zero_division=0)
    xgb_opt_test_recall = recall_score(y_test, y_pred_xgb_opt, average='weighted', zero_division=0)
    logger.info(f"--- XGBoost (Optimized) Test Set Performance ---")
    logger.info(f"AUC: {xgb_opt_test_auc:.4f}")
    logger.info(f"F1 Score (Class 0): {xgb_opt_test_f1_minority:.4f}")
    logger.info(f"Overall F1 Score: {xgb_opt_test_f1_overall:.4f}")
    logger.info(f"Precision: {xgb_opt_test_precision:.4f}")
    logger.info(f"Recall: {xgb_opt_test_recall:.4f}")
    logger.info(f"F1 Scores per class: {xgb_opt_f1_scores}")

    logger.info("\n" + "="*10 + " Experiment 10: Lightweight Transformer (OPTIMIZED Embeddings) " + "="*10)
    transformer_opt_results = train_eval_transformer(
        X_train_opt, y_train, X_val_opt, y_val, X_test_opt, y_test,
        class_weights_tensor, logger, input_dim=args.contrastive_output_dim, 
        num_classes=args.num_classes, classifier_config=classifier_config, device=device,
        title_suffix=" (Optimized)", evaluate_only=args.eval_only_transformer_opt,
        model_save_path=args.transformer_opt_model_save_path
    )
    if transformer_opt_results[0] is not None:
        transformer_opt_auc, transformer_opt_f1_minority, transformer_opt_f1_overall, transformer_opt_precision, transformer_opt_recall, _, _ = transformer_opt_results
    else:
        transformer_opt_auc, transformer_opt_f1_minority, transformer_opt_f1_overall, transformer_opt_precision, transformer_opt_recall = 0,0,0,0,0


    logger.info("\n" + "="*20 + " FINAL EXPERIMENT SUMMARY " + "="*20)
    logger.info("Performance on RAW Embeddings:")
    logger.info(f"  SVM:                   AUC={svm_raw_test_auc:.4f}, F1 (Class 0)={svm_raw_test_f1_minority:.4f}, Overall F1={svm_raw_test_f1_overall:.4f}, Precision={svm_raw_test_precision:.4f}, Recall={svm_raw_test_recall:.4f}")
    logger.info(f"  Random Forest:         AUC={rf_raw_test_auc:.4f}, F1 (Class 0)={rf_raw_test_f1_minority:.4f}, Overall F1={rf_raw_test_f1_overall:.4f}, Precision={rf_raw_test_precision:.4f}, Recall={rf_raw_test_recall:.4f}")
    logger.info(f"  Logistic Regression:   AUC={lr_raw_test_auc:.4f}, F1 (Class 0)={lr_raw_test_f1_minority:.4f}, Overall F1={lr_raw_test_f1_overall:.4f}, Precision={lr_raw_test_precision:.4f}, Recall={lr_raw_test_recall:.4f}")
    logger.info(f"  XGBoost:               AUC={xgb_raw_test_auc:.4f}, F1 (Class 0)={xgb_raw_test_f1_minority:.4f}, Overall F1={xgb_raw_test_f1_overall:.4f}, Precision={xgb_raw_test_precision:.4f}, Recall={xgb_raw_test_recall:.4f}")
    if transformer_raw_auc is not None and transformer_raw_auc > 0 : # Check if evaluation was successful
        logger.info(f"  Lightweight Transformer: AUC={transformer_raw_auc:.4f}, F1 (Class 0)={transformer_raw_f1_minority:.4f}, Overall F1={transformer_raw_f1_overall:.4f}, Precision={transformer_raw_precision:.4f}, Recall={transformer_raw_recall:.4f}")
    else:
        logger.info("  Lightweight Transformer (RAW) evaluation incomplete or failed.")

    logger.info("\nPerformance on OPTIMIZED Embeddings:")
    logger.info(f"  SVM:                   AUC={svm_opt_test_auc:.4f}, F1 (Class 0)={svm_opt_test_f1_minority:.4f}, Overall F1={svm_opt_test_f1_overall:.4f}, Precision={svm_opt_test_precision:.4f}, Recall={svm_opt_test_recall:.4f}")
    logger.info(f"  Random Forest:         AUC={rf_opt_test_auc:.4f}, F1 (Class 0)={rf_opt_test_f1_minority:.4f}, Overall F1={rf_opt_test_f1_overall:.4f}, Precision={rf_opt_test_precision:.4f}, Recall={rf_opt_test_recall:.4f}")
    logger.info(f"  Logistic Regression:   AUC={lr_opt_test_auc:.4f}, F1 (Class 0)={lr_opt_test_f1_minority:.4f}, Overall F1={lr_opt_test_f1_overall:.4f}, Precision={lr_opt_test_precision:.4f}, Recall={lr_opt_test_recall:.4f}")
    logger.info(f"  XGBoost:               AUC={xgb_opt_test_auc:.4f}, F1 (Class 0)={xgb_opt_test_f1_minority:.4f}, Overall F1={xgb_opt_test_f1_overall:.4f}, Precision={xgb_opt_test_precision:.4f}, Recall={xgb_opt_test_recall:.4f}")
    if transformer_opt_auc is not None and transformer_opt_auc > 0:
        logger.info(f"  Lightweight Transformer: AUC={transformer_opt_auc:.4f}, F1 (Class 0)={transformer_opt_f1_minority:.4f}, Overall F1={transformer_opt_f1_overall:.4f}, Precision={transformer_opt_precision:.4f}, Recall={transformer_opt_recall:.4f}")
    else:
        logger.info("  Lightweight Transformer (Optimized) evaluation incomplete or failed.")


    logger.info("\n" + "="*20 + " Stratified K-Fold Cross-Validation on Test Set " + "="*20)
    skf = StratifiedKFold(n_splits=args.n_splits_kfold, shuffle=True, random_state=42)
    logger.info(f"Performing {args.n_splits_kfold}-fold stratified cross-validation on the test set...")

    metric_keys = ['auc', 'f1_minority', 'f1_overall', 'precision', 'recall']
    per_class_metric_keys = ['f1_per_class']
    if args.enable_magnitude_loss_metrics:
        metric_keys.extend(['mse', 'rmse', 'r2'])
        per_class_metric_keys.append('mse_per_class')

    cv_metrics = {}
    model_names_for_cv = [
        'svm_raw', 'rf_raw', 'lr_raw', 'xgb_raw', 'transformer_raw',
        'svm_opt', 'rf_opt', 'lr_opt', 'xgb_opt', 'transformer_opt'
    ]
    for model_name in model_names_for_cv:
        cv_metrics[model_name] = {key: [] for key in metric_keys}
        for pk_key in per_class_metric_keys:
             cv_metrics[model_name][pk_key] = {i: [] for i in range(args.num_classes)}


    for fold, (train_idx_ignore, test_idx) in enumerate(skf.split(X_test_raw, y_test)): # train_idx_ignore as we only evaluate on test folds
        logger.info(f"\n--- Fold {fold + 1}/{args.n_splits_kfold} ---")
        
        X_test_raw_fold, y_test_fold = X_test_raw[test_idx], y_test[test_idx]
        X_test_opt_fold = X_test_opt[test_idx] 


        models_to_eval_fold = {
            'svm_raw': svm_raw, 'rf_raw': rf_raw, 'lr_raw': log_reg_raw, 'xgb_raw': xgb_model_raw,
            'svm_opt': svm_opt, 'rf_opt': rf_opt, 'lr_opt': log_reg_opt, 'xgb_opt': xgb_model_opt
        }
        
        data_for_models = {
            'svm_raw': X_test_raw_fold, 'rf_raw': X_test_raw_fold, 'lr_raw': X_test_raw_fold, 'xgb_raw': X_test_raw_fold,
            'svm_opt': X_test_opt_fold, 'rf_opt': X_test_opt_fold, 'lr_opt': X_test_opt_fold, 'xgb_opt': X_test_opt_fold
        }

        for model_key, model_instance in models_to_eval_fold.items():
            logger.debug(f"Evaluating {model_key} on fold {fold+1}")
            current_X_test_fold = data_for_models[model_key]
            
            y_pred_fold = model_instance.predict(current_X_test_fold)
            y_prob_fold = model_instance.predict_proba(current_X_test_fold)

            if args.num_classes == 2:
                y_prob_fold = y_prob_fold[:, 1]  # Extract probabilities for positive class
                if y_pred_fold.ndim > 1:  # Check if y_pred_fold is one-hot or multilabel
                    y_pred_fold = np.argmax(y_pred_fold, axis=1)  # Convert to binary labels
                cv_metrics[model_key]['auc'].append(roc_auc_score(y_test_fold, y_prob_fold))
            else:
                cv_metrics[model_key]['auc'].append(roc_auc_score(y_test_fold, y_prob_fold, multi_class='ovr'))
            f1_s_fold = f1_score(y_test_fold, y_pred_fold, average=None, zero_division=0)
            cv_metrics[model_key]['f1_minority'].append(f1_s_fold[0] if len(f1_s_fold) > 0 else 0.0)
            cv_metrics[model_key]['f1_overall'].append(f1_score(y_test_fold, y_pred_fold, average='weighted', zero_division=0))
            cv_metrics[model_key]['precision'].append(precision_score(y_test_fold, y_pred_fold, average='weighted', zero_division=0))
            cv_metrics[model_key]['recall'].append(recall_score(y_test_fold, y_pred_fold, average='weighted', zero_division=0))
            for i in range(args.num_classes):
                if i < len(f1_s_fold):
                    cv_metrics[model_key]['f1_per_class'][i].append(f1_s_fold[i])
                else: # Should not happen if num_classes is correct
                    cv_metrics[model_key]['f1_per_class'][i].append(0.0)


            if args.enable_magnitude_loss_metrics:
                mse_val, rmse_val, r2_val, mse_per_class_val = compute_regression_metrics(y_test_fold, y_pred_fold)
                cv_metrics[model_key]['mse'].append(mse_val)
                cv_metrics[model_key]['rmse'].append(rmse_val)
                cv_metrics[model_key]['r2'].append(r2_val)
                for i in range(args.num_classes):
                    cv_metrics[model_key]['mse_per_class'][i].append(mse_per_class_val.get(i, 0.0)) # Use .get for safety

        # Transformer (RAW) K-fold
        logger.debug(f"Evaluating transformer_raw on fold {fold+1}")
        transformer_raw_fold_results = train_eval_transformer(
            X_train_raw, y_train, X_val_raw, y_val, X_test_raw_fold, y_test_fold,
            class_weights_tensor, logger, input_dim=args.contrastive_input_dim, 
            num_classes=args.num_classes, classifier_config=classifier_config, device=device,
            title_suffix=f" (RAW) Fold {fold+1}", evaluate_only=True, 
            model_save_path=args.transformer_raw_model_save_path 
        )
        if transformer_raw_fold_results[0] is not None:
            auc_tr_raw, f1m_tr_raw, f1o_tr_raw, p_tr_raw, r_tr_raw, pred_tr_raw, prob_tr_raw = transformer_raw_fold_results
            cv_metrics['transformer_raw']['auc'].append(auc_tr_raw)
            f1s_tr_raw = f1_score(y_test_fold, pred_tr_raw, average=None, zero_division=0)
            cv_metrics['transformer_raw']['f1_minority'].append(f1s_tr_raw[0] if len(f1s_tr_raw)>0 else 0.0)
            cv_metrics['transformer_raw']['f1_overall'].append(f1o_tr_raw)
            cv_metrics['transformer_raw']['precision'].append(p_tr_raw)
            cv_metrics['transformer_raw']['recall'].append(r_tr_raw)
            for i in range(args.num_classes):
                if i < len(f1s_tr_raw): cv_metrics['transformer_raw']['f1_per_class'][i].append(f1s_tr_raw[i])
                else: cv_metrics['transformer_raw']['f1_per_class'][i].append(0.0)
            if args.enable_magnitude_loss_metrics:
                mse_val, rmse_val, r2_val, mse_per_class_val = compute_regression_metrics(y_test_fold, np.array(pred_tr_raw))
                cv_metrics['transformer_raw']['mse'].append(mse_val)
                cv_metrics['transformer_raw']['rmse'].append(rmse_val)
                cv_metrics['transformer_raw']['r2'].append(r2_val)
                for i in range(args.num_classes): cv_metrics['transformer_raw']['mse_per_class'][i].append(mse_per_class_val.get(i,0.0))
        else: # Append zeros or NaNs if evaluation failed
            for key_m in metric_keys: cv_metrics['transformer_raw'][key_m].append(0.0 if key_m not in ['f1_per_class', 'mse_per_class'] else {i:0.0 for i in range(args.num_classes)})
            for key_pk in per_class_metric_keys : cv_metrics['transformer_raw'][key_pk] = {i : cv_metrics['transformer_raw'][key_pk].get(i, []) + [0.0] for i in range(args.num_classes)}


        # Transformer (OPTIMIZED) K-fold
        logger.debug(f"Evaluating transformer_opt on fold {fold+1}")
        transformer_opt_fold_results = train_eval_transformer(
            X_train_opt, y_train, X_val_opt, y_val, X_test_opt_fold, y_test_fold,
            class_weights_tensor, logger, input_dim=args.contrastive_output_dim,
            num_classes=args.num_classes, classifier_config=classifier_config, device=device,
            title_suffix=f" (Optimized) Fold {fold+1}", evaluate_only=True,
            model_save_path=args.transformer_opt_model_save_path
        )
        if transformer_opt_fold_results[0] is not None:
            auc_tr_opt, f1m_tr_opt, f1o_tr_opt, p_tr_opt, r_tr_opt, pred_tr_opt, prob_tr_opt = transformer_opt_fold_results
            cv_metrics['transformer_opt']['auc'].append(auc_tr_opt)
            f1s_tr_opt = f1_score(y_test_fold, pred_tr_opt, average=None, zero_division=0)
            cv_metrics['transformer_opt']['f1_minority'].append(f1s_tr_opt[0] if len(f1s_tr_opt)>0 else 0.0)
            cv_metrics['transformer_opt']['f1_overall'].append(f1o_tr_opt)
            cv_metrics['transformer_opt']['precision'].append(p_tr_opt)
            cv_metrics['transformer_opt']['recall'].append(r_tr_opt)
            for i in range(args.num_classes):
                if i < len(f1s_tr_opt): cv_metrics['transformer_opt']['f1_per_class'][i].append(f1s_tr_opt[i])
                else: cv_metrics['transformer_opt']['f1_per_class'][i].append(0.0)
            if args.enable_magnitude_loss_metrics:
                mse_val, rmse_val, r2_val, mse_per_class_val = compute_regression_metrics(y_test_fold, np.array(pred_tr_opt))
                cv_metrics['transformer_opt']['mse'].append(mse_val)
                cv_metrics['transformer_opt']['rmse'].append(rmse_val)
                cv_metrics['transformer_opt']['r2'].append(r2_val)
                for i in range(args.num_classes): cv_metrics['transformer_opt']['mse_per_class'][i].append(mse_per_class_val.get(i,0.0))
        else: # Append zeros or NaNs if evaluation failed
            for key_m in metric_keys: cv_metrics['transformer_opt'][key_m].append(0.0 if key_m not in ['f1_per_class', 'mse_per_class'] else {i:0.0 for i in range(args.num_classes)})
            for key_pk in per_class_metric_keys : cv_metrics['transformer_opt'][key_pk] = {i : cv_metrics['transformer_opt'][key_pk].get(i, []) + [0.0] for i in range(args.num_classes)}


    logger.info("\n" + "="*20 + " Cross-Validation Statistical Summary " + "="*20)
    for model_name_cv in cv_metrics.keys():
        logger.info(f"\nModel: {model_name_cv.upper()}")
        for metric_name in metric_keys:
            if metric_name in cv_metrics[model_name_cv] and cv_metrics[model_name_cv][metric_name]:
                 mean_val = np.mean(cv_metrics[model_name_cv][metric_name])
                 std_val = np.std(cv_metrics[model_name_cv][metric_name])
                 logger.info(f"  {metric_name.replace('_', ' ').capitalize()}: Mean={mean_val:.4f}, Std={std_val:.4f}")
            elif metric_name in cv_metrics[model_name_cv] : # List might be empty if not computed
                 logger.info(f"  {metric_name.replace('_', ' ').capitalize()}: Not computed or all folds failed.")


        for pk_metric_name in per_class_metric_keys:
            if pk_metric_name in cv_metrics[model_name_cv] and cv_metrics[model_name_cv][pk_metric_name]:
                for i in range(args.num_classes):
                    class_metric_values = cv_metrics[model_name_cv][pk_metric_name].get(i, [])
                    if class_metric_values:
                        mean_val = np.mean(class_metric_values)
                        std_val = np.std(class_metric_values)
                        logger.info(f"  {pk_metric_name.replace('_', ' ').capitalize()} (Class {i}): Mean={mean_val:.4f}, Std={std_val:.4f}")
                    else:
                        logger.info(f"  {pk_metric_name.replace('_', ' ').capitalize()} (Class {i}): Not computed or all folds failed.")
            

    logger.info("\n" + "="*20 + " GEOMETRIC QUALITY ANALYSIS OF EMBEDDINGS " + "="*20)
    logger.info("Analyzing geometric properties of RAW vs OPTIMIZED embeddings using K-fold splits from test set...")

    geometric_metric_keys_main = ['within_between_ratio', 'silhouette', 'overlap', 'fisher_ratio']
    geometric_metrics_cv = {
        'raw': {key: [] for key in geometric_metric_keys_main},
        'opt': {key: [] for key in geometric_metric_keys_main}
    }

    for fold, (train_idx_ignore, test_idx) in enumerate(skf.split(X_test_raw, y_test)):
        logger.info(f"\n--- Geometric Analysis: Fold {fold + 1}/{args.n_splits_kfold} ---")
        X_test_raw_fold, y_test_fold = X_test_raw[test_idx], y_test[test_idx]
        X_test_opt_fold = X_test_opt[test_idx]
        
        if len(np.unique(y_test_fold)) < 2:
            logger.warning(f"Fold {fold+1} has less than 2 unique classes in y_test_fold, skipping silhouette and some variance calculations for this fold.")
            raw_silhouette = 0.0
            opt_silhouette = 0.0
        else:
            raw_silhouette = silhouette_score(X_test_raw_fold, y_test_fold)
            opt_silhouette = silhouette_score(X_test_opt_fold, y_test_fold)
            
        raw_variance = calculate_variance_metrics(X_test_raw_fold, y_test_fold)
        raw_overlap = calculate_class_overlap(X_test_raw_fold, y_test_fold)
        raw_fisher = calculate_fisher_ratio(X_test_raw_fold, y_test_fold)
        
        geometric_metrics_cv['raw']['within_between_ratio'].append(raw_variance['ratio'])
        geometric_metrics_cv['raw']['silhouette'].append(raw_silhouette)
        geometric_metrics_cv['raw']['overlap'].append(raw_overlap)
        geometric_metrics_cv['raw']['fisher_ratio'].append(raw_fisher)
        
        opt_variance = calculate_variance_metrics(X_test_opt_fold, y_test_fold)
        opt_overlap = calculate_class_overlap(X_test_opt_fold, y_test_fold)
        opt_fisher = calculate_fisher_ratio(X_test_opt_fold, y_test_fold)

        geometric_metrics_cv['opt']['within_between_ratio'].append(opt_variance['ratio'])
        geometric_metrics_cv['opt']['silhouette'].append(opt_silhouette)
        geometric_metrics_cv['opt']['overlap'].append(opt_overlap)
        geometric_metrics_cv['opt']['fisher_ratio'].append(opt_fisher)

        logger.info(f"RAW Embeddings - Fold {fold + 1}: Within/Between ratio: {raw_variance['ratio']:.4f}, Silhouette: {raw_silhouette:.4f}, Overlap: {raw_overlap:.4f}, Fisher: {raw_fisher:.4f}")
        logger.info(f"OPTIMIZED Embeddings - Fold {fold + 1}: Within/Between ratio: {opt_variance['ratio']:.4f}, Silhouette: {opt_silhouette:.4f}, Overlap: {opt_overlap:.4f}, Fisher: {opt_fisher:.4f}")

    logger.info("\n" + "="*20 + " GEOMETRIC QUALITY ANALYSIS SUMMARY (AVERAGE OVER K-FOLDS) " + "="*20)
    summary_lines = []
    for embed_type in ['raw', 'opt']:
        logger.info(f"{embed_type.upper()} Embeddings - Average across all folds:")
        for metric_key_geo in geometric_metric_keys_main:
            mean_val = np.mean(geometric_metrics_cv[embed_type][metric_key_geo])
            std_val = np.std(geometric_metrics_cv[embed_type][metric_key_geo])
            direction = ""
            if metric_key_geo in ['within_between_ratio', 'overlap']: direction = "(lower is better)"
            elif metric_key_geo in ['silhouette', 'fisher_ratio']: direction = "(higher is better)"
            log_line = f"  {metric_key_geo.replace('_', ' ').capitalize()}: {mean_val:.4f} ± {std_val:.4f} {direction}"
            logger.info(log_line)
            if embed_type == 'opt': summary_lines.append(log_line) # For final comparison

    logger.info("\nOverall Geometric Quality Improvement after Optimization (based on K-Fold averages):")
    raw_means = {k: np.mean(geometric_metrics_cv['raw'][k]) for k in geometric_metric_keys_main}
    opt_means = {k: np.mean(geometric_metrics_cv['opt'][k]) for k in geometric_metric_keys_main}

    for key_geo in geometric_metric_keys_main:
        raw_m, opt_m = raw_means[key_geo], opt_means[key_geo]
        improvement = float('inf')
        change_type = "change"
        if key_geo in ['within_between_ratio', 'overlap']: # lower is better
            if raw_m != 0: improvement = ((raw_m - opt_m) / abs(raw_m)) * 100
            change_type = "decrease" if opt_m < raw_m else "increase"
            if opt_m < raw_m : improvement_direction_good = True 
            else: improvement_direction_good = False
        elif key_geo in ['silhouette', 'fisher_ratio']: # higher is better
            if raw_m != 0: improvement = ((opt_m - raw_m) / abs(raw_m)) * 100
            change_type = "increase" if opt_m > raw_m else "decrease"
            if opt_m > raw_m : improvement_direction_good = True
            else: improvement_direction_good = False
        
        if improvement == float('inf') or abs(raw_m) < 1e-9 : # Handle division by zero or near-zero raw_m
            logger.info(f"  {key_geo.replace('_',' ').capitalize()}: RAW={raw_m:.4f}, OPT={opt_m:.4f} (improvement N/A or RAW is zero)")
        else:
            logger.info(f"  {key_geo.replace('_',' ').capitalize()}: {improvement:.2f}% {change_type} (RAW={raw_m:.4f}, OPT={opt_m:.4f}) {'[Improved]' if improvement_direction_good else '[Not Improved]' if improvement !=0 else '[No Change]'}")


    logger.info("\nEvaluation script finished.\n")