#!/usr/bin/env python3
"""
HSBench: MLP外推变体实验

Linear回归外推:
  - 在锚点上拟合1D/ND Linear回归来预测目标模型的0/1正确性。
  - 提供五种特征输入：
      (a) mlp_logit_avg: 源模型MLP概率在源上取均值（s_bar）
      (b) mlp_label_avg: 源模型原始0/1得分在源上取均值（y_bar）
      (c) mlp_logit_vec: 源模型MLP logits向量（每源一个logit值，logits[:,1]）
      (d) mlp_label_vec: 源模型原始0/1得分向量（每源一个0/1值）
      (e) question_embed_vec: 题目embedding向量（输入随聚类方式变化：mean=题目级向量；concat=按source拼接后flatten）

基线方法:
  - naive_mean: 目标模型锚点正确率均值（anchor mean baseline）
  - apw_weighted: 基于聚类簇大小的加权平均
  - tailoredbench_scaling: 簇内TailoredBench风格缩放（基线保留）

MLP分数统一计算:
  - 无论聚类方法如何（mean/concat），我们始终使用concat风格的每源
    MLP概率作为基础分数（每模型×问题 -> 每源分数矩阵）。
"""

import warnings

warnings.simplefilter("ignore", category=FutureWarning)
warnings.simplefilter("ignore", category=DeprecationWarning)

import os
import sys
import time
import signal
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from rich.console import Console

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

console = Console()

try:
    # 原脚本位于项目同一目录中
    from HSbench_cukm_01TB_review import (
        HSExperiment as _BaseHSExperiment,
        _child_init_with_shared,
    )
except Exception:
    raise RuntimeError("无法导入基础HSExperiment模块，请确保HSbench_cukm_01TB_review.py存在于同一目录中。")


# 全局共享数据，用于spawn工作进程
SHARED = None

def _child_init_with_shared_local(pid_q, shared):
    """封装原始初始化器，同时设置我们模块级别的SHARED。"""
    global SHARED
    _child_init_with_shared(pid_q, shared)
    SHARED = shared


def _process_combo_task_with_shared_variants(args_tuple):
    """工作进程入口：实例化我们的实验子类并运行一个组合。"""
    (
        combo_idx,
        train_labels_dict,
        test_labels_dict,
        train_models,
        test_models,
        dataset,
        feats_dir,
        output_dir,
        device,
        layer_type,
        operate_type,
        anchor_points,
        true_accuracies,
        label_data_dir,
        use_expand,
        strict_id_alignment,
        gpu_id,
        num_runs_per_anchor,
        test_features_for_similarity,
        exclude_extreme_samples,
        exclude_extreme_samples_predict_all,
        dynamic_source_selection,
        lambda_max,
        num_models,
        svd_retain,
        svd_for_clustering,
        svd_exclude_dominant_ratio,
        svd_spectral_shrinkage,
        svd_specific_components,
        use_zscore,
        use_euclidean,
        ridge_alpha,
        decorr_lambda,
        model_strategy,
        family,
        strength,
    ) = args_tuple

    # 在子进程中设置环境变量，防止CUDA/BLAS线程冲突
    os.environ["OMP_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"

    # 绑定GPU（与原脚本相同）
    if gpu_id is not None and torch.cuda.is_available():
        torch.cuda.set_device(gpu_id)
        try:
            import cupy as cp

            cp.cuda.Device(gpu_id).use()
        except Exception:
            pass
        device = f"cuda:{gpu_id}"

    exp = HSExperimentMLPVariants(
        dataset,
        feats_dir,
        output_dir,
        device,
        label_data_dir,
        use_expand=use_expand,
        lambda_max=lambda_max,
        num_models=num_models,
        svd_retain=float(svd_retain),
        svd_for_clustering=bool(svd_for_clustering),
        svd_exclude_dominant_ratio=svd_exclude_dominant_ratio,
        svd_spectral_shrinkage=svd_spectral_shrinkage,
        svd_specific_components=svd_specific_components,
        use_zscore=use_zscore,
        use_euclidean=use_euclidean,
        ridge_alpha=ridge_alpha,
        decorr_lambda=decorr_lambda,
        model_strategy=model_strategy,
        family=family,
        strength=strength,
    )

    try:
        global SHARED
        result = exp._process_single_combo_with_shared(
            combo_idx=combo_idx,
            train_labels_dict=train_labels_dict,
            test_labels_dict=test_labels_dict,
            train_models=train_models,
            test_models=test_models,
            layer_type=layer_type,
            operate_type=operate_type,
            anchor_points=anchor_points,
            true_accuracies=true_accuracies,
            strict_id_alignment=strict_id_alignment,
            num_runs_per_anchor=num_runs_per_anchor,
            test_features_for_similarity=test_features_for_similarity,
            shared_feature_data=SHARED,
            exclude_extreme_samples=exclude_extreme_samples,
            exclude_extreme_samples_predict_all=exclude_extreme_samples_predict_all,
            dynamic_source_selection=dynamic_source_selection,
            family=family,
            strength=strength,
        )
        return combo_idx, result
    except Exception as e:
        console.print(f"[red]组合 {combo_idx} 处理失败: {e}[/red]")
        import traceback
        traceback.print_exc()
        return combo_idx, None
    finally:
        exp.cleanup_resources()


class HSExperimentMLPVariants(_BaseHSExperiment):
    """仅重写评估/保存以支持MLP外推变体。"""

    def __init__(
        self,
        dataset,
        feats_dir,
        output_dir,
        device,
        label_data_dir,
        use_expand: bool = False,
        lambda_max: Optional[float] = None,
        num_models: Optional[int] = None,
        svd_retain: float = 0.95,
        svd_for_clustering: bool = False,
        svd_exclude_dominant_ratio: Optional[float] = None,
        svd_spectral_shrinkage: bool = False,
        svd_specific_components: Optional[List[int]] = None,
        use_zscore: bool = False,
        use_euclidean: bool = False,
        ridge_alpha: float = 1.0,
        decorr_lambda: float = 0.0,
        model_strategy: str = None,
        family: str = None,
        strength: str = None,
    ):
        """初始化实验。

        注意：由于 spawn 模式下子进程不共享父进程状态，SVD 相关参数必须显式
        通过任务参数传入子进程，并在子进程中构造实例时同步设置。
        """
        super().__init__(dataset, feats_dir, output_dir, device, label_data_dir, use_expand, lambda_max)
        self.num_models = num_models
        self.svd_retain = float(svd_retain)
        self.svd_for_clustering = bool(svd_for_clustering)
        self.svd_exclude_dominant_ratio = float(svd_exclude_dominant_ratio) if svd_exclude_dominant_ratio is not None else None
        self.svd_spectral_shrinkage = bool(svd_spectral_shrinkage)
        self.svd_specific_components = svd_specific_components
        self.use_zscore = bool(use_zscore)
        self.use_euclidean = bool(use_euclidean)
        self.ridge_alpha = float(ridge_alpha)
        self.decorr_lambda = float(decorr_lambda)
        self.model_strategy = model_strategy
        self.family = family
        self.strength = strength

    def get_mlp_model_path(
        self, combo_idx: int, layer_type: str = "last", operate_type: str = "last_token", use_expand: bool = False,
        lambda_max: float = None, num_models: int = None, model_strategy: str = None,
        family: str = None, strength: str = None
    ) -> str:
        """
        生成基于combo索引的MLP模型路径，使用combo[n]命名格式
        覆盖基类方法以支持 decorr_lambda、num_models、model_strategy、family 和 strength
        """
        # 包含数据集名称、layer_type和operate_type以避免冲突
        model_filename = f"combo{combo_idx}.pth"

        # 根据use_expand决定是否添加_expand后缀
        if use_expand:
            # 使用expand loss时，需要根据超参数查找子目录
            if lambda_max is None:
                raise ValueError("使用 use_expand=True 时，必须提供 lambda_max 参数")

            # 使用 .3g 格式化浮点数，保留3位有效数字
            lambda_str = f"{lambda_max:.3g}"

            # 路径: mlp_models_expand/lambda_X/dataset/layer_type/operate_type/comboN.pth
            console.print(f"[blue]加载MLP模型路径: use_expand=True, lambda={lambda_str}[/blue]")
            return os.path.join(
                PROJECT_ROOT,
                "main_experiment/results",
                "mlp_models_expand",
                f"lambda_{lambda_str}",
                self.dataset,
                layer_type,
                operate_type,
                model_filename,
            )
        elif self.decorr_lambda > 0:
            # 使用 decorr loss
            lambda_str = f"{self.decorr_lambda:.3g}"
            console.print(f"[blue]加载MLP模型路径: use_decorr=True, lambda={lambda_str}[/blue]")
            return os.path.join(
                PROJECT_ROOT,
                "main_experiment/results",
                "mlp_models_decorr",
                f"lambda_{lambda_str}",
                self.dataset,
                layer_type,
                operate_type,
                model_filename,
            )
        else:
            # 普通训练
            # 优先级: model_strategy + family/strength > num_models > 默认
            if model_strategy is not None:
                if model_strategy == 'family_diverse' and family is not None:
                    base_dir = f"mlp_models_{model_strategy}/{family}"
                elif model_strategy == 'strength' and strength is not None:
                    base_dir = f"mlp_models_{model_strategy}/{strength}"
                elif model_strategy == 'temporal_shift':
                    base_dir = f"mlp_models_{model_strategy}"
                else:
                    base_dir = f"mlp_models_{model_strategy}"
                console.print(f"[blue]加载MLP模型路径: model_strategy={model_strategy}, family={family}, strength={strength}[/blue]")
            elif num_models is not None:
                base_dir = f"mlp_models_{num_models}models"
                console.print(f"[blue]加载MLP模型路径: num_models={num_models}[/blue]")
            else:
                base_dir = "mlp_models"
                console.print("[blue]加载MLP模型路径: use_expand=False, num_models=None, model_strategy=None[/blue]")

            return os.path.join(
                PROJECT_ROOT,
                "main_experiment/results",
                base_dir,
                self.dataset,
                layer_type,
                operate_type,
                model_filename,
            )

    def compute_question_representations(
        self, embeddings_dict: Dict[str, torch.Tensor], method: str = "mean"
    ) -> Dict[int, torch.Tensor]:
        """
        计算题目表征（向量化优化版本）。
        根据 self.use_zscore 和 self.use_euclidean 决定是否进行归一化。

        Args:
            embeddings_dict: {model_name: embeddings} 形式的字典，embeddings shape: (n_samples, embed_dim)
            method: "mean" 或 "concat"

        Returns:
            {question_id: representation} 形式的字典
        """
        n_samples = list(embeddings_dict.values())[0].shape[0]

        if method == "mean":
            # 方法1: 使用训练集模型嵌入表示的均值（向量化实现）
            model_names = list(embeddings_dict.keys())
            stacked_embeddings = torch.stack(
                [embeddings_dict[model_name] for model_name in model_names], dim=0
            )  # (n_models, n_samples, embed_dim)

            # 沿模型维度计算均值：(n_samples, embed_dim)
            mean_representations = torch.mean(stacked_embeddings, dim=0)

            # --- Z-score ---
            if self.use_zscore:
                eps = 1e-6
                mean = torch.mean(mean_representations, dim=0, keepdim=True)
                std = torch.std(mean_representations, dim=0, keepdim=True)
                mean_representations = (mean_representations - mean) / (std + eps)

            # --- L2 Normalize ---
            # 如果不使用欧氏距离 (即使用 Cosine)，则进行 L2 Normalize
            # 如果使用欧氏距离，则保留幅值 (不做 L2 Normalize)
            if not self.use_euclidean:
                mean_representations = torch.nn.functional.normalize(mean_representations, p=2, dim=1)

            # 转换为字典格式
            question_reps = {i: mean_representations[i] for i in range(n_samples)}

        elif method == "concat":
            # 方法2: 使用训练集模型嵌入的拼接（向量化实现）
            model_names = list(embeddings_dict.keys())

            # 沿最后一个维度拼接：(n_samples, n_models * embed_dim)
            concat_representations = torch.cat([embeddings_dict[model_name] for model_name in model_names], dim=1)

            # --- Z-score ---
            if self.use_zscore:
                eps = 1e-6
                mean = torch.mean(concat_representations, dim=0, keepdim=True)
                std = torch.std(concat_representations, dim=0, keepdim=True)
                concat_representations = (concat_representations - mean) / (std + eps)

            # --- L2 Normalize ---
            if not self.use_euclidean:
                concat_representations = torch.nn.functional.normalize(concat_representations, p=2, dim=1)

            # 转换为字典格式
            question_reps = {i: concat_representations[i] for i in range(n_samples)}

        else:
            raise ValueError(f"Unsupported method: {method}")

        return question_reps

    def select_anchor_questions(
        self,
        question_reps: Dict[int, torch.Tensor],
        cluster_dict: Dict[int, List[int]],
        kmeans_model,
        n_anchors: int,
    ) -> List[int]:
        """
        每个簇选择一个最接近质心的题目作为锚点。

        根据 self.use_euclidean 决定度量方式：
        - True (Euclidean): 计算欧氏距离，选 ArgMin (最小距离)。
        - False (Cosine): 计算余弦相似度，选 ArgMax (最大相似度)。

        Args:
            question_reps: 题目表征字典
            cluster_dict: 聚类结果，cluster数量等于n_anchors
            kmeans_model: KMeans模型（包含质心信息）
            n_anchors: 要选择的锚点数量（等于聚类数量）

        Returns:
            锚点题目ID列表
        """
        question_ids = list(question_reps.keys())

        # 获取所有非空簇
        non_empty_clusters = {
            cluster_id: questions
            for cluster_id, questions in cluster_dict.items()
            if len(questions) > 0
        }

        if not non_empty_clusters:
            console.print("[red]警告: 没有非空聚类，随机选择锚点[/red]")
            return np.random.choice(
                question_ids, min(n_anchors, len(question_ids)), replace=False
            ).tolist()

        # 获取质心
        centroids = None
        if hasattr(kmeans_model, "cluster_centers_"):
            centroids = kmeans_model.cluster_centers_
            if hasattr(centroids, "get"):
                centroids = centroids.get()  # cupy -> numpy
            centroids = np.asarray(centroids)

        # 从每个簇中选择一个最接近质心的题目
        anchor_questions = []

        for cluster_id, cluster_questions in non_empty_clusters.items():
            if len(cluster_questions) == 0:
                continue

            if centroids is not None and cluster_id < len(centroids):
                centroid = centroids[cluster_id]

                # 收集该簇内所有向量 & 对应 id
                reps = []
                ids = []
                for qid in cluster_questions:
                    rep = question_reps[qid]
                    if hasattr(rep, "cpu"):
                        rep = rep.cpu().numpy()
                    reps.append(rep)
                    ids.append(qid)

                reps = np.stack(reps, axis=0)  # (N, dim)

                if self.use_euclidean:
                    # Euclidean 模式: 欧氏距离，越小越好 (ArgMin)
                    # 直接计算 ||x - c||
                    diff = reps - centroid
                    dists = np.linalg.norm(diff, axis=1)  # (N,)
                    best_idx = int(np.argmin(dists))

                else:
                    # Cosine 模式: 余弦相似度，越大越好 (ArgMax)
                    # 遵循基类逻辑: 归一化质心 -> 点积
                    norm = np.linalg.norm(centroid) + 1e-12
                    centroid_norm = centroid / norm

                    # 假设 reps 已经归一化 (在 compute_question_representations 中保证)
                    # 计算点积
                    similarities = np.dot(reps, centroid_norm) # (N,)
                    best_idx = int(np.argmax(similarities))

                closest_question = ids[best_idx]
                anchor_questions.append(closest_question)
            else:
                # 如果没有质心信息，随机选择一个
                selected_question = np.random.choice(cluster_questions)
                anchor_questions.append(selected_question)
        return anchor_questions

    def _generate_experiment_dir_name(self, *args, **kwargs) -> str:
        """在父类目录名基础上追加配置信息。

        命名格式统一为: {base}[_zscore]_{l2|cos}[_svd...]
        与 HSbench_logistic_logits_exploration.py 保持一致
        """
        base = super()._generate_experiment_dir_name(*args, **kwargs)

        # Z-score 标记
        if self.use_zscore:
            base = f"{base}_zscore"

        # 距离度量标记
        if self.use_euclidean:
            base = f"{base}_l2"  # L2 Distance (Euclidean)
        else:
            base = f"{base}_cos"  # Cosine Similarity

        # SVD 配置标记
        if self.svd_for_clustering:
            if self.svd_specific_components is not None:
                comps_str = "_".join(map(str, self.svd_specific_components))
                return f"{base}_svd_spec{comps_str}"
            if self.svd_spectral_shrinkage:
                return f"{base}_svd_shrink"
            if self.svd_exclude_dominant_ratio is not None:
                return f"{base}_svd_excl{self.svd_exclude_dominant_ratio:.2g}"
            retain_str = f"{self.svd_retain:.3g}"
            return f"{base}_svd{retain_str}"

        return base

    def _svd_reduce_matrix(
        self,
        X: np.ndarray,
        retain: float = 0.95,
        return_rank: bool = False,
        exclude_dominant_ratio: Optional[float] = None,
        specific_components: Optional[List[int]] = None,
    ):
        """
        Tangent-space SVD denoising with low-rank reconstruction back to original dim.

        输入:
        X: (n_samples, d)

        输出:
        - return_rank=False: (n_samples, d)  # 返回原维度“去噪向量”，不是k维坐标
        - return_rank=True:  ((n_samples, d), k)  # k的含义取决于分支：
                * retain 分支：截断阶数
                * exclude_dominant_ratio 分支：保留的分量数
                * spectral_shrinkage 分支：K = min(n, d)（不截断，仅收缩）
                * specific_components 分支：保留的分量数 (len(specific_components))
        """
        X = np.asarray(X, dtype=np.float32)
        if X.ndim != 2:
            raise ValueError(f"SVD expects 2D array, got shape={X.shape}")
        n, d = X.shape
        if n <= 1 or d <= 1:
            return (X, d) if return_rank else X

        retain = float(retain)
        if not (0.0 < retain <= 1.0):
            raise ValueError(f"svd retain must be in (0,1], got {retain}")

        eps = 1e-12

        # =========================================================
        # Standard PCA Logic (Euclidean)
        # Replaces Tangent-space SVD
        # =========================================================

        # 1) Column-wise mean centering
        mean_vec = X.mean(axis=0, keepdims=True)
        Xc = X - mean_vec

        # 2) Standard SVD on centered data
        U, S, Vt = np.linalg.svd(Xc.astype(np.float64), full_matrices=False)
        if S.size == 0:
            return (X, 1) if return_rank else X

        var = S ** 2
        total = float(var.sum())
        if not np.isfinite(total) or total <= eps:
            return (X, 1) if return_rank else X

        # -------- branch S: specific components (user specified) --------
        if specific_components is not None:
            # 这里的 components 是 1-based index
            # 需要转换为 0-based index
            # 同时要确保不超过维度范围
            indices = []
            max_idx = Vt.shape[0] - 1
            for c in specific_components:
                idx = c - 1
                if 0 <= idx <= max_idx:
                    indices.append(idx)
                else:
                    # 如果指定的维度超出了实际维度，忽略
                    pass

            if not indices:
                # 如果没有有效的分量，默认保留第一个
                keep_indices = np.array([0], dtype=np.int64)
            else:
                keep_indices = np.array(indices, dtype=np.int64)

            Vk = Vt[keep_indices, :].T  # (d, k_new)
            k = int(keep_indices.size)

            # Reconstruct
            Xc_hat = (Xc.astype(np.float64) @ Vk.astype(np.float64)) @ Vk.astype(np.float64).T

            if not np.isfinite(Xc_hat).all():
                bad = np.sum(~np.isfinite(Xc_hat))
                raise FloatingPointError(
                     f"[SVD specific_components] Non-finite values in Xc_hat: {bad} elements are NaN/Inf. "
                     f"components={specific_components}, kept={k}/{Vt.shape[0]}"
                )

        # -------- branch A: spectral shrinkage (no truncation) --------
        elif getattr(self, "svd_spectral_shrinkage", False):
            # r_i = sigma_i^2 / total
            r = var / (total + eps)

            # shrink factor: sqrt(1 - r_i), with clipping for numerical safety
            shrink = np.sqrt(np.clip(1.0 - r, 0.0, 1.0))
            # (optional) avoid all-zero shrink in extreme cases
            shrink = np.maximum(shrink, 1e-12)

            S_new = S * shrink

            # Reconstruct: Xc_hat = U * diag(S_new) * Vt
            # (U * S_new[None, :]) uses broadcasting to scale columns of U
            Xc_hat = (U * S_new[None, :]) @ Vt
            k = int(S.shape[0])  # keep all components (economy rank)

            if not np.isfinite(Xc_hat).all():
                bad = np.sum(~np.isfinite(Xc_hat))
                raise FloatingPointError(
                    f"[SVD spectral shrinkage] Non-finite values in Xc_hat: {bad} elements are NaN/Inf. "
                )

        # -------- branch B: exclude components with dominant energy ratio --------
        elif exclude_dominant_ratio is not None:
            thr = float(exclude_dominant_ratio)
            if not (0.0 <= thr <= 1.0):
                raise ValueError(f"exclude_dominant_ratio must be in [0,1], got {thr}")

            ratios = var / (total + eps)
            keep_indices = np.where(ratios <= thr)[0]

            if keep_indices.size == 0:
                # fallback: keep the smallest-energy component
                # keep_indices = np.array([Vt.shape[0] - 1], dtype=np.int64)
                pass

            Vk = Vt[keep_indices, :].T  # (d, k_new)
            k = int(keep_indices.size)

            # Reconstruct: Xc_hat = (Xc @ Vk) @ Vk.T
            # Note: Xc @ Vk gives the scores (n, k), then @ Vk.T projects back to (n, d)
            Xc_hat = (Xc.astype(np.float64) @ Vk.astype(np.float64)) @ Vk.astype(np.float64).T

            if not np.isfinite(Xc_hat).all():
                bad = np.sum(~np.isfinite(Xc_hat))
                raise FloatingPointError(
                    f"[SVD exclude_dominant_ratio] Non-finite values in Xc_hat: {bad} elements are NaN/Inf. "
                    f"thr={thr}, total={total:.3e}, kept={k}/{Vt.shape[0]}"
                )

        # -------- branch C: retain-based truncation --------
        else:
            cumsum = np.cumsum(var) / (total + eps)
            k = int(np.searchsorted(cumsum, retain, side="left") + 1)
            k = max(1, min(k, Vt.shape[0]))

            Vk = Vt[:k, :].T  # (d, k)

            # Reconstruct: Xc_hat = (Xc @ Vk) @ Vk.T
            Xc_hat = (Xc.astype(np.float64) @ Vk.astype(np.float64)) @ Vk.astype(np.float64).T

            if not np.isfinite(Xc_hat).all():
                bad = np.sum(~np.isfinite(Xc_hat))
                raise FloatingPointError(
                    f"[SVD retain] Non-finite values in Xc_hat: {bad} elements are NaN/Inf. "
                    f"retain={retain}, k={k}, total={total:.3e}"
                )

        # 3) Add mean back
        X_hat = Xc_hat.astype(np.float32) + mean_vec.astype(np.float32)

        if not np.isfinite(X_hat).all():
            bad = np.sum(~np.isfinite(X_hat))
            raise FloatingPointError(
                f"[SVD output] Non-finite values in X_hat: {bad} elements are NaN/Inf."
            )

        return (X_hat, k) if return_rank else X_hat

    def _variant_method_keys(self) -> List[str]:
        """本实验新增的外推方法键。

        新方案：在锚点上拟合 1D/ND Linear 回归来预测目标模型的 0/1 正确性，
        并外推到所有非锚点题目。提供五种特征输入：
          - mlp_logit_avg:    源模型 MLP 概率在源上取均值（s_bar）
          - mlp_label_avg:    源模型原始 0/1 得分在源上取均值（y_bar）
          - mlp_logit_vec:    源模型 MLP logits 向量作为输入（每源一个logit1）
          - mlp_label_vec:    源模型原始 0/1 得分向量作为输入（每源一个0/1）
          - question_embed_vec: 题目embedding向量作为输入（mean=题目级向量；concat=按source拼接flatten）
        """
        return ["mlp_logit_avg", "mlp_label_avg", "mlp_logit_vec", "mlp_label_vec", "question_embed_vec"]

    def _fit_anchor_linear_1d(
        self,
        x_anchor: np.ndarray,
        y_anchor: np.ndarray,
    ):
        """在锚点上拟合 1D Linear:  x -> y。

        - x_anchor: shape (n_anchor,), 源模型在锚点题目上的平均分数
        - y_anchor: shape (n_anchor,), 目标模型在锚点题目上的 0/1 标签
        """
        x_anchor = np.asarray(x_anchor, dtype=np.float32).reshape(-1)
        y_anchor = np.asarray(y_anchor, dtype=np.float32).reshape(-1) # LinearRegression target float

        meta = {
            "one_class": False,
            "const_p": None,
            "coef": None,
            "intercept": None,
        }

        # 兜底：空锚点（理论上不会发生）
        if y_anchor.size == 0:
            raise RuntimeError("锚点为空")

        # 统一使用 Ridge 回归 (固定 alpha)
        # 即使对于 1D 特征，为了避免小样本过拟合并保持一致性，也使用 Ridge
        from sklearn.linear_model import Ridge

        clf = Ridge(alpha=self.ridge_alpha)
        clf.fit(x_anchor.reshape(-1, 1), y_anchor)

        meta["coef"] = float(clf.coef_.reshape(-1)[0]) if hasattr(clf, "coef_") else None
        meta["intercept"] = float(clf.intercept_) if hasattr(clf, "intercept_") else None

        def _pred(x: np.ndarray) -> np.ndarray:
            xx = np.asarray(x, dtype=np.float32).reshape(-1, 1)
            return clf.predict(xx).astype(np.float32)

        return _pred, meta

    def _fit_anchor_linear_nd(
        self,
        X_anchor: np.ndarray,
        y_anchor: np.ndarray,
    ):
        """在锚点上拟合多维 Linear:  X -> y。

        - X_anchor: shape (n_anchor, n_feat)
        - y_anchor: shape (n_anchor,), 目标模型在锚点题目上的 0/1 标签
        """
        X_anchor = np.asarray(X_anchor, dtype=np.float32)
        y_anchor = np.asarray(y_anchor, dtype=np.float32).reshape(-1) # LinearRegression target float

        meta = {
            "one_class": False,
            "const_p": None,
            "coef": None,
            "intercept": None,
        }

        if y_anchor.size == 0:
            raise RuntimeError("锚点为空")

        # 使用固定 alpha 的 Ridge 回归
        from sklearn.linear_model import Ridge

        clf = Ridge(alpha=self.ridge_alpha)
        clf.fit(X_anchor, y_anchor)

        if hasattr(clf, "coef_"):
            meta["coef"] = [float(v) for v in clf.coef_.reshape(-1)]
        if hasattr(clf, "intercept_"):
            meta["intercept"] = float(clf.intercept_)

        def _pred(X: np.ndarray) -> np.ndarray:
            XX = np.asarray(X, dtype=np.float32)
            if XX.ndim == 1:
                XX = XX.reshape(1, -1)  # 单样本：正确 reshape 为 (1, n_features)
            return clf.predict(XX).astype(np.float32)

        return _pred, meta


    def _get_concat_mlp_probs_per_source(self) -> np.ndarray:
        """返回 P[q, s] = 每个源模型s和问题q的MLP概率（正确）。

        始终使用concat风格的缓存嵌入：self._current_encoded_features_concat_for_mlp
        形状为 (n_questions, n_sources, 32)。

        注意：当启用SVD时，使用的是原始32维embeddings，而不是降维后的表征。
        """
        if not hasattr(self, "_current_mlp_model"):
            raise RuntimeError("Missing _current_mlp_model")
        if not hasattr(self, "_current_encoded_features_concat_for_mlp"):
            raise RuntimeError("Missing _current_encoded_features_concat_for_mlp")

        mlp_model = self._current_mlp_model
        feats = torch.as_tensor(self._current_encoded_features_concat_for_mlp)

        if feats.dim() != 3:
            raise RuntimeError(f"期望concat特征为3D (n_q, n_src, 32)，但得到shape={tuple(feats.shape)}")

        n_q, n_s, feat_dim = feats.shape
        if feat_dim == 32 or feat_dim == 16 or feat_dim == 64:
            pass
        else:
            # 模型的clf期望32维输入；如果您的训练不同，请在此调整。
            raise RuntimeError(f"期望feat_dim=32, 16, 64，但得到{feat_dim}")

        with torch.no_grad():
            flat = feats.reshape(-1, feat_dim).to(self.device)
            logits = mlp_model.clf(flat)
            probs = torch.softmax(logits, dim=1)[:, 1]
            probs = probs.view(n_q, n_s)

        return probs.detach().cpu().numpy().astype(np.float64)

    def _get_source_labels_matrix(self, train_models_data: Dict[str, Dict]) -> np.ndarray:
        """返回 Y[q, s] = 每个问题和源模型的源0/1标签。

        源顺序遵循self.train_models（与concat嵌入堆叠的方式相同）。
        """
        if not self.train_models:
            raise RuntimeError("self.train_models is empty")

        ys: List[np.ndarray] = []
        for m in self.train_models:
            if m not in train_models_data:
                raise RuntimeError(f"train_models_data missing source model: {m}")
            y = train_models_data[m]["labels"].detach().cpu().numpy().astype(np.int8)
            ys.append(y)
        # 堆叠为 (n_sources, n_q) 然后转置
        Y = np.stack(ys, axis=0).T
        return Y


    def _get_question_embedding_matrix(
        self,
        clustering_method: Optional[str] = None,
        apply_svd: Optional[bool] = None,
        svd_retain: Optional[float] = None,
    ) -> Tuple[np.ndarray, str]:
        """返回题目级 embedding 矩阵 E。

        输入向量取决于 clustering_method：
          - mean: 题目级向量（通常是 (n_q, d)，常见 d=32）
          - concat: 按 source 维度拼接后 flatten（(n_q, n_src * d)）

        如果 clustering_method 为 None，则使用 self._current_representation_method
        （base 类在调用 evaluate_on_test_set 前会设置此变量）。
        """
        if clustering_method is None:
            # 必须显式指定 clustering_method，不能从缓存变量推断
            # 因为 base 类同时设置了 _current_encoded_features 和 _current_encoded_features_concat
            if hasattr(self, "_current_representation_method"):
                clustering_method = self._current_representation_method
            else:
                raise ValueError(
                    "clustering_method not provided and _current_representation_method not set. "
                    "Please either pass clustering_method explicitly or ensure base class sets _current_representation_method."
                )

        if clustering_method == "mean":
            # 使用 base 类设置的特征变量
            if not hasattr(self, "_current_encoded_features"):
                raise RuntimeError(
                    "Missing _current_encoded_features for mean clustering. "
                    "Base class should set this variable."
                )
            feats = torch.as_tensor(self._current_encoded_features)
            # mean 特征已经是 2D: (n_q, d)
        elif clustering_method == "concat":
            if not hasattr(self, "_current_encoded_features_concat"):
                raise RuntimeError("Missing _current_encoded_features_concat for concat clustering")
            feats = torch.as_tensor(self._current_encoded_features_concat)
            # concat 特征是 3D: (n_q, n_src, d)，需要 reshape 为 2D
            if feats.dim() == 3:
                feats = feats.reshape(int(feats.shape[0]), -1)  # (n_q, n_src*d)
        elif clustering_method == "fused":
            # Fused 策略: Mean + Concat 融合
            # 递归调用获取基础特征 (已经处理过 Z-score 和 SVD)
            E_mean, _ = self._get_question_embedding_matrix("mean", apply_svd=apply_svd, svd_retain=svd_retain)
            E_concat, _ = self._get_question_embedding_matrix("concat", apply_svd=apply_svd, svd_retain=svd_retain)

            Xm = E_mean
            Xc = E_concat

            # 维度平衡缩放 (1/sqrt(d))
            # 无论是否 Z-score 或 SVD，平衡维度贡献通常是有益的
            d_m = Xm.shape[1]
            d_c = Xc.shape[1]
            Xm = Xm * (1.0 / np.sqrt(d_m))
            Xc = Xc * (1.0 / np.sqrt(d_c))

            # 拼接
            X_fused = np.concatenate([Xm, Xc], axis=1)

            # 最终 L2 Normalize 控制
            if not self.use_euclidean:
                # Cosine 模式: 需要 L2 Normalize
                norm_f = np.linalg.norm(X_fused, axis=1, keepdims=True)
                norm_f[norm_f < 1e-12] = 1.0
                E_fused = X_fused / norm_f
            else:
                # Euclidean 模式: 不做 L2 Normalize
                E_fused = X_fused

            return E_fused.astype(np.float32), clustering_method

        else:
            raise ValueError(f"Unknown clustering_method: {clustering_method}")

        if feats.dim() != 2:
            raise RuntimeError(f"Unexpected question embedding shape={tuple(feats.shape)}; expected 2D")

        E = feats.detach().cpu().numpy().astype(np.float32)

        # ===== Z-score Normalization (Global Switch) =====
        # 优先级最高：对所有方法适用 (包括 SVD 和非 SVD)
        # 注意：如果已经做了 SVD 缓存 (already_reduced)，则不需要再做 Z-score，
        # 因为缓存的数据应该是已经处理过的。
        # 但这里的 already_reduced 检查是在后面做的。
        # 我们需要在 SVD 逻辑之前应用 Z-score。

        # 为了避免对缓存数据重复 Z-score，我们需要先判断是否命中缓存。
        # 只有当需要重新计算 (apply_svd or not already_reduced) 时才应用 Z-score

        if apply_svd is None:
            apply_svd = self.svd_for_clustering
        if svd_retain is None:
            svd_retain = self.svd_retain

        # 检查特征是否已经被SVD降维过
        # 逻辑：如果svd_for_clustering=True且存在对应的SVD缓存，则说明已经降维过了
        # 因为只有在svd_for_clustering=True时才会设置这些缓存
        already_reduced = False
        if self.svd_for_clustering:
            if clustering_method == "mean":
                # 如果存在SVD降维后的mean缓存，说明已经降维
                if hasattr(self, "_svd_reduced_mean_embeddings") and self._svd_reduced_mean_embeddings is not None:
                    already_reduced = True
            elif clustering_method == "concat":
                # 如果存在SVD降维后的concat缓存，说明已经降维
                if hasattr(self, "_svd_reduced_concat_embeddings") and self._svd_reduced_concat_embeddings is not None:
                    already_reduced = True

        if not already_reduced and self.use_zscore:
            # 应用 Column-wise Z-score
            eps = 1e-6
            mean = np.mean(E, axis=0, keepdims=True)
            std = np.std(E, axis=0, keepdims=True)
            E = (E - mean) / (std + eps)

        # ===== SVD reduction =====
        if apply_svd and not already_reduced:
            E = self._svd_reduce_matrix(
                E,
                retain=float(svd_retain),
                exclude_dominant_ratio=self.svd_exclude_dominant_ratio,
                specific_components=self.svd_specific_components
            )

        # ===== L2 Normalization (controlled by use_euclidean) =====
        # 如果不使用欧氏距离 (即使用 Cosine)，则需要 L2 Normalize
        # 注意：
        # - Tangent-space SVD 内部通常会做 L2 Normalize。
        # - Standard PCA 内部只做 Center，不做 L2 Normalize。
        # 因此，对于 Standard PCA，如果后续用 Cosine 聚类，必须在此处进行归一化。

        # 只要 self.use_euclidean 为 False，无论是否做了 SVD，最终都应该归一化，
        # 以保证后续的点积计算等效于 Cosine 相似度。

        if not already_reduced and not self.use_euclidean:
            # Cosine 模式: 手动 L2 Normalize
            norm = np.linalg.norm(E, axis=1, keepdims=True)
            norm[norm < 1e-12] = 1.0
            E = E / norm

        return E, clustering_method

    def evaluate_on_test_set(
        self,
        anchor_questions: List[int],
        test_models_data: Dict[str, Dict],
        train_models_data: Dict[str, Dict],
        cluster_dict: Optional[Dict[int, List[int]]] = None,
        reference_ids: Optional[List[str]] = None,
        strict_id_alignment: bool = True,
        dynamic_source_selection: bool = False,
        clustering_method: Optional[str] = None,
    ) -> Dict[str, Dict[str, Dict]]:
        """评估：新增 Linear 外推变体（锚点拟合 -> 外推到全量题目）。

        具体做法（对每个 target 模型都做一次）：
          1) 计算源侧特征：
             - s_bar[q]: 源模型 MLP 概率在源上取均值
             - y_bar[q]: 源模型原始 0/1 得分在源上取均值
             - L[q, :]:  源模型 MLP logits 向量（每源一个logits[:,1]）
             - Y_src[q,:]: 源模型原始 0/1 得分向量（每源一个0/1）
          2) 在锚点集合 A 上拟合 Linear：
             - (s_bar[q] -> y_t[q])_{q in A}       对应 mlp_logit_avg (1D)
             - (y_bar[q] -> y_t[q])_{q in A}       对应 mlp_label_avg (1D)
             - (L[q,:]   -> y_t[q])_{q in A}       对应 mlp_logit_vec (ND)
             - (Y_src[q,:]-> y_t[q])_{q in A}      对应 mlp_label_vec (ND)
          3) 用对应 Linear 预测全量题目，锚点处用真实 y_t 覆盖；均值作为估计准确率；
          4) 若锚点标签单类（全0/全1），按 HSbench_cukm_AIPW.py 的 Jeffreys prior 平滑退化为常数预测。
        """

        results: Dict[str, Dict[str, Dict]] = {
            # "naive_mean": {},  # 注释掉简单平均值基线
            "apw_weighted": {},
            # "tailoredbench_scaling": {},  # 注释掉，该方法未实际运行
        }
        for k in self._variant_method_keys():
            results[k] = {}

        # APW权重（簇大小比率）
        anchor_weights = None
        if cluster_dict:
            anchor_weights = self._calculate_apw_weights(anchor_questions, cluster_dict)

        # 源模型平均分数 s_bar[q]：concat 风格每源 MLP 概率 -> 源上取均值
        P = self._get_concat_mlp_probs_per_source()  # (n_q, n_src)
        s_bar = P.mean(axis=1).astype(np.float64)    # (n_q,)

        # 源模型概率向量 L[q, s]：复用 P，避免重复推理
        L = P  # (n_q, n_src)

        # 源 0/1 平均分数 y_bar[q]：源标签在源上取均值（原始0/1得分）
        Y_src = self._get_source_labels_matrix(train_models_data)  # (n_q, n_src)
        y_bar = Y_src.mean(axis=1).astype(np.float64)              # (n_q,)

        # 题目 embedding 向量：输入随聚类方式变化（mean/concat）
        E, embed_cm = self._get_question_embedding_matrix(clustering_method=clustering_method)

        n_q = int(s_bar.shape[0])

        # 维度一致性断言：确保特征、标签和源模型数量对齐
        n_src = len(self.train_models)
        if P.shape[1] != n_src:
            raise RuntimeError(
                f"MLP probability source dimension mismatch: P.shape[1]={P.shape[1]} vs "
                f"len(self.train_models)={n_src}. This indicates feature/label misalignment."
            )
        if Y_src.shape[1] != n_src:
            raise RuntimeError(
                f"Source label dimension mismatch: Y_src.shape[1]={Y_src.shape[1]} vs "
                f"len(self.train_models)={n_src}. This indicates feature/label misalignment."
            )

        if int(E.shape[0]) != n_q:
            raise RuntimeError(
                f"question embedding row mismatch: E has {int(E.shape[0])} rows, expected n_q={n_q}"
            )
        embed_dim = int(E.shape[1])
        anchor_qs = np.asarray(anchor_questions, dtype=np.int64).reshape(-1)
        if anchor_qs.size == 0:
            raise RuntimeError("anchor_questions is empty")

        for model_name, model_data in test_models_data.items():
            if model_name in self.train_models:
                continue

            test_labels = model_data["labels"]
            test_ids = model_data.get("ids", None)

            # ===== ID对齐检查（与基线相同）=====
            if reference_ids is not None and test_ids is not None:
                try:
                    if hasattr(test_ids, "detach"):
                        test_ids = test_ids.detach().cpu().tolist()
                    elif hasattr(test_ids, "tolist"):
                        test_ids = test_ids.tolist()
                except Exception:
                    pass

                if len(test_ids) != len(reference_ids):
                    msg = f"测试模型 {model_name} 的ID长度 ({len(test_ids)}) 与训练集参考ID长度 ({len(reference_ids)}) 不匹配"
                    if strict_id_alignment:
                        raise ValueError(msg)
                    console.print(f"[yellow]警告: {msg}[/yellow]")
                elif list(test_ids) != list(reference_ids):
                    msg = f"测试模型 {model_name} 的题目ID顺序与训练集参考ID不一致"
                    if strict_id_alignment:
                        raise ValueError(msg)
                    console.print(f"[yellow]警告: {msg}，可能导致评估错误[/yellow]")

            # 目标标签（全量 / 锚点）
            y_all = test_labels.detach().cpu().numpy().astype(np.int8, copy=False)
            y_anchor = y_all[anchor_qs].astype(np.int8, copy=False)

            # ===== APW weighted mean =====
            if anchor_weights is not None:
                anchor_labels_t = test_labels[torch.tensor(anchor_questions, dtype=torch.long, device=test_labels.device)]
                weighted_accuracy = self._calculate_weighted_anchor_accuracy(anchor_labels_t, anchor_weights)

                # 计算 APW 的 agreement
                # APW 外推：同一簇内的所有题目，预测标签 = 该簇锚点的标签
                # 注意：cluster_dict 的索引和 y_all 的索引都是基于 reference_ids 的顺序，严格对齐
                pred_labels_apw = np.zeros(n_q, dtype=np.int8)
                for cluster_id, member_indices in cluster_dict.items():
                    # 找到该簇的锚点
                    anchor_in_cluster = None
                    for idx in member_indices:
                        if idx in anchor_qs:
                            anchor_in_cluster = idx
                            break

                    if anchor_in_cluster is not None:
                        # 该簇所有题目的预测 = 锚点的真实标签
                        # member_indices 和 anchor_in_cluster 都是相对于 reference_ids 的索引
                        anchor_label = y_all[anchor_in_cluster]
                        pred_labels_apw[member_indices] = anchor_label

                # 计算 agreement：预测标签和真实标签的准确率（确保顺序严格对齐）
                agreement_apw = float(np.mean(pred_labels_apw == y_all))

                results["apw_weighted"][model_name] = {
                    "estimated_model_accuracy": float(weighted_accuracy),
                    "n_anchor": int(anchor_qs.size),
                    "agreement": agreement_apw,
                }
            else:
                raise RuntimeError("APW weighted requires cluster_dict for anchor weights")

            # ===== 新方案：锚点 Linear 外推（特征=源MLP概率均分）=====
            x_anchor = s_bar[anchor_qs]
            pred_fn, meta = self._fit_anchor_linear_1d(x_anchor, y_anchor)

            p_all = np.asarray(pred_fn(s_bar), dtype=np.float64).reshape(-1)
            if p_all.shape[0] != n_q:
                raise RuntimeError(f"mlp_logit_avg pred shape mismatch: got {p_all.shape}, expected ({n_q},)")
            p_all = np.clip(p_all, 0.0, 1.0)

            # 概率：锚点处用真实标签覆盖
            pred_prob = p_all.copy()
            pred_prob[anchor_qs] = y_anchor.astype(np.float64)
            est_prob = float(np.mean(pred_prob))

            # 计算 agreement：阈值化后的预测和真实标签的准确率
            pred_labels = (pred_prob >= 0.5).astype(np.int8)
            agreement = float(np.mean(pred_labels == y_all))

            results["mlp_logit_avg"][model_name] = {
                "estimated_model_accuracy": float(est_prob),
                "n_anchor": int(anchor_qs.size),
                "variant": "anchor_linear_1d",
                "feature": "src_mlp_prob_mean",
                "one_class_smoothing": bool(meta.get("one_class", False)),
                "const_p": meta.get("const_p", None),
                "coef": meta.get("coef", None),
                "intercept": meta.get("intercept", None),
                "agreement": agreement,
            }

            # ===== 新方案：锚点 Linear 外推（特征=源0/1得分均分）=====
            x_anchor2 = y_bar[anchor_qs]
            pred_fn2, meta2 = self._fit_anchor_linear_1d(x_anchor2, y_anchor)

            p_all2 = np.asarray(pred_fn2(y_bar), dtype=np.float64).reshape(-1)
            if p_all2.shape[0] != n_q:
                raise RuntimeError(f"mlp_label_avg pred shape mismatch: got {p_all2.shape}, expected ({n_q},)")
            p_all2 = np.clip(p_all2, 0.0, 1.0)

            pred_prob2 = p_all2.copy()
            pred_prob2[anchor_qs] = y_anchor.astype(np.float64)
            est_prob2 = float(np.mean(pred_prob2))

            # 计算 agreement
            pred_labels2 = (pred_prob2 >= 0.5).astype(np.int8)
            agreement2 = float(np.mean(pred_labels2 == y_all))

            results["mlp_label_avg"][model_name] = {
                "estimated_model_accuracy": float(est_prob2),
                "n_anchor": int(anchor_qs.size),
                "variant": "anchor_linear_1d",
                "feature": "src_label_mean",
                "one_class_smoothing": bool(meta2.get("one_class", False)),
                "const_p": meta2.get("const_p", None),
                "coef": meta2.get("coef", None),
                "intercept": meta2.get("intercept", None),
                "agreement": agreement2,
            }

            # ===== 新方案：锚点 Linear 外推（特征=源MLP logits向量）=====
            X_anchor3 = L[anchor_qs, :]
            pred_fn3, meta3 = self._fit_anchor_linear_nd(X_anchor3, y_anchor)

            p_all3 = np.asarray(pred_fn3(L), dtype=np.float64).reshape(-1)
            if p_all3.shape[0] != n_q:
                raise RuntimeError(f"mlp_logit_vec pred shape mismatch: got {p_all3.shape}, expected ({n_q},)")
            p_all3 = np.clip(p_all3, 0.0, 1.0)

            pred_prob3 = p_all3.copy()
            pred_prob3[anchor_qs] = y_anchor.astype(np.float64)
            est_prob3 = float(np.mean(pred_prob3))

            # 计算 agreement
            pred_labels3 = (pred_prob3 >= 0.5).astype(np.int8)
            agreement3 = float(np.mean(pred_labels3 == y_all))

            results["mlp_logit_vec"][model_name] = {
                "estimated_model_accuracy": float(est_prob3),
                "n_anchor": int(anchor_qs.size),
                "variant": "anchor_linear_nd",
                "feature": "src_mlp_logit_vec",
                "one_class_smoothing": bool(meta3.get("one_class", False)),
                "const_p": meta3.get("const_p", None),
                "coef": meta3.get("coef", None),
                "intercept": meta3.get("intercept", None),
                "agreement": agreement3,
            }

            # ===== 新方案：锚点 Linear 外推（特征=源0/1得分向量）=====
            X_anchor4 = Y_src[anchor_qs, :]
            pred_fn4, meta4 = self._fit_anchor_linear_nd(X_anchor4, y_anchor)

            p_all4 = np.asarray(pred_fn4(Y_src), dtype=np.float64).reshape(-1)
            if p_all4.shape[0] != n_q:
                raise RuntimeError(f"mlp_label_vec pred shape mismatch: got {p_all4.shape}, expected ({n_q},)")
            p_all4 = np.clip(p_all4, 0.0, 1.0)

            pred_prob4 = p_all4.copy()
            pred_prob4[anchor_qs] = y_anchor.astype(np.float64)
            est_prob4 = float(np.mean(pred_prob4))

            # 计算 agreement
            pred_labels4 = (pred_prob4 >= 0.5).astype(np.int8)
            agreement4 = float(np.mean(pred_labels4 == y_all))

            results["mlp_label_vec"][model_name] = {
                "estimated_model_accuracy": float(est_prob4),
                "n_anchor": int(anchor_qs.size),
                "variant": "anchor_linear_nd",
                "feature": "src_label_vec",
                "one_class_smoothing": bool(meta4.get("one_class", False)),
                "const_p": meta4.get("const_p", None),
                "coef": meta4.get("coef", None),
                "intercept": meta4.get("intercept", None),
                "agreement": agreement4,
            }

            # ===== 新方案：锚点 Linear 外推（特征=题目 embedding 向量）=====
            X_anchor5 = E[anchor_qs, :]
            pred_fn5, meta5 = self._fit_anchor_linear_nd(X_anchor5, y_anchor)

            p_all5 = np.asarray(pred_fn5(E), dtype=np.float64).reshape(-1)
            if p_all5.shape[0] != n_q:
                raise RuntimeError(f"question_embed_vec pred shape mismatch: got {p_all5.shape}, expected ({n_q},)")
            p_all5 = np.clip(p_all5, 0.0, 1.0)

            pred_prob5 = p_all5.copy()
            pred_prob5[anchor_qs] = y_anchor.astype(np.float64)
            est_prob5 = float(np.mean(pred_prob5))

            # 计算 agreement
            pred_labels5 = (pred_prob5 >= 0.5).astype(np.int8)
            agreement5 = float(np.mean(pred_labels5 == y_all))

            coef = meta5.get("coef", None)
            coef_l2 = None
            if coef is not None:
                coef_l2 = float(np.linalg.norm(coef))

            results["question_embed_vec"][model_name] = {
                "estimated_model_accuracy": float(est_prob5),
                "n_anchor": int(anchor_qs.size),
                "variant": "anchor_linear_nd",
                "feature": f"question_embedding_{embed_cm}",
                "embed_dim": int(embed_dim),
                "one_class_smoothing": bool(meta5.get("one_class", False)),
                "const_p": meta5.get("const_p", None),
                "intercept": meta5.get("intercept", None),
                "coef_l2": coef_l2,  # 只保存系数的 L2 范数，不保存完整向量
                "agreement": agreement5,
            }
        return results

    def _process_single_combo_with_shared(
        self,
        combo_idx: int,
        train_labels_dict: Dict,
        test_labels_dict: Dict,
        train_models: List[str],
        test_models: List[str],
        layer_type: str,
        operate_type: str,
        anchor_points: List[int],
        true_accuracies: Dict[str, float] = None,
        strict_id_alignment: bool = True,
        num_runs_per_anchor: int = 1,
        test_features_for_similarity: bool = False,
        shared_feature_data: Dict[str, Dict] = None,
        exclude_extreme_samples: bool = False,
        exclude_extreme_samples_predict_all: bool = False,
        dynamic_source_selection: bool = False,
        family: str = None,
        strength: str = None,
    ) -> Dict:
        """处理单个训练/测试组合 - 重写以使用概率聚类

        exclude_extreme_samples_predict_all 模式：
        - 聚类时：排除极端样本
        - 外推时：使用全量数据（包括极端样本），回归头预测所有样本
        - 不做 R/E 加权合并
        """
        try:
            mode_str = "(predict_all模式)" if exclude_extreme_samples_predict_all else "(基于概率聚类)"
            console.print(f"[cyan]处理组合 {combo_idx} {mode_str}:[/cyan]")
            console.print(f"  训练模型数: {len(train_models)}, 测试模型数: {len(test_models)}")

            # 构建训练数据
            train_models_data = {}
            for model_name in train_models:
                if model_name in shared_feature_data:
                    shared_data = shared_feature_data[model_name]
                    shared_features_tensor = shared_data["features"]
                    shared_labels_tensor = shared_data["labels"]

                    features = shared_features_tensor.to(self.device)

                    train_models_data[model_name] = {
                        "features": features,
                        "labels": shared_labels_tensor.to(self.device),
                        "ids": shared_data["ids"],
                    }

            # 收集输入维度
            in_dims = set()
            for model_name, model_data in train_models_data.items():
                in_dims.add(model_data["features"].shape[1])
            in_dims = list(in_dims)

            # 加载MLP模型
            self.train_models = train_models
            mlp_model_path = self.get_mlp_model_path(combo_idx, layer_type, operate_type, self.use_expand, self.lambda_max, self.num_models, self.model_strategy, family, strength)

            if not os.path.exists(mlp_model_path):
                console.print(f"[red]错误: 未找到预训练的MLP模型: {mlp_model_path}[/red]")
                raise FileNotFoundError(f"MLP model not found: {mlp_model_path}")

            mlp_model = self.load_trained_mlp(mlp_model_path, in_dims)
            self._current_mlp_model = mlp_model

            # 计算训练集模型嵌入（用于MLP概率计算）
            console.print(f"[yellow]计算训练集模型嵌入 (combo{combo_idx})[/yellow]")
            train_embeddings = {}
            keep_on_gpu = self.device.startswith("cuda") and torch.cuda.is_available()

            for model_name in train_models:
                if model_name in train_models_data:
                    features = train_models_data[model_name]["features"]
                    embeddings = self.get_embeddings_with_mlp(mlp_model, features, keep_on_gpu)
                    train_embeddings[model_name] = embeddings

            # 数据校验
            first_ids = None
            for model_name, model_data in train_models_data.items():
                ids_list = model_data["ids"]
                if first_ids is None:
                    first_ids = ids_list
                else:
                    assert first_ids == ids_list, f"训练模型 {model_name} 的题目ID顺序不一致"

            for model_name, model_data in train_models_data.items():
                assert model_data["labels"].shape[0] == len(model_data["ids"])

            for model_name, model_data in train_models_data.items():
                u = torch.unique(model_data["labels"])
                assert set(u.tolist()).issubset({0, 1}), f"{model_name} 的 labels 存在非0/1值"

            reference_ids = first_ids

            # 重置本次组合的 exclude 统计（避免沿用上一次）
            self._current_exclude_stats = None

            # 识别和过滤极端样本（如果启用任一模式）
            filtered_train_embeddings = train_embeddings
            valid_indices = None  # 用于保存训练集过滤时使用的有效索引

            # 保存全量数据的副本（用于 predict_all 模式）
            import copy
            train_models_data_full = {k: dict(v) for k, v in train_models_data.items()}
            train_embeddings_full = dict(train_embeddings)
            reference_ids_full = first_ids

            # 两种模式都需要识别极端样本用于聚类过滤
            should_filter_for_clustering = exclude_extreme_samples or exclude_extreme_samples_predict_all

            if should_filter_for_clustering:
                console.print(f"[yellow]识别极端样本 (combo{combo_idx})[/yellow]")
                extreme_samples = self.identify_extreme_samples(train_models_data)

                if extreme_samples:
                    # 在过滤任何数据之前，先使用完整的标签计算极端样本准确率
                    first_model = list(train_models_data.keys())[0]
                    labels_full = train_models_data[first_model]["labels"]
                    if hasattr(labels_full, "detach"):
                        labels_full_np = labels_full.detach().cpu().numpy()
                    elif hasattr(labels_full, "cpu"):
                        labels_full_np = labels_full.cpu().numpy()
                    else:
                        labels_full_np = np.array(labels_full)

                    extreme_indices = np.array(list(extreme_samples), dtype=int)
                    extreme_acc_fixed = 0.0
                    if extreme_indices.size > 0:
                        n_all_correct = np.sum(labels_full_np[extreme_indices] == 1)
                        n_all_wrong = np.sum(labels_full_np[extreme_indices] == 0)
                        denom = n_all_correct + n_all_wrong
                        extreme_acc_fixed = float(n_all_correct / denom) if denom > 0 else 0.0

                    console.print(f"[yellow]过滤极端样本 (combo{combo_idx})[/yellow]")
                    filtered_train_embeddings, valid_indices = self.filter_extreme_samples(train_embeddings, extreme_samples)

                    # 同时过滤训练模型数据中的 labels / features / ids，确保对齐
                    for model_name in train_models_data:
                        # labels
                        labels = train_models_data[model_name]["labels"]
                        if hasattr(labels, "detach"):
                            labels_np = labels.detach().cpu().numpy()
                        elif hasattr(labels, "cpu"):
                            labels_np = labels.cpu().numpy()
                        else:
                            labels_np = np.array(labels)

                        filtered_labels = labels_np[valid_indices]
                        dev = labels.device if hasattr(labels, "device") else self.device
                        train_models_data[model_name]["labels"] = torch.as_tensor(filtered_labels, device=dev)

                        # features（若存在）
                        feats = train_models_data[model_name].get("features", None)
                        if feats is not None:
                            idx_device = torch.from_numpy(valid_indices).to(feats.device)
                            train_models_data[model_name]["features"] = feats[idx_device]

                        # ids
                        original_ids = train_models_data[model_name]["ids"]
                        if hasattr(original_ids, "detach"):
                            original_ids = original_ids.detach().cpu().tolist()
                        elif hasattr(original_ids, "tolist"):
                            original_ids = original_ids.tolist()

                        filtered_ids = [original_ids[i] for i in valid_indices]
                        train_models_data[model_name]["ids"] = filtered_ids

                    # 更新 reference_ids 为过滤后的 ids
                    reference_ids = filtered_ids

                    # 计算 R/E 权重，并保存供后续结果合并使用
                    N_R = int(len(valid_indices))
                    N_E = int(len(extreme_samples))
                    N_total = N_R + N_E
                    w_R = (N_R / N_total) if N_total > 0 else 1.0
                    w_E = (N_E / N_total) if N_total > 0 else 0.0

                    self._current_exclude_stats = {
                        "w_R": float(w_R),
                        "w_E": float(w_E),
                        "extreme_acc_fixed": float(extreme_acc_fixed),
                    }
                else:
                    console.print(f"[yellow]未发现极端样本，跳过过滤 (combo{combo_idx})[/yellow]")

            # 准备concat特征（强制按self.train_models顺序，确保与标签矩阵列顺序一致）
            model_names = list(self.train_models)
            missing = [m for m in model_names if m not in filtered_train_embeddings]
            if missing:
                raise RuntimeError(f"filtered_train_embeddings missing models: {missing}")

            # 保存原始32维embeddings副本（供MLP clf获取logits使用）
            import copy
            original_train_embeddings = copy.deepcopy(filtered_train_embeddings)

            # 初始化SVD rank变量
            per_model_svd_ranks = {}  # 记录每个模型的SVD rank

            # =========================================================
            # 新逻辑：SVD 在 mean/concat 之前，对每个模型的 embeddings 单独做 SVD
            # =========================================================
            if self.svd_for_clustering:
                console.print("[yellow]应用SVD降维到每个模型的embeddings (before mean/concat)...[/yellow]")
                if self.svd_exclude_dominant_ratio is not None:
                    console.print(f"[yellow]SVD模式: 排除解释方差 > {self.svd_exclude_dominant_ratio} 的主分量[/yellow]")

                # 对每个模型的 embeddings 单独做 SVD
                svd_processed_embeddings = {}
                for m in model_names:
                    emb = filtered_train_embeddings[m]  # (n_q, 32)
                    E_m = emb.cpu().numpy()
                    original_dim = E_m.shape[1]

                    E_m_reduced, rank_m = self._svd_reduce_matrix(
                        E_m,
                        retain=self.svd_retain,
                        return_rank=True,
                        exclude_dominant_ratio=self.svd_exclude_dominant_ratio,
                        specific_components=self.svd_specific_components
                    )

                    per_model_svd_ranks[m] = rank_m

                    # 转回 tensor
                    device = emb.device
                    svd_processed_embeddings[m] = torch.from_numpy(E_m_reduced).to(device).float()

                # 打印 SVD 统计信息
                ranks = list(per_model_svd_ranks.values())
                console.print(f"[green]SVD降维完成 (per-model):[/green]")
                console.print(f"  原始维度: 32d")
                console.print(f"  处理后维度: 32d (重构回原维度)")
                console.print(f"  各模型保留的主成分数: min={min(ranks)}, max={max(ranks)}, avg={np.mean(ranks):.1f}")

                # 用 SVD 处理后的 embeddings 替换原始的
                filtered_train_embeddings = svd_processed_embeddings

                # 现在用 SVD 处理后的 embeddings 计算 mean/concat 表征
                # mean 表征: (n_q, 32)
                mean_embeddings = torch.stack([filtered_train_embeddings[m] for m in model_names], dim=1).mean(dim=1)

                # concat 表征: (n_q, n_models, 32) - 保持3D格式供后续使用
                concat_features_tensor = torch.stack([filtered_train_embeddings[m] for m in model_names], dim=1)

                self._current_encoded_features = mean_embeddings
                self._current_encoded_features_concat = concat_features_tensor

                # 缓存原始32维concat特征供MLP使用（MLP需要32维输入）
                # 注意：这里使用原始的 embeddings，因为 MLP 是在原始特征上训练的
                original_concat_features_tensor = torch.stack([original_train_embeddings[m] for m in model_names], dim=1)
                self._current_encoded_features_concat_for_mlp = original_concat_features_tensor

                # 标记已经做过 SVD（用于 _get_question_embedding_matrix 中跳过重复处理）
                self._svd_reduced_mean_embeddings = mean_embeddings
                self._svd_reduced_concat_embeddings = concat_features_tensor

                console.print("[green]SVD处理后的表征已缓存，将用于聚类和外推[/green]")
            else:
                # 不使用SVD，保持原始处理
                concat_features_tensor = torch.stack([filtered_train_embeddings[m] for m in model_names], dim=1)  # (n_q, n_models, 32)
                mean_embeddings_tensor = concat_features_tensor.mean(dim=1)  # (n_q, 32)

                self._current_encoded_features = mean_embeddings_tensor
                self._current_encoded_features_concat = concat_features_tensor
                self._current_encoded_features_concat_for_mlp = concat_features_tensor
                self._svd_reduced_mean_embeddings = None
                self._svd_reduced_concat_embeddings = None

            # 基于embeddings计算题目表征（使用基类的compute_question_representations）
            console.print(f"[yellow]计算基于embeddings的题目表征 (combo{combo_idx})[/yellow]")

            # 如果启用SVD，直接使用降维后的表征构建question_reps
            if self.svd_for_clustering and hasattr(self, '_svd_reduced_mean_embeddings') and self._svd_reduced_mean_embeddings is not None:
                # 手动构建降维后的question_reps_mean（跳过compute_question_representations）
                q_ids = list(range(len(self._svd_reduced_mean_embeddings)))

                feats_mean = self._svd_reduced_mean_embeddings
                # 如果不使用欧氏距离 (Cosine模式)，需要手动归一化 SVD 的结果
                if not self.use_euclidean:
                    feats_mean = torch.nn.functional.normalize(feats_mean, p=2, dim=1)
                    dist_mode = "Cosine"
                else:
                    dist_mode = "L2"

                question_reps_mean = {i: feats_mean[i] for i in q_ids}
                console.print(f"[green]使用SVD降维后的表征构建question_reps (mean方式，{dist_mode}距离聚类)[/green]")
            else:
                # 使用原始方法
                question_reps_mean = self.compute_question_representations(filtered_train_embeddings, method="mean")

            # 同样处理concat方式（如果启用）
            if self.svd_for_clustering and hasattr(self, '_svd_reduced_concat_embeddings') and self._svd_reduced_concat_embeddings is not None:
                q_ids = list(range(len(self._svd_reduced_concat_embeddings)))

                feats_concat = self._svd_reduced_concat_embeddings
                # Concat 模式需要 Flatten: (N, M, D) -> (N, M*D)
                if feats_concat.dim() == 3:
                    n_samples, n_models, d_dim = feats_concat.shape
                    feats_concat = feats_concat.reshape(n_samples, n_models * d_dim)

                # 如果不使用欧氏距离 (Cosine模式)，需要手动归一化 SVD 的结果
                if not self.use_euclidean:
                    feats_concat = torch.nn.functional.normalize(feats_concat, p=2, dim=1)
                    dist_mode = "Cosine"
                else:
                    dist_mode = "L2"

                question_reps_concat = {i: feats_concat[i] for i in q_ids}
                console.print(f"[green]使用SVD降维后的表征构建question_reps (concat方式，{dist_mode}距离聚类)[/green]")
            else:
                # 使用原始方法构建concat题目表征
                question_reps_concat = self.compute_question_representations(filtered_train_embeddings, method="concat")

            # 构建 Fused 题目表征 (Mean + Concat)
            console.print(f"[yellow]构建 Fused 题目表征 (combo{combo_idx})[/yellow]")
            # 直接调用 _get_question_embedding_matrix 获取处理好的 numpy 数组
            E_fused, _ = self._get_question_embedding_matrix(clustering_method="fused")
            E_fused_tensor = torch.from_numpy(E_fused).to(self.device)
            # 修改为使用欧氏距离：不进行L2归一化
            # E_fused_norm = torch.nn.functional.normalize(E_fused_tensor, p=2, dim=1)
            q_ids = list(range(len(E_fused_tensor)))
            question_reps_fused = {i: E_fused_tensor[i] for i in q_ids}
            console.print("[green]Fused 题目表征构建完成 (L2距离聚类)[/green]")

            # 加载测试数据
            console.print(f"[yellow]加载测试数据 (combo{combo_idx})[/yellow]")
            test_models_data_full = {}
            for model_name in test_models:
                if model_name in shared_feature_data:
                    shared_data = shared_feature_data[model_name]

                    if test_features_for_similarity:
                        features = shared_data["features"].to(self.device)
                    else:
                        features = None

                    test_models_data_full[model_name] = {
                        "features": features,
                        "labels": shared_data["labels"].to(self.device),
                        "ids": shared_data["ids"],
                    }


            # 同步过滤测试集的极端样本（仅在 exclude_extreme_samples 模式下）
            # predict_all 模式不过滤测试集，使用全量数据
            if exclude_extreme_samples and not exclude_extreme_samples_predict_all and valid_indices is not None:
                console.print(f"[yellow]对测试集同步过滤极端样本 (combo{combo_idx})[/yellow]")
                for model_name in test_models_data_full:
                    # labels
                    labels = test_models_data_full[model_name]["labels"]
                    if hasattr(labels, "detach"):
                        labels_np = labels.detach().cpu().numpy()
                    elif hasattr(labels, "cpu"):
                        labels_np = labels.cpu().numpy()
                    else:
                        labels_np = np.array(labels)

                    filtered_labels = labels_np[valid_indices]
                    dev = labels.device if hasattr(labels, "device") else self.device
                    test_models_data_full[model_name]["labels"] = torch.as_tensor(filtered_labels, device=dev)

                    # features（若存在）
                    feats = test_models_data_full[model_name].get("features", None)
                    if feats is not None:
                        idx_device = torch.from_numpy(valid_indices).to(feats.device)
                        test_models_data_full[model_name]["features"] = feats[idx_device]

                    # ids
                    original_ids = test_models_data_full[model_name]["ids"]
                    if hasattr(original_ids, "detach"):
                        original_ids = original_ids.detach().cpu().tolist()
                    elif hasattr(original_ids, "tolist"):
                        original_ids = original_ids.tolist()

                    test_models_data_full[model_name]["ids"] = [original_ids[i] for i in valid_indices]

                # 为合并结果准备：极端样本部分的固定准确率
                if self._current_exclude_stats is not None:
                    extreme_acc_fixed = float(self._current_exclude_stats.get("extreme_acc_fixed", 0.0))
                    self._current_exclude_stats["acc_on_excluded"] = {
                        m: extreme_acc_fixed for m in test_models_data_full.keys() if m not in train_models
                    }

            # ===== predict_all 模式：需要为外推准备全量数据 =====
            if exclude_extreme_samples_predict_all and valid_indices is not None:
                console.print(f"[yellow]predict_all模式：外推将使用全量数据 (combo{combo_idx})[/yellow]")
                # 保存全量数据供外推使用
                self._train_models_data_full = train_models_data_full
                self._train_embeddings_full = train_embeddings_full
                self._reference_ids_full = reference_ids_full

                # ---------------------------------------------------------
                # 关键修复：在评估前，必须将 self._current_* 特征变量切换为全量数据的特征
                # 这样 evaluate_on_test_set 中的线性回归才能在全量数据上进行预测
                # ---------------------------------------------------------

                # 1. 准备全量 embeddings (如果启用了SVD，需要对全量数据做SVD)
                full_embeddings_map = self._train_embeddings_full
                model_names = list(self.train_models)

                if self.svd_for_clustering:
                    # 对全量数据应用相同的SVD参数
                    # 注意：这里是独立对全量数据做SVD，以获得降维后的特征
                    full_svd_embeddings = {}
                    for m in model_names:
                        emb = full_embeddings_map[m] # tensor
                        E_m = emb.cpu().numpy()

                        E_m_reduced = self._svd_reduce_matrix(
                            E_m,
                            retain=self.svd_retain,
                            exclude_dominant_ratio=self.svd_exclude_dominant_ratio,
                            specific_components=self.svd_specific_components
                        )
                        full_svd_embeddings[m] = torch.from_numpy(E_m_reduced).to(emb.device).float()

                    current_full_embeddings = full_svd_embeddings

                    # 计算 SVD 后的 mean/concat
                    full_mean_emb = torch.stack([current_full_embeddings[m] for m in model_names], dim=1).mean(dim=1)
                    full_concat_emb = torch.stack([current_full_embeddings[m] for m in model_names], dim=1)

                    # 欺骗 _get_question_embedding_matrix，让其认为已经做过 SVD
                    self._svd_reduced_mean_embeddings = full_mean_emb
                    self._svd_reduced_concat_embeddings = full_concat_emb

                else:
                    # 无 SVD
                    current_full_embeddings = full_embeddings_map
                    full_mean_emb = torch.stack([current_full_embeddings[m] for m in model_names], dim=1).mean(dim=1)
                    full_concat_emb = torch.stack([current_full_embeddings[m] for m in model_names], dim=1)

                    self._svd_reduced_mean_embeddings = None
                    self._svd_reduced_concat_embeddings = None

                # 2. 更新供 evaluate_on_test_set 使用的特征变量
                self._current_encoded_features = full_mean_emb
                self._current_encoded_features_concat = full_concat_emb

                # 3. 更新供 MLP 使用的原始特征 (必须是原始维度，无 SVD)
                full_original_concat = torch.stack([full_embeddings_map[m] for m in model_names], dim=1)
                self._current_encoded_features_concat_for_mlp = full_original_concat

                console.print(f"[green]predict_all模式：已切换至全量特征数据 (n={full_mean_emb.shape[0]})[/green]")

            # 评估不同锚点数量
            console.print("[green]评估不同锚点数量的性能[/green]")
            anchor_size_results = {}

            for n_anchors in anchor_points:
                console.print(f"[cyan]测试锚点数量: {n_anchors} (combo{combo_idx})[/cyan]")

                metrics_runs_mean = []
                anchor_questions_mean_list = []
                eval_results_mean_runs = []

                # -------------------------
                # mean 表征进行聚类与评估
                # -------------------------
                for run_idx in range(num_runs_per_anchor):
                    if num_runs_per_anchor > 1:
                        console.print(f"[yellow]  [mean] 运行 {run_idx + 1}/{num_runs_per_anchor}[/yellow]")

                    cluster_dict_mean, kmeans_model_mean = self.cluster_questions(question_reps_mean, n_anchors)

                    anchor_questions_mean = self.select_anchor_questions(
                        question_reps_mean,
                        cluster_dict_mean,
                        kmeans_model_mean,
                        n_anchors,
                    )

                    # ===== predict_all 模式：映射锚点索引并使用全量数据 =====
                    if exclude_extreme_samples_predict_all and valid_indices is not None:
                        # 将过滤后的锚点索引映射回全量数据索引
                        anchor_questions_mean_mapped = [int(valid_indices[i]) for i in anchor_questions_mean]
                        # 映射 cluster_dict 中的索引
                        cluster_dict_mean_mapped = {}
                        for cid, members in cluster_dict_mean.items():
                            cluster_dict_mean_mapped[cid] = [int(valid_indices[i]) for i in members]

                        # 使用全量数据进行评估
                        eval_train_data = self._train_models_data_full
                        eval_reference_ids = self._reference_ids_full
                        eval_anchor_questions = anchor_questions_mean_mapped
                        eval_cluster_dict = cluster_dict_mean_mapped
                    else:
                        eval_train_data = train_models_data
                        eval_reference_ids = reference_ids
                        eval_anchor_questions = anchor_questions_mean
                        eval_cluster_dict = cluster_dict_mean

                    eval_results_mean = self.evaluate_on_test_set(
                        eval_anchor_questions,
                        test_models_data_full,
                        train_models_data=eval_train_data,
                        cluster_dict=eval_cluster_dict,
                        reference_ids=eval_reference_ids,
                        strict_id_alignment=strict_id_alignment,
                        dynamic_source_selection=dynamic_source_selection,
                        clustering_method="mean",
                    )

                    true_accs = true_accuracies if true_accuracies else {}
                    print_method = self._get_print_method(eval_results_mean)
                    metrics_mean = self.calculate_correlation_metrics(eval_results_mean[print_method], true_accs)

                    metrics_runs_mean.append(metrics_mean)
                    anchor_questions_mean_list.append(anchor_questions_mean)
                    eval_results_mean_runs.append(eval_results_mean)

                mean_mean_data = self._aggregate_metrics(metrics_runs_mean)
                mean_mean, method_std_mean = mean_mean_data["mean"], mean_mean_data["method_std"]

                anchor_size_results[n_anchors] = {
                    "mean": {
                        "overall_stats": {
                            "overall_avg_kendall_tau": mean_mean["kendall_tau"],
                            "overall_method_std_kendall_tau": method_std_mean["kendall_tau"],
                            "overall_avg_pearson_r": mean_mean["pearson_r"],
                            "overall_method_std_pearson_r": method_std_mean["pearson_r"],
                            "overall_avg_spearman_r": mean_mean["spearman_r"],
                            "overall_method_std_spearman_r": method_std_mean["spearman_r"],
                            "overall_avg_mae": mean_mean["mae"],
                            "overall_method_std_mae": method_std_mean["mae"],
                            "overall_avg_rmse": mean_mean["rmse"],
                            "overall_method_std_rmse": method_std_mean["rmse"],
                            # 输出 Agreement 统计
                            "overall_avg_agreement": mean_mean.get("agreement", 0.0),
                            "overall_method_std_agreement": method_std_mean.get("agreement", 0.0),
                        },
                        "runs_data": {
                            "anchor_questions": anchor_questions_mean_list,
                            "evaluation_results": eval_results_mean_runs,
                            "metrics_per_run": metrics_runs_mean,
                        },
                    },
                }

                console.print(f"[cyan]combo{combo_idx}-锚点数量 {n_anchors} [mean]:[/cyan]")
                console.print(f"    Kendall Tau: {mean_mean['kendall_tau']:.3f} (method_std: {method_std_mean['kendall_tau']:.3f})")
                console.print(f"    Pearson r: {mean_mean['pearson_r']:.3f} (method_std: {method_std_mean['pearson_r']:.3f})")
                console.print(f"    Spearman r: {mean_mean['spearman_r']:.3f} (method_std: {method_std_mean['spearman_r']:.3f})")
                console.print(f"    MAE: {mean_mean['mae']:.3f} (method_std: {method_std_mean['mae']:.3f})")
                console.print(f"    RMSE: {mean_mean['rmse']:.3f} (method_std: {method_std_mean['rmse']:.3f})")
                # 打印 Agreement
                console.print(f"    Agreement: {mean_mean.get('agreement', 0.0):.3f} (method_std: {method_std_mean.get('agreement', 0.0):.3f})")

                # -------------------------
                # concat 表征进行聚类与评估（如果启用/可用）
                # -------------------------
                if question_reps_concat is not None:
                    metrics_runs_concat = []
                    anchor_questions_concat_list = []
                    eval_results_concat_runs = []

                    for run_idx in range(num_runs_per_anchor):
                        # if num_runs_per_anchor > 1:
                        #     console.print(f"[yellow]  [concat] 运行 {run_idx + 1}/{num_runs_per_anchor}[/yellow]")
                        pass

                        cluster_dict_concat, kmeans_model_concat = self.cluster_questions(question_reps_concat, n_anchors)

                        anchor_questions_concat = self.select_anchor_questions(
                            question_reps_concat,
                            cluster_dict_concat,
                            kmeans_model_concat,
                            n_anchors,
                        )

                        # ===== predict_all 模式：映射锚点索引并使用全量数据 =====
                        if exclude_extreme_samples_predict_all and valid_indices is not None:
                            anchor_questions_concat_mapped = [int(valid_indices[i]) for i in anchor_questions_concat]
                            cluster_dict_concat_mapped = {}
                            for cid, members in cluster_dict_concat.items():
                                cluster_dict_concat_mapped[cid] = [int(valid_indices[i]) for i in members]

                            eval_train_data_c = self._train_models_data_full
                            eval_reference_ids_c = self._reference_ids_full
                            eval_anchor_questions_c = anchor_questions_concat_mapped
                            eval_cluster_dict_c = cluster_dict_concat_mapped
                        else:
                            eval_train_data_c = train_models_data
                            eval_reference_ids_c = reference_ids
                            eval_anchor_questions_c = anchor_questions_concat
                            eval_cluster_dict_c = cluster_dict_concat

                        eval_results_concat = self.evaluate_on_test_set(
                            eval_anchor_questions_c,
                            test_models_data_full,
                            train_models_data=eval_train_data_c,
                            cluster_dict=eval_cluster_dict_c,
                            reference_ids=eval_reference_ids_c,
                            strict_id_alignment=strict_id_alignment,
                            dynamic_source_selection=dynamic_source_selection,
                            clustering_method="concat",
                        )

                        true_accs = true_accuracies if true_accuracies else {}
                        print_method = self._get_print_method(eval_results_concat)
                        metrics_concat = self.calculate_correlation_metrics(eval_results_concat[print_method], true_accs)

                        metrics_runs_concat.append(metrics_concat)
                        anchor_questions_concat_list.append(anchor_questions_concat)
                        eval_results_concat_runs.append(eval_results_concat)

                    concat_mean_data = self._aggregate_metrics(metrics_runs_concat)
                    mean_concat, method_std_concat = concat_mean_data["mean"], concat_mean_data["method_std"]

                    # 注意：保持mean结果结构不变，concat结果新增到 concat 下
                    anchor_size_results[n_anchors]["concat"] = {
                        "overall_stats": {
                            "overall_avg_kendall_tau": mean_concat["kendall_tau"],
                            "overall_method_std_kendall_tau": method_std_concat["kendall_tau"],
                            "overall_avg_pearson_r": mean_concat["pearson_r"],
                            "overall_method_std_pearson_r": method_std_concat["pearson_r"],
                            "overall_avg_spearman_r": mean_concat["spearman_r"],
                            "overall_method_std_spearman_r": method_std_concat["spearman_r"],
                            "overall_avg_mae": mean_concat["mae"],
                            "overall_method_std_mae": method_std_concat["mae"],
                            "overall_avg_rmse": mean_concat["rmse"],
                            "overall_method_std_rmse": method_std_concat["rmse"],
                            "overall_avg_agreement": mean_concat.get("agreement", 0.0),
                            "overall_method_std_agreement": method_std_concat.get("agreement", 0.0),
                        },
                        "runs_data": {
                            "anchor_questions": anchor_questions_concat_list,
                            "evaluation_results": eval_results_concat_runs,
                            "metrics_per_run": metrics_runs_concat,
                        },
                    }

                # -------------------------
                # fused 表征进行聚类与评估
                # -------------------------
                if question_reps_fused is not None:
                    metrics_runs_fused = []
                    anchor_questions_fused_list = []
                    eval_results_fused_runs = []

                    for run_idx in range(num_runs_per_anchor):
                        # if num_runs_per_anchor > 1:
                        #     console.print(f"[yellow]  [fused] 运行 {run_idx + 1}/{num_runs_per_anchor}[/yellow]")
                        pass

                        cluster_dict_fused, kmeans_model_fused = self.cluster_questions(question_reps_fused, n_anchors)

                        anchor_questions_fused = self.select_anchor_questions(
                            question_reps_fused,
                            cluster_dict_fused,
                            kmeans_model_fused,
                            n_anchors,
                        )

                        # ===== predict_all 模式：映射锚点索引并使用全量数据 =====
                        if exclude_extreme_samples_predict_all and valid_indices is not None:
                            anchor_questions_fused_mapped = [int(valid_indices[i]) for i in anchor_questions_fused]
                            cluster_dict_fused_mapped = {}
                            for cid, members in cluster_dict_fused.items():
                                cluster_dict_fused_mapped[cid] = [int(valid_indices[i]) for i in members]

                            eval_train_data_f = self._train_models_data_full
                            eval_reference_ids_f = self._reference_ids_full
                            eval_anchor_questions_f = anchor_questions_fused_mapped
                            eval_cluster_dict_f = cluster_dict_fused_mapped
                        else:
                            eval_train_data_f = train_models_data
                            eval_reference_ids_f = reference_ids
                            eval_anchor_questions_f = anchor_questions_fused
                            eval_cluster_dict_f = cluster_dict_fused

                        eval_results_fused = self.evaluate_on_test_set(
                            eval_anchor_questions_f,
                            test_models_data_full,
                            train_models_data=eval_train_data_f,
                            cluster_dict=eval_cluster_dict_f,
                            reference_ids=eval_reference_ids_f,
                            strict_id_alignment=strict_id_alignment,
                            dynamic_source_selection=dynamic_source_selection,
                            clustering_method="fused",
                        )

                        true_accs = true_accuracies if true_accuracies else {}
                        print_method = self._get_print_method(eval_results_fused)
                        metrics_fused = self.calculate_correlation_metrics(eval_results_fused[print_method], true_accs)

                        metrics_runs_fused.append(metrics_fused)
                        anchor_questions_fused_list.append(anchor_questions_fused)
                        eval_results_fused_runs.append(eval_results_fused)

                    fused_mean_data = self._aggregate_metrics(metrics_runs_fused)
                    mean_fused, method_std_fused = fused_mean_data["mean"], fused_mean_data["method_std"]

                    anchor_size_results[n_anchors]["fused"] = {
                        "overall_stats": {
                            "overall_avg_kendall_tau": mean_fused["kendall_tau"],
                            "overall_method_std_kendall_tau": method_std_fused["kendall_tau"],
                            "overall_avg_pearson_r": mean_fused["pearson_r"],
                            "overall_method_std_pearson_r": method_std_fused["pearson_r"],
                            "overall_avg_spearman_r": mean_fused["spearman_r"],
                            "overall_method_std_spearman_r": method_std_fused["spearman_r"],
                            "overall_avg_mae": mean_fused["mae"],
                            "overall_method_std_mae": method_std_fused["mae"],
                            "overall_avg_rmse": mean_fused["rmse"],
                            "overall_method_std_rmse": method_std_fused["rmse"],
                            "overall_avg_agreement": mean_fused.get("agreement", 0.0),
                            "overall_method_std_agreement": method_std_fused.get("agreement", 0.0),
                        },
                        "runs_data": {
                            "anchor_questions": anchor_questions_fused_list,
                            "evaluation_results": eval_results_fused_runs,
                            "metrics_per_run": metrics_runs_fused,
                        },
                    }

            return {
                "combo_idx": combo_idx,
                "train_models": train_models,
                "test_models": test_models,
                "anchor_size_results": anchor_size_results,
                "svd_ranks": {
                    "per_model": per_model_svd_ranks if self.svd_for_clustering else None,
                    "avg_rank": float(np.mean(list(per_model_svd_ranks.values()))) if self.svd_for_clustering and per_model_svd_ranks else None,
                    "min_rank": int(min(per_model_svd_ranks.values())) if self.svd_for_clustering and per_model_svd_ranks else None,
                    "max_rank": int(max(per_model_svd_ranks.values())) if self.svd_for_clustering and per_model_svd_ranks else None,
                } if self.svd_for_clustering else None,
            }

        except Exception as e:
            console.print(f"[red]处理组合 {combo_idx} 失败: {e}[/red]")
            import traceback
            traceback.print_exc()
            return None

    def run_experiment_parallel(
            self,
            layer_type: str = "last",
            operate_type: str = "last_token",
            anchor_points: List[int] = None,
            num_workers: int = 20,
            strict_id_alignment: bool = True,
            num_runs_per_anchor: int = 1,
            test_features_for_similarity: bool = False,
            exclude_extreme_samples: bool = False,
            exclude_extreme_samples_predict_all: bool = False,
            dynamic_source_selection: bool = False,
            svd_exclude_dominant_ratio: float = None,
            svd_spectral_shrinkage: bool = False,
            svd_specific_components: List[int] = None,
            args=None,
            use_zscore: bool = False,
            use_euclidean: bool = False,
            ridge_alpha: float = 1.0,
            decorr_lambda: float = 0.0,
            model_strategy: str = None,
            family: str = None,
            strength: str = None,
        ) -> None:
            if anchor_points is None:
                anchor_points = [10, 15, 20]

            from baseline import (
                save_multi_method_results_to_excel,
                get_train_test_splits_with_strategy,
                load_model_true_accuracies,
            )

            console.print(f"[bold cyan]开始并行运行 {self.dataset} 数据集的 MLP 外推变体实验[/bold cyan]")
            console.print(f"[green]使用 {num_workers} 个并行进程[/green]")

            label_matrix_path = os.path.join(self.label_data_dir, f"label_matrix_{self.dataset}.json")

            # 使用统一的策略处理接口（支持family/strength参数）
            splits, metadata = get_train_test_splits_with_strategy(
                label_matrix_path=label_matrix_path,
                model_strategy=model_strategy,
                num_models=self.num_models,
                dataset=self.dataset,
                family=family,
                strength=strength,
            )

            # true accuracies
            try:
                true_accuracies_path = os.path.join(self.accuracy_data_dir, f"{self.dataset}_main_experiment_models.xlsx")
                true_accuracies = load_model_true_accuracies(true_accuracies_path)
            except Exception as e:
                console.print(f"[yellow]警告: 无法加载真实准确率数据: {e}[/yellow]")
                true_accuracies = {}

            tasks = []
            for combo_idx, (train_labels_dict, test_labels_dict, train_models, test_models) in enumerate(splits):
                task_args = (
                    combo_idx,
                    train_labels_dict,
                    test_labels_dict,
                    train_models,
                    test_models,
                    self.dataset,
                    self.feats_dir,
                    self.output_dir,
                    self.device,
                    layer_type,
                    operate_type,
                    anchor_points,
                    true_accuracies,
                    self.label_data_dir,
                    self.use_expand,
                    strict_id_alignment,
                    None,  # gpu_id (后面会被覆盖)
                    num_runs_per_anchor,
                    test_features_for_similarity,
                    exclude_extreme_samples,
                    exclude_extreme_samples_predict_all,
                    dynamic_source_selection,
                    self.lambda_max,
                    self.num_models,
                    self.svd_retain,
                    self.svd_for_clustering,
                    svd_exclude_dominant_ratio,
                    self.svd_spectral_shrinkage,
                    svd_specific_components,
                    use_zscore,
                    use_euclidean,
                    ridge_alpha,
                    decorr_lambda,
                    self.model_strategy,
                    family,
                    strength,
                )
                tasks.append(task_args)

            num_gpus = torch.cuda.device_count()
            if num_gpus == 0 and self.device.startswith("cuda"):
                raise RuntimeError("未检测到可用GPU")

            # 为工作进程分配GPU
            gpu_ids = [i % max(1, num_gpus) for i in range(num_workers)] if self.device.startswith("cuda") else [None] * num_workers

            # 预加载共享特征
            console.print("[yellow]预加载特征数据到共享内存...[/yellow]")
            shared_feature_data = self._preload_shared_features(splits, layer_type, operate_type)
            console.print(f"[green]共享内存数据预加载完成，包含 {len(shared_feature_data)} 个模型[/green]")

            ctx = mp.get_context("spawn")
            pid_q = ctx.Queue()
            child_pids = set()

            all_results = []
            completed = 0

            def _kill_children(pids):
                for pid in pids:
                    try:
                        os.kill(pid, signal.SIGTERM)
                    except Exception:
                        pass
                time.sleep(1.0)
                for pid in pids:
                    try:
                        os.kill(pid, signal.SIGKILL)
                    except Exception:
                        pass

            executor = None
            failed = False
            try:
                executor = ProcessPoolExecutor(
                    max_workers=num_workers,
                    mp_context=ctx,
                    initializer=_child_init_with_shared_local,
                    initargs=(pid_q, shared_feature_data),
                )

                # collect pids
                deadline = time.time() + 5
                while len(child_pids) < num_workers and time.time() < deadline:
                    try:
                        child_pids.add(pid_q.get(timeout=0.2))
                    except Exception:
                        pass

                future_to_task = {}
                for i, task in enumerate(tasks):
                    task = list(task)
                    # 修复: 由于添加了新参数，gpu_id 现在是正向索引 16 (或者倒数第 10 个)
                    # 原来的 -9 会错误覆盖 num_runs_per_anchor
                    task[16] = gpu_ids[i % len(gpu_ids)] if gpu_ids else None
                    task = tuple(task)
                    fut = executor.submit(_process_combo_task_with_shared_variants, task)
                    future_to_task[fut] = task

                for fut in as_completed(future_to_task):
                    while True:
                        try:
                            child_pids.add(pid_q.get_nowait())
                        except Exception:
                            break

                    try:
                        combo_idx, combo_result = fut.result()
                        if combo_result is not None:
                            all_results.append(combo_result)
                            completed += 1
                            console.print(
                                f"[green]进度: {completed}/{len(tasks)} ({completed/len(tasks)*100:.1f}%) - 组合{combo_idx}完成[/green]"
                            )
                        else:
                            console.print(f"[red]任务失败 - 组合{combo_idx}[/red]")
                    except Exception as e:
                        failed = True
                        t = future_to_task[fut]
                        console.print(f"[red]任务异常 - 组合{t[0]}: {e}[/red]")

            except KeyboardInterrupt:
                failed = True
                console.print("[red]检测到键盘中断，正在清理...[/red]")
                if executor:
                    executor.shutdown(wait=False, cancel_futures=True)
                _kill_children(child_pids)
                raise
            finally:
                if executor:
                    try:
                        executor.shutdown(wait=not failed, cancel_futures=failed)
                    except Exception:
                        pass
                if failed and child_pids:
                    _kill_children(child_pids)

            console.print(f"[green]所有并行任务完成! 成功: {len(all_results)}/{len(splits)} 个组合[/green]")
            all_results.sort(key=lambda x: x["combo_idx"])

            # 注意：不传入 metric 参数，使用父类默认值 "cosine"
            # 子类的 _generate_experiment_dir_name 会根据 use_euclidean 追加 _l2 或 _cos
            # 确定实验目录名中的极端样本标记（两种模式都需要标记）
            exclude_flag_for_dir = exclude_extreme_samples or exclude_extreme_samples_predict_all
            experiment_dir_name = self._generate_experiment_dir_name(
                exclude_extreme_samples=exclude_flag_for_dir,
                anchor_points=anchor_points,
                layer_type=layer_type,
                operate_type=operate_type,
            )
            # 如果使用 predict_all 模式，在目录名中添加标记
            # 注意：base类生成的目录名中包含 "exclude" (无下划线)，这里替换为 "exclude_predict_all"
            if exclude_extreme_samples_predict_all:
                experiment_dir_name = experiment_dir_name.replace("exclude", "exclude_predict_all")

            # 在 output_dir 中体现 logistic_C（与 logits 版本保持一致）
            # output_dir 已经在 main() 中包含了 C 值，这里直接使用
            timestamp_dir = os.path.join(self.output_dir, self.dataset, experiment_dir_name)
            os.makedirs(timestamp_dir, exist_ok=True)

            base_excel_filename = f"{self.dataset}.xlsx"
            base_excel_path = os.path.join(timestamp_dir, base_excel_filename)

            # 构建保存的方法列表
            # clustering_methods = ["mean", "concat"]
            clustering_methods = ["mean", "concat", "fused"]

            # 保留这些基线 + 所有变体键
            extrapolation_method_keys = [
                # ("naive_mean", "naive_mean"),  # 注释掉简单平均值基线
                ("apw_weighted", "weight_mean"),
                # ("tailoredbench_scaling", "tailoredbench"),  # 注释掉，该方法未实际运行
            ]
            extrapolation_method_keys += [(k, k) for k in self._variant_method_keys()]

            multi_method_results: Dict[str, Any] = {}
            methods_to_save: List[str] = []

            for clustering_method in clustering_methods:
                for method_key, method_name in extrapolation_method_keys:
                    # 如果方法存在，构建标准结果
                    method_results = self._build_standard_results_from_method(
                        all_results,
                        method_key,
                        anchor_points,
                        num_runs_per_anchor,
                        layer_type,
                        operate_type,
                        metadata,
                        clustering_method=clustering_method,
                    )
                    if method_results:
                        result_key = f"{clustering_method}_{method_name}"
                        # Excel工作表名称限制保护（31字符）
                        if len(result_key) > 31:
                            result_key = result_key[:31]
                        multi_method_results[result_key] = method_results
                        methods_to_save.append(result_key)

            if methods_to_save:
                save_status = save_multi_method_results_to_excel(
                    multi_method_results,
                    base_excel_path,
                    methods_to_save,
                )
                ok = [k for k, v in save_status.items() if v]
                console.print(f"[green]✅ 成功保存 {len(ok)} 个方法到: {timestamp_dir}[/green]")
            else:
                console.print("[yellow]没有可保存的方法结果[/yellow]")

            # ============================================================
            # 保存精简版JSON文件（只保存combo的mean聚类方法结果）
            # ============================================================
            import json
            simplified_json_path = os.path.join(timestamp_dir, f"{base_excel_filename}.json")

            # 只保存每个combo的mean聚类方法的Spearman和MAE均值
            simplified_combo_results = []
            for combo_result in all_results:
                simplified_combo = {
                    "combo_idx": combo_result["combo_idx"],
                    "train_models": combo_result["train_models"],
                }

                # 只保留mean聚类方法，移除concat和fused
                simplified_anchor_results = {}
                for anchor_size, anchor_data in combo_result.get("anchor_size_results", {}).items():
                    # 只保存mean方法
                    if "mean" in anchor_data and "overall_stats" in anchor_data["mean"]:
                        stats = anchor_data["mean"]["overall_stats"]
                        simplified_anchor_results[anchor_size] = {
                            "spearman_r_avg": stats.get("overall_avg_spearman_r"),
                            "mae_avg": stats.get("overall_avg_mae"),
                        }

                simplified_combo["anchor_size_results"] = simplified_anchor_results
                simplified_combo_results.append(simplified_combo)

            simplified_json_data = {
                "dataset": self.dataset,
                "num_samples": metadata.get("num_samples", 0) if "metadata" in locals() else 0,
                "num_combos": len(all_results),
                "anchor_points": anchor_points,
                "num_runs_per_anchor": num_runs_per_anchor,
                "clustering_method": "mean",
                "combo_results": simplified_combo_results,
            }

            with open(simplified_json_path, "w", encoding="utf-8") as f:
                json.dump(simplified_json_data, f, indent=2, ensure_ascii=False)
            console.print(f"[green]✅ 精简JSON结果已保存到: {simplified_json_path}[/green]")

            # 保存实验配置文件
            if args is not None:
                config_data = {
                    "dataset": args.dataset,
                    "feats_dir": args.feats_dir,
                    "output_dir": args.output_dir,
                    "device": args.device,
                    "layer_type": layer_type,
                    "operate_type": operate_type,
                    "anchor_points": anchor_points,
                    "num_workers": args.num_workers,
                    "strict_id_alignment": strict_id_alignment,
                    "num_runs_per_anchor": num_runs_per_anchor,
                    "label_data_dir": args.label_data_dir,
                    "exclude_extreme_samples": exclude_extreme_samples,
                    "exclude_extreme_samples_predict_all": exclude_extreme_samples_predict_all,
                    "test_features_for_similarity": test_features_for_similarity,
                    "dynamic_source_selection": dynamic_source_selection,
                    "metric": "cosine",
                    "experiment_timestamp": experiment_dir_name,
                    "total_combos": len(all_results),
                    "successful_methods": methods_to_save,
                    "svd_for_clustering": self.svd_for_clustering,
                    "svd_retain": self.svd_retain,
                    "ridge_alpha": ridge_alpha,
                    "decorr_lambda": decorr_lambda,
                    "use_zscore": use_zscore,
                    "use_euclidean": use_euclidean,
                }

                config_filename = f"{base_excel_filename}_config.json"
                config_path = os.path.join(timestamp_dir, config_filename)

                with open(config_path, "w", encoding="utf-8") as f:
                    json.dump(config_data, f, indent=2, ensure_ascii=False)
                console.print(f"[green]✅ 实验配置已保存到: {config_path}[/green]")

            # 保存SVD rank信息到JSON文件
            if self.svd_for_clustering:
                svd_ranks_data = []
                for result in all_results:
                    if result.get("svd_ranks") is not None:
                        svd_ranks_data.append({
                            "combo_idx": result["combo_idx"],
                            "train_models": result["train_models"],
                            "per_model_ranks": result["svd_ranks"].get("per_model"),
                            "avg_rank": result["svd_ranks"].get("avg_rank"),
                            "min_rank": result["svd_ranks"].get("min_rank"),
                            "max_rank": result["svd_ranks"].get("max_rank"),
                        })

                if svd_ranks_data:
                    svd_ranks_path = os.path.join(timestamp_dir, f"{self.dataset}_svd_ranks.json")
                    import json
                    with open(svd_ranks_path, "w", encoding="utf-8") as f:
                        json.dump({
                            "dataset": self.dataset,
                            "svd_retain": self.svd_retain,
                            "svd_exclude_dominant_ratio": self.svd_exclude_dominant_ratio,
                            "svd_mode": "per_model_before_aggregation",
                            "num_combos": len(svd_ranks_data),
                            "combos": svd_ranks_data,
                        }, f, indent=2, ensure_ascii=False)
                    console.print(f"[green]✅ 成功保存SVD rank信息到: {svd_ranks_path}[/green]")


def main():
    overall_start = time.time()
    console.print("[bold green]=== MLP 外推变体实验开始 ===[/bold green]")
    console.print(f"[cyan]开始时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(overall_start))}[/cyan]")

    try:
        mp.set_start_method("spawn", force=True)
    except RuntimeError:
        pass

    import argparse

    parser = argparse.ArgumentParser(description="HSBench MLP 外推变体实验（支持多种模型选择策略）")
    parser.add_argument("--dataset", type=str, default="mmlu_pro")
    parser.add_argument(
        "--model_strategy",
        type=str,
        default=None,
        choices=["family_diverse", "strength", "temporal_shift"],
        help="模型选择策略 (可选): "
             "- family_diverse: 模型族多样性组合（3-4个combo） "
             "- strength: 强弱组合（3个combo：最强/中间/最弱） "
             "- temporal_shift: 时序漂移组合（10个combo，最近30个模型作为训练池） "
             "默认: None（使用 train_model_select_num_ablation 的配置）"
    )
    parser.add_argument(
        "--num_models",
        type=int,
        default=10,
        choices=[5, 10, 15, 20],
        help="每个combo的模型数量（用于消融实验，默认: 10）"
    )
    parser.add_argument(
        "--family",
        type=str,
        default=None,
        choices=["llama", "qwen", "yi", "granite", "internvl", "ovis", "llava"],
        help="模型族名称（仅当model_strategy='family_diverse'时有效）"
    )
    parser.add_argument(
        "--strength",
        type=str,
        default=None,
        choices=["strong", "middle", "weak"],
        help="强度等级（仅当model_strategy='strength'时有效）"
    )
    parser.add_argument("--feats_dir", type=str, default=os.path.join(PROJECT_ROOT, "feats"))
    parser.add_argument("--output_dir", type=str, default=None)
    parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
    parser.add_argument("--layer_type", type=str, default="last", choices=["quarter", "middle", "three_quarters", "last", "second_last", "first", "all"], help="隐状态层")
    parser.add_argument("--operate_type", type=str, default="last_token", choices=["prompt_last_token", "answer_first_token", "last_token"], help="特征操作类型")
    parser.add_argument("--anchor_points", type=int, nargs="+", default=[10, 15, 20, 25, 30, 35, 40, 50, 100, 150, 300])
    parser.add_argument("--num_workers", type=int, default=10)
    parser.add_argument("--strict_id_alignment", action="store_true", default=True)
    parser.add_argument("--num_runs_per_anchor", type=int, default=50)
    parser.add_argument("--test_features_for_similarity", action="store_true", default=False)
    parser.add_argument("--label_data_dir", type=str, default=None)
    parser.add_argument("--exclude_extreme_samples", action="store_true", default=False,
                        help="排除极端样本（全对/全错）用于聚类和评估，结果做R/E加权合并")
    parser.add_argument("--exclude_extreme_samples_predict_all", action="store_true", default=False,
                        help="排除极端样本仅用于聚类选锚点，但外推时使用全量数据预测，不做R/E加权合并")
    parser.add_argument("--dynamic_source_selection", action="store_true", default=True)
    parser.add_argument("--use_expand", action="store_true", default=False, help="是否使用expand loss训练的MLP模型（决定从mlp_models还是mlp_models_expand目录加载）(默认: False)")
    parser.add_argument("--lambda_max", type=float, default=None, help="Expand loss的最大权重 (仅在使用use_expand时有效，用于区分不同超参数组合)")
    parser.add_argument("--svd_retain", type=float, default=0.95, help="SVD保留的累计解释方差比例，用于聚类前的embedding降维(默认: 0.95)")
    parser.add_argument(
        "--svd_for_clustering",
        action="store_true",
        default=False,
        help="是否在聚类前对embedding做SVD降维(默认: False，不使用SVD). 使用此标志开启SVD降维",
    )
    parser.add_argument(
        "--svd_exclude_dominant_ratio",
        type=float,
        default=None,
        help="SVD特殊模式：排除解释方差比例大于此阈值的分量，保留其他所有分量(例如0.5)。默认None(不启用)",
    )
    parser.add_argument(
        "--svd_spectral_shrinkage",
        action="store_true",
        default=False,
        help="SVD特殊模式：谱收缩 (Spectral Shrinkage)，根据能量占比自动削弱主成分，不进行硬截断。默认False",
    )
    parser.add_argument(
        "--svd_specific_components",
        type=str,
        default=None,
        help="SVD特殊模式：指定使用的SVD分量（1-based index），例如 '1' 或 '1,3'。默认None",
    )
    parser.add_argument(
        "--use_zscore",
        action="store_true",
        default=False,
        help="是否在聚类/SVD前进行Feature-wise Z-score标准化 (mean=0, std=1)。默认False",
    )
    parser.add_argument(
        "--use_euclidean",
        action="store_true",
        default=False,
        help="是否使用欧氏距离进行聚类和锚点选择（默认使用Cosine相似度）。开启后将跳过最终的L2 Normalization。",
    )
    parser.add_argument(
        "--ridge_alpha",
        type=float,
        default=0.1,
        help="Ridge回归的正则化参数 (默认: 0.1)",
    )
    parser.add_argument(
        "--decorr_lambda",
        type=float,
        default=0.0,
        help="特征去相关权重，用于加载对应的MLP模型 (默认: 0.0)",
    )
    args = parser.parse_args()

    # 验证超参数
    # 0. exclude_extreme_samples 和 exclude_extreme_samples_predict_all 互斥
    if args.exclude_extreme_samples and args.exclude_extreme_samples_predict_all:
        raise ValueError("--exclude_extreme_samples 和 --exclude_extreme_samples_predict_all 互斥，只能选择其一")

    # 1. 如果提供了 lambda_max，必须使用 use_expand
    if args.lambda_max is not None and not args.use_expand:
        raise ValueError("指定了 --lambda_max 时，必须同时指定 --use_expand")

    # 2. 如果使用 use_expand，必须提供 lambda_max
    if args.use_expand:
        if args.lambda_max is None:
            raise ValueError("使用 --use_expand 时必须提供 --lambda_max 参数")
        # 验证数值范围
        if args.lambda_max <= 0:
            raise ValueError(f"lambda_max 必须大于 0，当前值: {args.lambda_max}")

    # 3. SVD retain 范围检查
    if not (0.0 < float(args.svd_retain) <= 1.0):
        raise ValueError(f"svd_retain must be in (0,1], got {args.svd_retain}")

    svd_specific_components_list = None
    if args.svd_specific_components:
        try:
            # 解析逗号分隔的字符串
            svd_specific_components_list = [int(x.strip()) for x in args.svd_specific_components.split(",")]
            # 确保都是正整数
            if any(x <= 0 for x in svd_specific_components_list):
                raise ValueError("SVD components must be positive integers (1-based)")
        except ValueError as e:
            raise ValueError(f"Invalid format for --svd_specific_components: {e}")

    # 动态设置默认输出目录
    if args.output_dir is None:
        # 构建基础路径，包含模型数量信息
        base_dir = "main_experiment/results/hs_experiment_linear_embedding_exploration"

        # 支持family/strength模式（与baseline对齐）
        if args.family is not None and args.model_strategy == 'family_diverse':
            # family_diverse模式：添加strategy后缀，但不立即添加family子目录
            # ridge_alpha等超参数应该添加在strategy之后，family子目录之前
            base_dir += f"_{args.model_strategy}"
        elif args.strength is not None and args.model_strategy == 'strength':
            # strength模式：同样先添加strategy后缀
            base_dir += f"_{args.model_strategy}"
        else:
            # 添加模型策略后缀（如果使用特殊策略）
            if args.model_strategy is not None:
                base_dir += f"_{args.model_strategy}"

            # 添加模型数量后缀（消融实验）
            if args.num_models is not None and args.num_models != 10:
                base_dir += f"_{args.num_models}models"

        # 添加超参数后缀（所有模式通用）
        if args.use_expand:
            # 使用 .3g 格式化浮点数，保留3位有效数字
            lambda_str = f"{args.lambda_max:.3g}"
            base_dir = f"{base_dir}_expand_loss/lambda_{lambda_str}"
        elif args.decorr_lambda > 0:
            lambda_str = f"{args.decorr_lambda:.3g}"
            # 在输出目录中包含 ridge_alpha 和 decorr_lambda 信息
            alpha_str = f"{args.ridge_alpha}"
            base_dir = f"{base_dir}_ridge_alpha{alpha_str}_decorr_loss/lambda_{lambda_str}"
        else:
            # 在输出目录中包含 ridge_alpha 信息，以便区分实验
            alpha_str = f"{args.ridge_alpha}"
            base_dir = f"{base_dir}_ridge_alpha{alpha_str}"

        # 对于family/strength模式，现在添加family/strength子目录
        if args.family is not None and args.model_strategy == 'family_diverse':
            base_dir = f"{base_dir}/{args.family}"
        elif args.strength is not None and args.model_strategy == 'strength':
            base_dir = f"{base_dir}/{args.strength}"

        args.output_dir = os.path.join(PROJECT_ROOT, base_dir)

    args.feats_dir = os.path.join(PROJECT_ROOT, f"feats/{args.dataset}/labeled_data")

    exp = HSExperimentMLPVariants(
        dataset=args.dataset,
        feats_dir=args.feats_dir,
        output_dir=args.output_dir,
        device=args.device,
        label_data_dir=args.label_data_dir,
        use_expand=args.use_expand,
        lambda_max=args.lambda_max,
        num_models=args.num_models,
        svd_retain=args.svd_retain,
        svd_for_clustering=args.svd_for_clustering,
        svd_exclude_dominant_ratio=args.svd_exclude_dominant_ratio,
        svd_spectral_shrinkage=args.svd_spectral_shrinkage,
        svd_specific_components=svd_specific_components_list,
        use_zscore=args.use_zscore,
        use_euclidean=args.use_euclidean,
        ridge_alpha=args.ridge_alpha,
        decorr_lambda=args.decorr_lambda,
        model_strategy=args.model_strategy,
    )

    exp.run_experiment_parallel(
        layer_type=args.layer_type,
        operate_type=args.operate_type,
        anchor_points=args.anchor_points,
        num_workers=args.num_workers,
        strict_id_alignment=args.strict_id_alignment,
        num_runs_per_anchor=args.num_runs_per_anchor,
        test_features_for_similarity=args.test_features_for_similarity,
        exclude_extreme_samples=args.exclude_extreme_samples,
        exclude_extreme_samples_predict_all=args.exclude_extreme_samples_predict_all,
        dynamic_source_selection=args.dynamic_source_selection,
        svd_exclude_dominant_ratio=args.svd_exclude_dominant_ratio,
        svd_spectral_shrinkage=args.svd_spectral_shrinkage,
        svd_specific_components=svd_specific_components_list,
        args=args,
        use_zscore=args.use_zscore,
        use_euclidean=args.use_euclidean,
        ridge_alpha=args.ridge_alpha,
        decorr_lambda=args.decorr_lambda,
        model_strategy=args.model_strategy,
        family=args.family,
        strength=args.strength,
    )

    exp.cleanup_resources()
    total = time.time() - overall_start
    console.print("[bold green]=== MLP 外推变体实验结束 ===[/bold green]")
    console.print(f"[bold yellow]总耗时: {total/3600:.2f} h[/bold yellow]")


if __name__ == "__main__":
    def signal_handler(signum, _):
        console.print(f"\n[red]收到中断信号 {signum}[/red]")
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            try:
                import cupy as cp

                cp.get_default_memory_pool().free_all_blocks()
            except Exception:
                pass
        except Exception:
            pass
        sys.exit(1)

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)
    main()
