#!/usr/bin/env python3
"""
从指定配置（model/dataset/layer/pos）中提取0/1标签矩阵

基于已有的labeled_data文件，提取每个模型在数据集上的0/1标签，
每一行代表单个模型在当前数据集上的运行结果。
最终返回一个字典数组，格式为"model": [0,1,0,...,1]。

使用示例:
    python get_gt_label.py --dataset mmlu_pro --layer last --pos prompt_last_token
    python get_gt_label.py --dataset bbh --layer middle --pos last_token
    python get_gt_label.py --dataset math --layer second_last --pos answer_first_token
    python get_gt_label.py --dataset arc_challenge --layer last --pos last_token

支持的层配置: quarter, middle, three_quarters, last, second_last, first
支持的位置配置: prompt_last_token, answer_first_token, last_token（兼容旧格式: avg_without_prompt, avg_with_prompt）
"""

import argparse
import os
import sys
from pathlib import Path
import torch
import json
import pandas as pd

# 添加项目根目录到路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)


def safe_torch_load(p):
    """安全的torch加载"""
    try:
        return torch.load(p, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(p, map_location="cpu")


def load_models_from_excel(excel_path):
    """从Excel文件加载模型列表"""
    try:
        df = pd.read_excel(excel_path, engine="openpyxl")
        if "模型名称" not in df.columns:
            print(f"错误: Excel文件中缺少'模型名称'列。可用列: {df.columns.tolist()}")
            return []

        models = df["模型名称"].dropna().tolist()
        print(f"从 {excel_path} 加载了 {len(models)} 个模型")
        return models
    except Exception as e:
        print(f"加载Excel文件失败: {e}")
        return []


def get_labeled_data_path(feats_root, dataset, layer, pos):
    """获取标记数据文件路径"""
    config_name = f"{layer}_{pos}"
    return Path(feats_root) / f"{dataset}" / "labeled_data" / f"{dataset}_{config_name}.pt"


def extract_label_matrix(feats_root, dataset, layer, pos):
    """
    从标记数据中提取标签矩阵

    Args:
        feats_root: 特征根目录
        dataset: 数据集名称 (math, mmlu_pro, bbh, arc_challenge)
        layer: 层配置 (quarter, middle, three_quarters, last, second_last, first)
        pos: 位置配置 (prompt_last_token, answer_first_token, last_token, 兼容旧格式: avg_without_prompt, avg_with_prompt)

    Returns:
        tuple: (label_matrix, question_ids)
            - label_matrix: 格式为 {"model_name": [0, 1, 0, ..., 1]} 的字典
            - question_ids: question_id的顺序列表，与标签列表一一对应
    """
    # 获取标记数据文件路径
    labeled_data_path = get_labeled_data_path(feats_root, dataset, layer, pos)

    if not labeled_data_path.exists():
        print(f"错误: 标记数据文件不存在: {labeled_data_path}")
        return None

    print(f"加载标记数据: {labeled_data_path}")

    try:
        # 加载标记数据
        labeled_data = safe_torch_load(labeled_data_path)
        print(f"成功加载标记数据，包含 {len(labeled_data)} 个模型")

        # 获取模型列表
        excel_path = (
            Path(PROJECT_ROOT)
            / "get_label_data"
            / f"{dataset}_main_experiment_models.xlsx"
        )
        if not excel_path.exists():
            print(f"错误: 模型列表Excel文件不存在: {excel_path}")
            return None

        models = load_models_from_excel(excel_path)
        if not models:
            return None

        # 3) 轻量校验：检查所有模型的ID顺序是否一致
        print("  验证样本ID顺序一致性...")
        available_models = [m for m in models if m in labeled_data]

        if not available_models:
            print("错误: 没有可用的模型数据")
            return None

        # 记录被跳过的模型数量
        original_available_count = len(available_models)
        skipped_models = []

        base_model = available_models[0]
        base_ids = list(labeled_data[base_model].keys())  # 这里的 keys 是 question_id
        # 将ID转换为字符串进行比较，避免类型不一致问题
        base_ids_str = [str(id) for id in base_ids]

        for model in available_models[1:]:
            current_ids = list(labeled_data[model].keys())
            current_ids_str = [str(id) for id in current_ids]
            if current_ids_str != base_ids_str:
                print(f"错误: ID 顺序或集合与基准模型不一致: {model}")
                print(f"基准模型 {base_model} 有 {len(base_ids)} 个ID")
                print(f"当前模型 {model} 有 {len(current_ids)} 个ID")
                print(
                    f"前5个基准ID: {base_ids[:5]} (类型: {[type(id).__name__ for id in base_ids[:5]]})"
                )
                print(
                    f"前5个当前ID: {current_ids[:5]} (类型: {[type(id).__name__ for id in current_ids[:5]]})"
                )
                return None

        print(f"  ✓ ID顺序一致性验证通过 ({len(base_ids)} 个样本)")

        # 构建标签矩阵字典
        label_matrix = {}

        for model in models:
            if model not in labeled_data:
                print(f"警告: 模型 {model} 不在标记数据中")
                continue

            model_data = labeled_data[model]

            # 使用基准ID顺序提取标签（处理ID类型不一致问题）
            labels = []
            missing_ids = []

            # 获取当前模型的所有可用ID
            current_model_ids = set(model_data.keys())

            """
                多种情况：
                    1. question_id 与 model_data（当前模型） 的 key 类型一致，直接查找
                    2. question_id 是 int, model_data（当前模型） 的 key 是字符串，尝试转换
                    3. question_id 是字符串，model_data（当前模型） 的 key 是整数，尝试转换
            """
            for question_id in base_ids:
                # 尝试多种方式查找ID
                found = False

                # 方式1: 直接使用原始question_id查找
                if question_id in model_data:
                    labels.append(int(model_data[question_id]["question_label"]))
                    found = True
                # 方式2: 使用字符串版本的question_id查找（防止base模型的ID是int类型）
                elif str(question_id) in model_data:
                    labels.append(int(model_data[str(question_id)]["question_label"]))
                    found = True
                # 方式3: 如果question_id是字符串，尝试转换为int查找（防止当前模型的ID是int类型）
                elif isinstance(question_id, str):
                    try:
                        question_id_int = int(question_id)
                        if question_id_int in model_data:
                            labels.append(
                                int(model_data[question_id_int]["question_label"])
                            )
                            found = True
                    except ValueError:
                        pass
                # 方式4: 在当前模型的ID中查找匹配的字符串表示
                else:
                    for current_id in current_model_ids:
                        if str(current_id) == str(question_id):
                            labels.append(int(model_data[current_id]["question_label"]))
                            found = True
                            break

                if not found:
                    missing_ids.append(question_id)

            if missing_ids:
                print(f"警告: 模型 {model} 缺失 {len(missing_ids)} 个ID")
                print(
                    f"  缺失的ID示例: {missing_ids[:3]} (类型: {[type(id).__name__ for id in missing_ids[:3]]})"
                )
                print(
                    f"  当前模型ID示例: {list(current_model_ids)[:3]} (类型: {[type(id).__name__ for id in list(current_model_ids)[:3]]})"
                )
                print(
                    f"  当前模型ID总数: {len(current_model_ids)}, 基准模型ID总数: {len(base_ids)}"
                )
                # 严格模式：如果有缺失ID，跳过该模型
                skipped_models.append(model)
                continue

            label_matrix[model] = labels
            print(
                f"  模型 {model}: {len(labels)} 个样本，正确率 {sum(labels)/len(labels):.3f}"
            )

        # 1) 检查所有模型样本数是否一致
        lengths = {m: len(v) for m, v in label_matrix.items()}
        if len(set(lengths.values())) > 1:
            max_len = max(lengths.values())
            bad_models = {
                m: length for m, length in lengths.items() if length != max_len
            }
            raise ValueError(
                f"样本数不一致，无法形成矩阵。最大样本数={max_len}，不一致模型及样本数={bad_models}"
            )

        print(
            f"  ✓ 样本数一致性验证通过 (所有模型均为 {lengths[list(lengths.keys())[0]]} 个样本)"
        )

        # 报告被跳过的模型
        if skipped_models:
            print(
                f"\n  ⚠️  由于缺失ID被跳过的模型: {len(skipped_models)}/{original_available_count}"
            )
            print("     跳过的模型:", ", ".join(skipped_models))
        else:
            print("\n  ✓ 所有模型都包含完整的ID数据")

        # 返回标签矩阵和基准question_id顺序（转换为字符串）
        question_ids = [str(id) for id in base_ids]
        return label_matrix, question_ids

    except Exception as e:
        print(f"错误: 处理标记数据失败: {e}")
        import traceback

        traceback.print_exc()
        return None


def save_label_matrix(label_matrix, output_path, dataset, layer, pos, question_ids):
    """
    保存标签矩阵到JSON和CSV文件

    JSON格式:
    {
        "metadata": {
            "dataset": "mmlu_pro",
            "layer": "last",
            "pos": "avg_without_prompt",
            "num_models": 45,
            "num_samples": 12032
        },
        "question_ids": ["10024", "5088", "7070", ...],
        "label_matrix": {
            "model_name_1": [0, 1, 0, 1, ...],
            "model_name_2": [1, 0, 1, 0, ...],
            ...
        }
    }

    CSV格式:
    model,labels,accuracy
    model_name_1,"[0, 1, 0, 1, ...]",0.456
    model_name_2,"[1, 0, 1, 0, ...]",0.678
    ...

    注意: question_ids 列表与 label_matrix 中每个标签列表的顺序一一对应
    """
    try:
        # 2) 计算num_samples：验证一致性后取任意模型的长度
        num_samples = 0
        if label_matrix:
            # 验证所有模型样本数一致
            sample_counts = [len(labels) for labels in label_matrix.values()]
            if len(set(sample_counts)) == 1:
                num_samples = sample_counts[0]
            else:
                # 这种情况在extract_label_matrix中应该已经被捕获
                num_samples = sample_counts[0]

        output_data = {
            "metadata": {
                "dataset": dataset,
                "layer": layer,
                "pos": pos,
                "num_models": len(label_matrix),
                "num_samples": num_samples,
            },
            "question_ids": question_ids,
            "label_matrix": label_matrix,
        }

        # 保存为JSON文件
        json_path = output_path.with_suffix(".json")
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)

        print(f"标签矩阵已保存到: {json_path}")

        return True

    except Exception as e:
        print(f"错误: 保存标签矩阵失败: {e}")
        return False


def main():
    parser = argparse.ArgumentParser(description="从指定配置中提取0/1标签矩阵")
    parser.add_argument(
        "--dataset",
        required=True,
        choices=["math", "mmlu_pro", "bbh", "arc_challenge", "seedbench_plus2", "gsm8k"],
        help="数据集名称",
    )
    parser.add_argument(
        "--layer",
        choices=["quarter", "middle", "three_quarters", "last", "second_last", "first"],
        default="last",
        help="层配置",
    )
    parser.add_argument(
        "--pos",
        choices=["prompt_last_token", "answer_first_token", "last_token",
                 "avg_without_prompt", "avg_with_prompt"],  # 兼容旧格式
        default="last_token",
        help="位置配置",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        help="输出目录，默认为 main_experiment/label_data",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="详细输出",
    )

    args = parser.parse_args()

    # 设置路径
    feats_root = Path(PROJECT_ROOT) / "feats"

    if args.output_dir is None:
        output_dir = Path(BASE_DIR) / "label_data"
    else:
        output_dir = Path(args.output_dir)

    output_dir.mkdir(parents=True, exist_ok=True)

    # 检查特征目录
    if not feats_root.exists():
        print(f"错误: 特征目录不存在: {feats_root}")
        return 1

    print(f"开始处理 {args.dataset} 数据集")
    print(f"配置: layer={args.layer}, pos={args.pos}")

    # 提取标签矩阵
    result = extract_label_matrix(feats_root, args.dataset, args.layer, args.pos)

    if not result:
        print("错误: 无法提取标签矩阵 (可能所有模型都因缺失ID被跳过)")
        return 1

    label_matrix, question_ids = result

    print("\n成功提取标签矩阵:")
    print(f"  最终有效模型数量: {len(label_matrix)}")

    if label_matrix:
        first_model = list(label_matrix.keys())[0]
        sample_count = len(label_matrix[first_model])
        print(f"  样本数量: {sample_count}")

        # 计算统计信息
        accuracies = []
        for model, labels in label_matrix.items():
            accuracy = sum(labels) / len(labels)
            accuracies.append(accuracy)

        print(f"  平均准确率: {sum(accuracies)/len(accuracies):.3f}")
        print(f"  最高准确率: {max(accuracies):.3f}")
        print(f"  最低准确率: {min(accuracies):.3f}")

        if args.verbose:
            print("\n各模型准确率:")
            for model, labels in label_matrix.items():
                accuracy = sum(labels) / len(labels)
                print(f"  {model}: {accuracy:.3f}")

    # 保存结果（不包含layer+pos后缀，因为标签是通用的）
    output_path = output_dir / f"label_matrix_{args.dataset}"

    if save_label_matrix(
        label_matrix, output_path, args.dataset, args.layer, args.pos, question_ids
    ):
        print(f"\n处理完成! 结果已保存到: {output_path}")
        return 0
    else:
        print("\n保存失败")
        return 1


if __name__ == "__main__":
    sys.exit(main())
