#!/usr/bin/env python3
"""
独立的MLP训练脚本，用于预先训练所有模型组合的MLP权重

这个脚本负责：
1. 加载所有模型组合的训练数据
2. 为每个组合训练对应的MLP模型
3. 保存训练好的模型权重，供主实验脚本使用

使用方法：
python train_mlp_models.py --dataset mmlu_pro --anchor_points 20 25 30 35 40
"""

import os
import sys
import json
import uuid
import shutil
import signal
import time
import argparse
import random
import numpy as np
import multiprocessing as mp
from typing import List, Dict, Tuple
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime

# 设置环境变量（需要在导入torch之前）
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

import torch
import torch.nn as nn
from rich.console import Console

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, ".."))

# 添加项目根目录到Python路径（必须在导入项目模块之前）
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

# 导入项目特定模块（需要在设置路径之后）
from baseline import get_train_test_splits
from src.train_utils.train_faclens_linear import load_model_data, train_dataset_model, save_results
from src.train_utils.save_config import save_run_metadata

console = Console()


torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


def _child_init(pid_q, gpu_id=None):
    """子进程初始化：把PID回传给父进程；设置父进程死我也死；分配GPU"""
    try:
        import ctypes

        libc = ctypes.CDLL("libc.so.6")
        PR_SET_PDEATHSIG = 1
        libc.prctl(PR_SET_PDEATHSIG, signal.SIGTERM)
    except Exception:
        pass

    # 设置GPU设备
    if gpu_id is not None and torch.cuda.is_available():
        torch.cuda.set_device(gpu_id)
        console.print(f"[green]Worker {os.getpid()} 绑定到 GPU {gpu_id}[/green]")

    pid_q.put(os.getpid())


def _process_training_task(args_tuple):
    """处理单个组合的训练任务"""
    (
        combo_idx,
        _train_labels_dict,
        _test_labels_dict,
        train_models,
        _test_models,
        dataset,
        feats_dir,
        output_dir,
        device,
        layer_type,
        operate_type,
        n_clusters,
        batch_size,
        lr,
        weight_decay,
        max_samples,
        expand_yaml,
        force_retrain,
        gpu_id,  # 新增GPU ID参数
        use_expand,  # 新增use_expand参数
        lambda_max,  # 新增lambda_max参数
        num_models,  # 新增num_models参数
        difficulty_batching,  # 难度感知批处理策略
        difficulty_seed,  # 难度批处理随机种子
        difficulty_ascending,  # 难度排序方向
    ) = args_tuple

    # 在子进程中设置环境变量，防止CUDA冲突
    os.environ["OMP_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"

    # 设置GPU设备
    if gpu_id is not None and torch.cuda.is_available():
        torch.cuda.set_device(gpu_id)
        actual_device = f"cuda:{gpu_id}"
    else:
        actual_device = device

    try:
        # 加载训练数据（保持完整结构，不按数据集索引）
        train_models_data = load_model_data(
            feats_dir,
            layer_type,
            [dataset],
            train_models,
            operate_type,
        )

        # 收集输入维度（需要实际加载特征数据）
        in_dims = set()
        for model_name, model_data in train_models_data.items():
            if dataset in model_data:
                train_path = model_data[dataset]["train"]
                if os.path.exists(train_path):
                    train_data = torch.load(train_path, map_location="cpu")
                    in_dims.add(train_data["features"].shape[1])
                    console.print(f"[green]Combo {combo_idx}: {model_name} 特征维度: {train_data['features'].shape[1]}[/green]")
                else:
                    console.print(f"[red]警告: 训练文件不存在 {train_path}[/red]")

        if not in_dims:
            console.print(f"[red]Combo {combo_idx}: 未找到有效的特征数据[/red]")
            return combo_idx, False, None

        in_dims = sorted(list(in_dims))

        # 生成模型保存路径（根据use_expand决定文件夹后缀和超参数子目录）
        target_model_path = get_mlp_model_path(output_dir, combo_idx, dataset, layer_type, operate_type, use_expand, lambda_max, num_models)
        os.makedirs(os.path.dirname(target_model_path), exist_ok=True)

        # 检查是否需要重新训练
        if os.path.exists(target_model_path) and not force_retrain:
            console.print(f"[green]Combo {combo_idx}: MLP模型已存在，跳过训练[/green]")
            return combo_idx, True, target_model_path

        if force_retrain:
            # 清理旧的模型文件和配置文件
            if os.path.exists(target_model_path):
                os.remove(target_model_path)
                console.print(f"[yellow]Combo {combo_idx}: 删除旧模型文件[/yellow]")

            # 删除对应的配置文件
            config_path = target_model_path.replace(".pth", "_config.json")
            if os.path.exists(config_path):
                os.remove(config_path)
                console.print(f"[yellow]Combo {combo_idx}: 删除旧配置文件[/yellow]")

            # 删除对应的expand loss图片
            img_path = target_model_path.replace(".pth", "_expand_curve.png")
            if os.path.exists(img_path):
                os.remove(img_path)
                console.print(f"[yellow]Combo {combo_idx}: 删除旧expand loss图片[/yellow]")

            # 删除对应的最小相似度对分布直方图
            min_sim_path = target_model_path.replace(".pth", "_min_similarity_pair_distribution.png")
            if os.path.exists(min_sim_path):
                os.remove(min_sim_path)
                console.print(f"[yellow]Combo {combo_idx}: 删除旧最小相似度对分布直方图[/yellow]")

            console.print(f"[yellow]Combo {combo_idx}: 强制重新训练[/yellow]")

        # 训练MLP模型
        console.print(f"[cyan]Combo {combo_idx}: 开始训练MLP模型 (GPU: {actual_device})[/cyan]")
        success = train_single_mlp(
            dataset,
            train_models_data,
            train_models,
            combo_idx,
            in_dims,
            target_model_path,
            output_dir,
            actual_device,
            layer_type,
            operate_type,
            n_clusters,
            batch_size,
            lr,
            weight_decay,
            max_samples,
            expand_yaml,
            lambda_max,
            difficulty_batching,
            difficulty_seed,
            difficulty_ascending,
        )

        if success:
            console.print(f"[green]Combo {combo_idx}: MLP模型训练完成[/green]")
            return combo_idx, True, target_model_path
        else:
            console.print(f"[red]Combo {combo_idx}: MLP模型训练失败[/red]")
            return combo_idx, False, None

    except Exception as e:
        console.print(f"[red]Combo {combo_idx} 训练失败: {str(e)}[/red]")
        import traceback

        traceback.print_exc()
        return combo_idx, False, None


def get_mlp_model_path(
    output_dir: str,
    combo_idx: int,
    dataset: str = None,
    layer_type: str = "last",
    operate_type: str = "last_token",
    use_expand: bool = False,
    lambda_max: float = None,
    num_models: int = None,
) -> str:
    """
    生成基于combo索引的MLP模型路径，使用combo[n]命名格式

    Args:
        output_dir: 输出目录
        combo_idx: 在candidate_models中的索引
        dataset: 数据集名称（可选，用于避免跨数据集冲突）
        layer_type: 隐状态层类型
        operate_type: 特征操作类型
        use_expand: 是否使用expand loss（决定文件夹是否带_expand后缀）
        lambda_max: Expand loss的最大权重（用于区分不同超参数组合）
        num_models: 模型数量（用于区分不同数量的模型组合，如 5/10/15/20）

    Returns:
        MLP模型文件路径
    """
    if dataset:
        # 包含数据集名称、layer_type和operate_type以避免冲突
        model_filename = f"combo{combo_idx}.pth"

        # 根据use_expand决定是否添加_expand后缀
        if use_expand:
            # 使用expand loss时，需要根据超参数创建子目录
            if lambda_max is None:
                raise ValueError("使用 use_expand=True 时，必须提供 lambda_max 参数")

            # 使用 .3g 格式化浮点数，保留3位有效数字
            lambda_str = f"{lambda_max:.3g}"

            # 路径: mlp_models_expand/lambda_X/dataset/layer_type/operate_type/comboN.pth
            console.print(f"[blue]生成MLP模型路径: 数据集={dataset}, use_expand=True, lambda={lambda_str}[/blue]")
            return os.path.join(
                output_dir,
                "mlp_models_expand",
                f"lambda_{lambda_str}",
                dataset,
                layer_type,
                operate_type,
                model_filename,
            )
        else:
            # 普通训练
            base_dir = "mlp_models"
            # 如果指定了模型数量，添加后缀
            if num_models is not None:
                base_dir = f"mlp_models_{num_models}models"

            console.print(f"[blue]生成MLP模型路径: 数据集={dataset}, num_models={num_models}, use_expand=False[/blue]")
            return os.path.join(
                output_dir,
                base_dir,
                dataset,
                layer_type,
                operate_type,
                model_filename,
            )
    else:
        # 兼容性：如果未提供数据集名称，使用原有路径
        model_filename = f"combo{combo_idx}.pth"
        if use_expand:
            if lambda_max is None:
                raise ValueError("使用 use_expand=True 时，必须提供 lambda_max 参数")

            lambda_str = f"{lambda_max:.3g}"
            console.print(f"[blue]生成MLP模型路径: combo索引={combo_idx}, use_expand=True, lambda={lambda_str}[/blue]")
            return os.path.join(
                output_dir,
                "mlp_models_expand",
                f"lambda_{lambda_str}",
                model_filename
            )
        else:
            base_dir = "mlp_models"
            if num_models is not None:
                base_dir = f"mlp_models_{num_models}models"

            console.print(f"[blue]生成MLP模型路径: combo索引={combo_idx}, num_models={num_models}, use_expand=False[/blue]")
            return os.path.join(output_dir, base_dir, model_filename)


def load_trained_mlp(dataset: str, model_path: str, in_dims: List[int], device: str) -> nn.Module:
    """加载训练好的MLP模型"""
    # 根据数据集类型选择对应的模型类
    if dataset == "bbh":
        from src.bbh.train_linear import bbh_linear

        model = bbh_linear(target_dim=4096, in_dims=in_dims)
    elif dataset == "gsm8k":
        from src.gsm8k.train_linear import gsm8k_linear

        model = gsm8k_linear(target_dim=4096, in_dims=in_dims)
    elif dataset == "math":
        from src.math.train_linear import math_linear

        model = math_linear(target_dim=4096, in_dims=in_dims)
    elif dataset == "mmlu_pro":
        from src.mmlu_pro.train_linear import mmlu_pro_linear

        model = mmlu_pro_linear(target_dim=4096, in_dims=in_dims)
    elif dataset == "arc_challenge":
        from src.arc_challenge.train_linear import arc_challenge_linear # 先共用一下

        model = arc_challenge_linear(target_dim=4096, in_dims=in_dims)
    elif dataset == "seedbench_plus2":
        from src.seedbench_plus2.train_linear import seedbench_plus2_linear

        model = seedbench_plus2_linear(target_dim=4096, in_dims=in_dims)
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

    # 加载权重
    checkpoint = torch.load(model_path, map_location=device)
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)

    model.to(device)
    model.eval()
    return model


def train_single_mlp(
    dataset: str,
    train_models_data: Dict,
    train_models: List[str],
    combo_idx: int,
    in_dims: List[int],
    target_model_path: str,
    output_dir: str,
    device: str,
    layer_type: str = "last",
    operate_type: str = "last_token",
    n_clusters: int = 50,
    batch_size: int = 64,
    lr: float = 1e-3,
    weight_decay: float = 1e-5,
    max_samples: int = None,
    expand_yaml: str = "",
    lambda_max: float = None,
    difficulty_batching: str = 'none',
    difficulty_seed: int = 42,
    difficulty_ascending: bool = True,
) -> bool:
    """训练单个MLP模型"""
    try:
        # 根据数据集类型选择对应的模型类

        if dataset == "bbh":
            from src.bbh.train_linear import bbh_linear

            model = bbh_linear(target_dim=4096, in_dims=in_dims)
        elif dataset == "gsm8k":
            from src.gsm8k.train_linear import gsm8k_linear

            model = gsm8k_linear(target_dim=4096, in_dims=in_dims)
        elif dataset == "math":
            from src.math.train_linear import math_linear

            model = math_linear(target_dim=4096, in_dims=in_dims)
        elif dataset == "mmlu_pro":
            from src.mmlu_pro.train_linear import mmlu_pro_linear

            model = mmlu_pro_linear(target_dim=4096, in_dims=in_dims)
        elif dataset == "arc_challenge":
            from src.arc_challenge.train_linear import arc_challenge_linear

            model = arc_challenge_linear(target_dim=4096, in_dims=in_dims)
        elif dataset == "seedbench_plus2":
            from src.seedbench_plus2.train_linear import seedbench_plus2_linear

            model = seedbench_plus2_linear(target_dim=4096, in_dims=in_dims)
        else:
            raise ValueError(f"Unsupported dataset: {dataset}")

        # 创建唯一的临时训练输出目录
        unique_id = f"combo{combo_idx}_pid{os.getpid()}_{uuid.uuid4().hex[:8]}"
        temp_output_dir = os.path.join(output_dir, "temp_training", unique_id)
        os.makedirs(temp_output_dir, exist_ok=True)

        # 保存运行元信息
        class Args:
            def __init__(self):
                # save_run_metadata 函数期望的必需属性
                self.feats_dir = ""  # 置空，反正也没什么用
                self.output_dir = temp_output_dir
                self.device = device
                self.layer_type = layer_type
                self.operate_type = operate_type
                self.datasets = [dataset]  # 单个数据集作为列表传递
                self.batch_size = batch_size
                self.lr = lr
                self.weight_decay = weight_decay
                self.expand_yaml = expand_yaml

                # 其他可选属性
                self.level_filter = None
                self.max_samples = max_samples
                self.filter_incomplete = False

        args = Args()
        save_run_metadata(args, model, train_models, temp_output_dir)

        # 训练模型
        results = train_dataset_model(
            dataset,
            train_models_data,
            temp_output_dir,
            device,
            train_models,
            batch_size,
            lr,
            weight_decay,
            None,
            max_samples,
            False,
            model,
            expand_yaml=expand_yaml,
            difficulty_batching=difficulty_batching,
            difficulty_seed=difficulty_seed,
            difficulty_ascending=difficulty_ascending,
        )

        # 生成与 train_faclens_linear.save_results 一致的两个汇总 JSON（单数据集版本）
        # 说明：save_results 会先写入 JSON，再打印汇总；即使打印阶段报错，JSON 也通常已落盘
        try:
            save_results(args, {dataset: results})
        except Exception as e:
            console.print(f"[yellow]警告: 生成汇总json失败（将继续尝试打包拷贝）: {e}[/yellow]")

        # 获取最佳模型路径（与train_faclens_linear保持一致）
        best_model_path = os.path.join(temp_output_dir, dataset, f"best_{dataset}_model.pth")

        # 保存训练配置信息
        experiment_config = {
            "layer_type": layer_type,
            "operate_type": operate_type,
            # "n_clusters": n_clusters,
            "batch_size": batch_size,
            "lr": lr,
            "weight_decay": weight_decay,
            "max_samples": max_samples,
        }

        # 如果使用了 expand loss，记录超参数信息
        expand_config = None
        if expand_yaml and os.path.exists(expand_yaml):
            expand_config = {
                "yaml_path": expand_yaml,
                "lambda_max": None,  # 将从 yaml 读取或命令行参数获取
            }
            # 如果命令行指定了超参数，优先使用命令行值
            if lambda_max is not None:
                expand_config["lambda_max"] = lambda_max

        mlp_config = {
            "dataset": dataset,
            "train_models": train_models,
            "experiment_config": experiment_config,
            "training_timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
            # "model_source_path": target_model_path,
            "model_final_path": target_model_path,
            "in_dims": in_dims,
        }

        # 添加 expand loss 配置（如果存在）
        if expand_config is not None:
            mlp_config["expand_config"] = expand_config

        # 复制模型到目标路径
        if os.path.exists(best_model_path):
            shutil.copy2(best_model_path, target_model_path)

            # ✅ 放到这里再写配置，确保'有权重才有config'
            config_path = target_model_path.replace(".pth", "_config.json")
            with open(config_path, "w") as f:
                json.dump(mlp_config, f, indent=2, default=str)

            # 复制 save_results 生成的两个汇总 JSON（如果存在）
            # 为避免不同 combo 覆盖，这里按权重文件名加后缀保存
            all_results_json_path = os.path.join(temp_output_dir, "all_dataset_results.json")
            summary_json_path = os.path.join(temp_output_dir, "training_summary.json")

            target_all_results_json_path = target_model_path.replace(".pth", "_all_dataset_results.json")
            target_summary_json_path = target_model_path.replace(".pth", "_training_summary.json")

            if os.path.exists(all_results_json_path):
                shutil.copy2(all_results_json_path, target_all_results_json_path)
                console.print(f"[green]Combo {combo_idx}: all_dataset_results.json 已保存[/green]")

            if os.path.exists(summary_json_path):
                shutil.copy2(summary_json_path, target_summary_json_path)
                console.print(f"[green]Combo {combo_idx}: training_summary.json 已保存[/green]")

            # 复制expand loss图片（如果存在）
            expand_loss_img_path = os.path.join(temp_output_dir, dataset, "expand_curve.png")
            target_img_path = target_model_path.replace(".pth", "_expand_curve.png")
            if os.path.exists(expand_loss_img_path):
                shutil.copy2(expand_loss_img_path, target_img_path)
                console.print(f"[green]Combo {combo_idx}: Expand loss图片已保存[/green]")

            # 复制最小相似度对分布直方图（如果存在）
            min_sim_hist_path = os.path.join(temp_output_dir, dataset, "min_similarity_pair_distribution.png")
            target_min_sim_path = target_model_path.replace(".pth", "_min_similarity_pair_distribution.png")
            if os.path.exists(min_sim_hist_path):
                shutil.copy2(min_sim_hist_path, target_min_sim_path)
                console.print(f"[green]Combo {combo_idx}: 最小相似度对分布直方图已保存[/green]")

            # 清理临时目录
            shutil.rmtree(temp_output_dir, ignore_errors=True)
            return True
        else:
            console.print(f"[red]警告: 未找到最佳模型文件 {best_model_path}[/red]")
            # 清理临时目录
            shutil.rmtree(temp_output_dir, ignore_errors=True)
            return False

    except Exception as e:
        console.print(f"[red]训练combo{combo_idx}时出错: {e}[/red]")
        # 清理临时目录
        if "temp_output_dir" in locals():
            shutil.rmtree(temp_output_dir, ignore_errors=True)
        return False


def cleanup_resources():
    """清理GPU资源"""
    try:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            console.print("[green]PyTorch GPU缓存已清理[/green]")

        # 清理CuPy内存池（如果可用）
        try:
            import cupy as cp

            cp.get_default_memory_pool().free_all_blocks()
            console.print("[green]CuPy GPU内存池已清理[/green]")
        except ImportError:
            pass
        except Exception as e:
            console.print(f"[yellow]CuPy内存池清理失败: {str(e)}[/yellow]")

    except Exception as e:
        console.print(f"[yellow]GPU资源清理失败: {str(e)}[/yellow]")


def _kill_children(child_pids):
    """清理子进程"""
    console.print(f"[yellow]正在终止 {len(child_pids)} 个子进程...[/yellow]")
    for pid in child_pids:
        try:
            os.kill(pid, signal.SIGTERM)
        except Exception:
            pass
    time.sleep(1.0)
    for pid in child_pids:
        try:
            os.kill(pid, 0)
            os.kill(pid, signal.SIGKILL)
        except Exception:
            pass
    console.print("[green]子进程清理完成[/green]")


def train_all_mlp_models(
    dataset: str,
    feats_dir: str,
    output_dir: str,
    device: str = "cuda",
    layer_type: str = "last",
    operate_type: str = "last_token",
    n_clusters: int = 50,
    batch_size: int = 64,
    lr: float = 1e-3,
    weight_decay: float = 1e-5,
    max_samples: int = None,
    expand_yaml: str = "",
    force_retrain: bool = False,
    num_workers: int = 1,
    use_expand: bool = False,
    lambda_max: float = None,
    num_models: int = None,
    difficulty_batching: str = 'none',
    difficulty_seed: int = 42,
    difficulty_ascending: bool = True,
) -> bool:
    """训练所有MLP模型"""
    console.print(f"[bold cyan]开始训练 {dataset} 数据集的所有MLP模型[/bold cyan]")

    # 根据use_expand参数和num_models参数显示保存路径
    if use_expand:
        console.print("[green]✓ --use_expand 已指定，将保存到 mlp_models_expand/ 目录[/green]")
    else:
        if num_models is not None:
            console.print(f"[green]✓ --num_models {num_models} 已指定，将保存到 mlp_models_{num_models}models/ 目录[/green]")
        else:
            console.print("[yellow]✗ 使用普通训练，将保存到 mlp_models/ 目录[/yellow]")

    # 检测可用GPU数量
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
    console.print(f"[cyan]检测到 {num_gpus} 个GPU设备[/cyan]")

    # 1. 获取训练/测试划分
    console.print("[yellow]步骤1: 获取训练/测试划分[/yellow]")

    label_matrix_path = os.path.join(
        PROJECT_ROOT,
        "main_experiment/label_data",
        f"label_matrix_{dataset}.json",
    )

    # 根据num_models选择对应的select_models文件
    if num_models is not None:
        # 消融实验：使用 {dataset}_select_models_{num_models}_ablation.py
        select_models_file = f"{dataset}_select_models_{num_models}_ablation.py"
        model_select_subdir = "utils/train_model_select_num_ablation"
    else:
        # 默认：使用原文件（向后兼容）
        select_models_file = f"{dataset}_select_models.py"
        model_select_subdir = "utils/train_model_select_num_ablation"

    model_select_path = os.path.join(
        PROJECT_ROOT,
        model_select_subdir,
        f"{dataset}/{select_models_file}",
    )

    console.print(f"[cyan]使用训练配置: {select_models_file}[/cyan]")

    # 构建测试集刨除模型配置路径（20_ablation）
    test_exclude_models_path = None
    if num_models is not None and num_models != 20:
        # 如果不是20模型实验，需要加载20_ablation配置来统一测试集
        test_exclude_models_path = os.path.join(
            PROJECT_ROOT,
            model_select_subdir,
            f"{dataset}/{dataset}_select_models_20_ablation.py",
        )
        console.print(f"[cyan]使用测试集刨除配置: {dataset}_select_models_20_ablation.py[/cyan]")

    try:
        splits, metadata = get_train_test_splits(
            label_matrix_path,
            model_select_path,
            num_models=num_models,
            test_exclude_models_path=test_exclude_models_path
        )
        console.print(f"数据集: {metadata['dataset']}, 样本数: {metadata['num_samples']}")
        console.print(f"找到 {len(splits)} 个训练/测试组合")

        # 2. 生成训练任务
        console.print(f"[yellow]步骤2: 生成 {len(splits)} 个训练任务[/yellow]")
        tasks = []
        for combo_idx, (
            train_labels_dict,
            test_labels_dict,
            train_models,
            test_models,
        ) in enumerate(splits):
            # 为每个任务分配GPU
            if num_gpus > 0:
                gpu_id = combo_idx % num_gpus  # 循环分配GPU
            else:
                gpu_id = None  # CPU模式

            task_args = (
                combo_idx,
                train_labels_dict,
                test_labels_dict,
                train_models,
                test_models,
                dataset,
                feats_dir,
                output_dir,
                device,
                layer_type,
                operate_type,
                n_clusters,
                batch_size,
                lr,
                weight_decay,
                max_samples,
                expand_yaml,
                force_retrain,
                gpu_id,
                use_expand,  # 添加use_expand参数
                lambda_max,  # lambda_max参数
                num_models,  # 模型数量参数
                difficulty_batching,
                difficulty_seed,
                difficulty_ascending,
            )
            tasks.append(task_args)

        # 3. 并行或串行执行训练任务
        if num_workers > 1:
            console.print(f"[yellow]步骤3: 并行训练MLP模型 (使用{num_workers}个进程)[/yellow]")
            return _run_parallel_training(tasks, num_workers)
        else:
            console.print("[yellow]步骤3: 串行训练MLP模型[/yellow]")
            return _run_serial_training(tasks)

    except Exception as e:
        console.print(f"[red]MLP模型训练失败: {e}[/red]")
        return False


def _run_parallel_training(tasks: List[Tuple], num_workers: int) -> bool:
    """并行执行训练任务"""
    completed_tasks = 0
    successful_tasks = 0

    # 使用spawn上下文避免CUDA上下文继承冲突
    ctx = mp.get_context("spawn")
    pid_q = ctx.Queue()
    child_pids = set()

    executor = None
    execution_failed = False

    try:
        # 创建进程池，GPU分配在任务层面处理
        executor = ProcessPoolExecutor(
            max_workers=num_workers,
            mp_context=ctx,
            initializer=_child_init,
            initargs=(pid_q, None),  # GPU ID在任务中设置
        )

        # 收集worker的PID
        deadline = time.time() + 5
        while len(child_pids) < num_workers and time.time() < deadline:
            try:
                child_pids.add(pid_q.get(timeout=0.2))
            except Exception:
                pass

        console.print(f"[green]已收集 {len(child_pids)} 个子进程PID[/green]")

        # 提交所有任务
        future_to_task = {executor.submit(_process_training_task, task): task for task in tasks}

        # 收集结果
        for future in as_completed(future_to_task):
            # 吸收新产生的PID
            while True:
                try:
                    child_pids.add(pid_q.get_nowait())
                except Exception:
                    break

            try:
                combo_idx, success, _model_path = future.result()

                if success:
                    successful_tasks += 1

                completed_tasks += 1
                progress = completed_tasks / len(tasks) * 100
                status = "成功" if success else "失败"
                console.print(f"[green]进度: {completed_tasks}/{len(tasks)} ({progress:.1f}%) - 组合{combo_idx} {status}[/green]")

            except Exception as e:
                failed_task = future_to_task[future]
                failed_combo_idx = failed_task[0]
                console.print(f"[red]任务异常 - 组合{failed_combo_idx}: {e}[/red]")
                completed_tasks += 1
                continue

    except KeyboardInterrupt:
        console.print("[red]\n检测到键盘中断，正在清理...[/red]")
        execution_failed = True
        if executor:
            executor.shutdown(wait=False, cancel_futures=True)
        _kill_children(child_pids)
        return False
    except Exception as e:
        console.print(f"[red]并行执行出现异常: {e}[/red]")
        execution_failed = True
        if executor:
            executor.shutdown(wait=False, cancel_futures=True)
        _kill_children(child_pids)
        return False
    finally:
        if executor:
            try:
                if execution_failed:
                    executor.shutdown(wait=False, cancel_futures=True)
                    console.print("[yellow]强制关闭进程池[/yellow]")
                else:
                    executor.shutdown(wait=True)
                    console.print("[green]优雅关闭进程池[/green]")
            except Exception as e:
                console.print(f"[yellow]进程池关闭异常: {e}[/yellow]")
                if not execution_failed:
                    _kill_children(child_pids)

        if execution_failed and child_pids:
            _kill_children(child_pids)

        # GPU资源清理
        cleanup_resources()

    console.print(
        f"[green]所有MLP训练任务完成! 成功: {successful_tasks}/{len(tasks)} 个模型[/green]"
    )
    return successful_tasks == len(tasks)


def _run_serial_training(tasks: List[Tuple]) -> bool:
    """串行执行训练任务"""
    successful_tasks = 0

    for i, task_args in enumerate(tasks):
        combo_idx = task_args[0]
        # GPU ID 现在是倒数第3个参数（索引 -3），顺序是：..., gpu_id, use_expand, lambda_max
        gpu_id = task_args[-3]  # 修正：移除 r_target 后，GPU ID 位于倒数第3个位置
        gpu_info = f" (GPU: {gpu_id})" if gpu_id is not None else " (CPU)"
        console.print(f"[cyan]训练组合 {combo_idx}/{len(tasks)}{gpu_info}[/cyan]")

        try:
            combo_idx, success, _model_path = _process_training_task(task_args)

            if success:
                successful_tasks += 1

            status = "成功" if success else "失败"
            console.print(f"[green]组合{combo_idx}训练{status} ({i+1}/{len(tasks)})[/green]")

        except Exception as e:
            console.print(f"[red]组合{combo_idx}训练失败: {e}[/red]")

    console.print(f"[green]所有MLP训练任务完成! 成功: {successful_tasks}/{len(tasks)} 个模型[/green]")
    return successful_tasks == len(tasks)


def main():
    """主函数"""
    # 设置多进程启动方式为spawn，避免CUDA上下文冲突
    try:
        mp.set_start_method("spawn", force=True)
        console.print("[green]使用spawn启动方式初始化多进程环境[/green]")
    except RuntimeError:
        console.print("[yellow]多进程启动方式已经设置，继续使用现有配置[/yellow]")

    parser = argparse.ArgumentParser(description="训练所有MLP模型")
    parser.add_argument("--dataset", type=str, default="mmlu_pro", help="数据集名称")
    parser.add_argument("--num_models", type=int, default=None, help="模型数量（用于区分不同数量的模型组合，如 5/10/15/20）")
    parser.add_argument(
        "--feats_dir",
        type=str,
        default=os.path.join(PROJECT_ROOT, "feats/mmlu_pro/mmlu_pro_feats_split"),
        help="特征文件目录",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=os.path.join(PROJECT_ROOT, "main_experiment/results"),
        help="输出目录",
    )
    parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="运行设备")
    parser.add_argument(
        "--layer_type",
        type=str,
        default="last",
        choices=["quarter", "middle", "three_quarters", "last", "second_last", "first", "all"],
        help="隐状态层",
    )
    parser.add_argument(
        "--operate_type",
        type=str,
        default="last_token",
        choices=["prompt_last_token", "answer_first_token", "last_token"],
        help="特征操作类型",
    )
    parser.add_argument("--n_clusters", type=int, default=50, help="聚类数量")
    parser.add_argument("--batch_size", type=int, default=128, help="MLP训练批量大小")
    parser.add_argument("--lr", type=float, default=5e-4, help="MLP学习率")
    parser.add_argument("--weight_decay", type=float, default=1e-5, help="权重衰减")
    parser.add_argument("--max_samples", type=int, default=None, help="最大训练样本数量")
    parser.add_argument(
        "--expand_yaml",
        type=str,
        default=os.path.join(PROJECT_ROOT, "main_experiment/mlp_training.yaml"),
        help="Expand loss 的 YAML 配置路径；默认为 mlp_training.yaml；为空字符串 '' 则关闭 expand loss",
    )
    parser.add_argument("--use_expand", action="store_true", default=False, help="是否使用expand loss训练（决定保存到mlp_models还是mlp_models_expand目录）(默认: False)")
    parser.add_argument("--lambda_max", type=float, default=None, help="Expand loss的最大权重 (仅在使用use_expand时有效，用于区分不同超参数组合)")
    parser.add_argument(
        "--difficulty_batching",
        type=str,
        default="none",
        choices=["none", "blocks_sync", "blocks_async"],
        help=(
            "难度感知批处理策略: "
            "none=普通shuffle; blocks_sync=按难度分块+块级shuffle(所有模型同序); "
            "blocks_async=按难度分块+块级shuffle(每个模型不同序)"
        ),
    )
    parser.add_argument(
        "--difficulty_seed",
        type=int,
        default=42,
        help="难度批处理中块级shuffle使用的随机种子 (默认: 42)",
    )
    parser.add_argument(
        "--difficulty_descending",
        action="store_true",
        default=False,
        help="按难度分数降序排列（如果分数是正确率，则为简单→困难）。默认为升序（困难→简单）。",
    )
    parser.add_argument("--force_retrain", action="store_true", help="强制重新训练MLP模型")
    parser.add_argument("--num_workers", type=int, default=20, help="并行进程数 (默认: 1，>1时启用并行模式)")

    args = parser.parse_args()

    # 验证超参数
    # 1. 如果提供了 lambda_max，必须使用 use_expand
    if args.lambda_max is not None and not args.use_expand:
        raise ValueError("指定了 --lambda_max 时，必须同时指定 --use_expand")

    # 2. 如果使用 use_expand，必须提供 lambda_max
    if args.use_expand:
        if args.lambda_max is None:
            raise ValueError("使用 --use_expand 时必须提供 --lambda_max 参数")
        # 验证数值范围
        if args.lambda_max <= 0:
            raise ValueError(f"lambda_max 必须大于 0，当前值: {args.lambda_max}")

    # 使用labeled_data路径
    args.feats_dir = os.path.join(PROJECT_ROOT, f"feats/{args.dataset}/{args.dataset}_feats_split")

    try:
        success = train_all_mlp_models(
            dataset=args.dataset,
            feats_dir=args.feats_dir,
            output_dir=args.output_dir,
            device=args.device,
            layer_type=args.layer_type,
            operate_type=args.operate_type,
            n_clusters=args.n_clusters,
            batch_size=args.batch_size,
            lr=args.lr,
            weight_decay=args.weight_decay,
            max_samples=args.max_samples,
            expand_yaml=args.expand_yaml,
            force_retrain=args.force_retrain,
            num_workers=args.num_workers,
            use_expand=args.use_expand,
            lambda_max=args.lambda_max,
            num_models=args.num_models,
            difficulty_batching=args.difficulty_batching,
            difficulty_seed=args.difficulty_seed,
            difficulty_ascending=(not args.difficulty_descending),
        )

        if success:
            console.print("[bold green]所有MLP模型训练完成！[/bold green]")
            # 确定模型目录名称（与get_mlp_model_path和HSbench中的逻辑一致）
            if args.use_expand:
                model_dir = "mlp_models_expand"
            elif args.num_models is not None:
                model_dir = f"mlp_models_{args.num_models}models"
            else:
                model_dir = "mlp_models"
            console.print(f"[green]模型保存在: {os.path.join(args.output_dir, model_dir)}[/green]")
        else:
            console.print("[bold red]部分MLP模型训练失败！[/bold red]")
            sys.exit(1)

    except KeyboardInterrupt:
        console.print("\n[red]检测到键盘中断，正在清理资源...[/red]")
        cleanup_resources()
        sys.exit(1)
    except Exception as e:
        console.print(f"[red]训练执行出错: {str(e)}[/red]")
        raise
    finally:
        cleanup_resources()
        console.print("[green]训练完成，资源已清理[/green]")


if __name__ == "__main__":
    # 设置信号处理器
    def signal_handler(signum, _):
        console.print(f"\n[red]收到中断信号 {signum}[/red]")
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                console.print("[green]PyTorch GPU缓存已清理[/green]")
            try:
                import cupy as cp

                cp.get_default_memory_pool().free_all_blocks()
                console.print("[green]CuPy GPU内存池已清理[/green]")
            except ImportError:
                pass
        except Exception:
            pass
        console.print("[green]资源清理完成[/green]")
        sys.exit(1)

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    main()
