import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import re
import ast
import json
from operator import itemgetter
from llama_scorer import Llama_Scorer


def load_json_data(file_path):
    """加载JSON文件"""
    with open(file_path, 'r') as f:
        return json.load(f)


def extract_related_info(top_entries, problems_path, captions_path):
    """
    核心提取逻辑：
    1. 加载两个JSON文件
    2. 根据Question ID进行数据关联
    3. 返回结构化结果
    """
    # 加载数据源
    problems_data = load_json_data(problems_path)
    captions_data = load_json_data(captions_path)

    results = []

    for entry in top_entries:
        # 转换ID格式为字符串（JSON的key都是字符串类型）
        qid_str = str(entry["Question ID"])

        # 初始化结果结构
        result_item = {
            "Question ID": entry["Question ID"],
            "Answers": "",
            "question": None,
            "choices": [],
            "hint": "",
            "caption": None,
            "lecture": "",
            "solution": ""
        }

        # 提取problems数据
        if qid_str in problems_data:
            problem = problems_data[qid_str]
            result_item.update({
                "question": problem.get("question"),
                "Answers": problem.get("answer"),
                "choices": problem.get("choices", []),
                "hint": problem.get("hint", ""),
                "lecture": problem.get("lecture", ""),
                "solution": problem.get("solution", "")
            })
        else:
            print(f"Warning: Question ID {qid_str} not found in problems.json")
            continue  # 如果问题不存在则跳过

        # 提取caption数据
        captions = captions_data.get("captions", {})
        if qid_str in captions:
            result_item["caption"] = captions[qid_str]

        results.append(result_item)

    return results

'''
def extract_top_variance_entries(file_path):
    # 正则表达式匹配Question条目结构
    # pattern = re.compile(r'Question ID: (\d+), Variance: ([\d.]+), Answers: (\[.*\])')
    # pattern = re.compile(r'Question ID: (\d+), Entropy: ([\d.]+), Answers: (\[.*\])')
    pattern = re.compile(r'Question ID: (\d+), Uncertainty: ([\d.]+), Answers: (\[.*\])')
    top_entries = []
    max_Uncertainty = None

    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if not line:
                continue

            # 解析行数据
            match = pattern.match(line)
            if not match:
                print(f"格式不匹配，跳过该行: {line}")
                continue

            # 提取数据
            question_id = int(match.group(1))
            Uncertainty = float(match.group(2))
            answers = ast.literal_eval(match.group(3))  # 安全转换列表字符串

            # 初始化最大方差值
            if max_Uncertainty is None:
                max_Uncertainty = Uncertainty
                top_entries.append({'Question ID': question_id, 'Answers': answers})
                continue

            # 只收集并列第一的条目
            if Uncertainty == max_Uncertainty:
                top_entries.append({'Question ID': question_id, 'Answers': answers})
            else:
                # 后续方差更小，直接停止读取（因文件已排序）
                break

    return top_entries
'''
def extract_top_variance_entries(file_path):
    # 正则表达式匹配Question条目结构
    pattern = re.compile(r'Question ID: (\d+), Uncertainty: ([\d.]+), Answers: (\[.*\])')
    top_entries = []

    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if not line:
                continue

            # 解析行数据
            match = pattern.match(line)
            if not match:
                print(f"格式不匹配，跳过该行: {line}")
                continue

            # 提取数据
            question_id = int(match.group(1))
            uncertainty = float(match.group(2))
            answers = ast.literal_eval(match.group(3))  # 安全转换列表字符串

            # 检查方差是否为0
            if uncertainty == 1.0:
                top_entries.append({'Question ID': question_id, 'Answers': answers})

    return top_entries


def build_input_text(item):
    """构建符合评分要求的输入文本"""
    components = [
        f"Question: {item['question']}",
        f"Choices: {', '.join(item['choices'])}"
    ]

    # 可选字段处理
    if item['hint']:
        components.append(f"Hint: {item['hint']}")
    if item['caption']:
        components.append(f"Image Caption: {item['caption']}")

    return "\n".join(components)


def build_output_text(item):
    """构建符合评分要求的输入文本"""
    components = [
        f"Question: {item['question']}"
        f"Answers: {item['Answers']}"
    ]

    # 可选字段处理
    if item['lecture']:
        components.append(f"Lecture: {item['lecture']}")
    if item['solution']:
        components.append(f"Solution: {item['solution']}")

    return "\n".join(components)

def process_scoring(final_data, output_file="/home/test/yxl/MCoT/sqa/results/mistral-small3.1/complexity_scores_1.txt"):
    """执行完整评分流程"""
    # 初始化评分器
    try:
        # scorer = Llama_Scorer("hkust-nlp/deita-quality-scorer")
        scorer = Llama_Scorer("hkust-nlp/deita-complexity-scorer")
    except Exception as e:
        print(f"模型加载失败: {str(e)}")
        return

    # 存储评分结果
    scored_data = []

    for item in final_data:
        # 构建输入输出文本
        input_text = build_input_text(item)
        output_text = build_output_text(item)

        # 执行评分
        try:
            # score = scorer.infer_quality(input_text, output_text)
            score = scorer.infer_complexity(input_text)
            scored_data.append({
                "question_id": item["Question ID"],
                "score": float(score)
            })
            print(f"Question ID: {item['Question ID']}, Score: {score}")
            print("-" * 50)
        except Exception as e:
            print(f"评分失败（ID {item['Question ID']}）: {str(e)}")
            continue

    # 按分数降序排序
    sorted_scores = sorted(scored_data, key=itemgetter('score'), reverse=True)

    # 保存结果文件
    with open(output_file, 'w') as f:
        f.write("Question ID\tComplexity Score\n")
        for entry in sorted_scores:
            f.write(f"{entry['question_id']}\t{entry['score']:.4f}\n")

    print(f"评分完成！结果已保存至 {output_file}")

# 使用示例
if __name__ == "__main__":
    top_entries = extract_top_variance_entries('/home/test/yxl/MCoT/sqa/results/mistral-small3.1/results_summary.txt')

    # 提取关联信息
    top_data = extract_related_info(
        top_entries,
        "/home/test/yxl/MCoT/data/scienceqa/problems.json",
        "/home/test/yxl/MCoT/data/captions.json"
    )
    print("开始评分")
    print("-"*50)
    process_scoring(top_data)

