import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"

import re
import ast
import json
from operator import itemgetter
from llama_scorer import Llama_Scorer


def load_json_data(file_path):
    """加载JSON文件"""
    with open(file_path, 'r') as f:
        return json.load(f)


def extract_related_info(top_entries, problems_path, captions_path):
    """
    核心提取逻辑：
    1. 加载两个JSON文件
    2. 根据Question ID进行数据关联
    3. 返回结构化结果
    """
    # 加载数据源
    problems_data = load_json_data(problems_path)
    captions_data = load_json_data(captions_path)

    results = []

    problems_dict = {item['question_id']: item for item in problems_data}

    for entry in top_entries:
        # 转换ID格式为字符串（JSON的key都是字符串类型）
        qid_str = str(entry["Question ID"])

        # 初始化结果结构
        result_item = {
            "Question ID": entry["Question ID"],
            "Answers": entry["Answers"],
            "question": None,
            "choices": [],
            "captions": []
        }

        # 提取problems数据
        if qid_str in problems_dict:
            problem = problems_dict[qid_str]
            image_id = problem.get("image_id")
            result_item.update({
                "question": problem.get("question"),
                "choices": problem.get("choices", [])
            })
        else:
            print(f"Warning: Question ID {qid_str} not found in problems.json")
            continue  # 如果问题不存在则跳过

        # 提取caption数据
        annotations = captions_data.get("annotations", [])
        result_item["captions"] = [
            ann["caption"] for ann in annotations
            if ann.get("image_id") == image_id
        ]

        results.append(result_item)

    return results

'''
def extract_top_variance_entries(file_path):
    # 正则表达式匹配Question条目结构
    pattern = re.compile(r'Question ID: ([A-Za-z0-9]+), Uncertainty: ([\d.]+), Answers: (\[.*\])')
    top_entries = []
    max_variance = None

    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if not line:
                continue

            # 解析行数据
            match = pattern.match(line)
            if not match:
                print(f"格式不匹配，跳过该行: {line}")
                continue

            # 提取数据
            question_id = match.group(1)
            variance = float(match.group(2))
            answers = ast.literal_eval(match.group(3))  # 安全转换列表字符串

            # 初始化最大方差值
            if max_variance is None:
                max_variance = variance
                top_entries.append({'Question ID': question_id, 'Answers': answers})
                continue

            # 只收集并列第一的条目
            if variance == max_variance:
                top_entries.append({'Question ID': question_id, 'Answers': answers})
            else:
                # 后续方差更小，直接停止读取（因文件已排序）
                break

    return top_entries
'''


def extract_top_variance_entries(file_path):
    # 正则表达式匹配Question条目结构
    pattern = re.compile(r'Question ID: ([A-Za-z0-9]+), Uncertainty: ([\d.]+), Answers: (\[.*\])')
    top_entries = []

    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if not line:
                continue

            # 解析行数据
            match = pattern.match(line)
            if not match:
                print(f"格式不匹配，跳过该行: {line}")
                continue

            # 提取数据
            question_id = match.group(1)
            uncertainty = float(match.group(2))
            answers = ast.literal_eval(match.group(3))  # 安全转换列表字符串

            # 检查方差是否为0
            if uncertainty == 1.0:
                top_entries.append({'Question ID': question_id, 'Answers': answers})

    return top_entries

def build_input_text(item):
    """构建符合评分要求的输入文本"""
    components = [
        f"Question: {item['question']}",
        f"Choices: {', '.join(item['choices'])}"
    ]

    if item['captions']:
        # 合并多个caption
        caption_str = "; ".join(item['captions'])
        components.append(f"Image Captions: {caption_str}")

    return "\n".join(components)


def process_scoring(final_data, output_file="/home/test/yxl/MCoT/aokvqa/results/mistral/complexity_scores_1.txt"):
    """执行完整评分流程"""
    # 初始化评分器
    try:
        # scorer = Llama_Scorer("hkust-nlp/deita-quality-scorer")
        scorer = Llama_Scorer("hkust-nlp/deita-complexity-scorer")
    except Exception as e:
        print(f"模型加载失败: {str(e)}")
        return

    # 存储评分结果
    scored_data = []

    for item in final_data:
        # 构建输入输出文本
        input_text = build_input_text(item)
        output_text = "\n".join(item['Answers'])  # 假设Answers字段存在

        # 执行评分
        try:
            # score = scorer.infer_quality(input_text, output_text)
            score = scorer.infer_complexity(input_text)
            scored_data.append({
                "question_id": item["Question ID"],
                "score": float(score)
            })
            print(f"Question ID: {item['Question ID']}, Score: {score}")
            print("-" * 50)
        except Exception as e:
            print(f"评分失败（ID {item['Question ID']}）: {str(e)}")
            continue

    # 按分数降序排序
    sorted_scores = sorted(scored_data, key=itemgetter('score'), reverse=True)

    # 保存结果文件
    with open(output_file, 'w') as f:
        f.write("Question ID\tQuality Score\n")
        for entry in sorted_scores:
            f.write(f"{entry['question_id']}\t{entry['score']:.4f}\n")

    print(f"评分完成！结果已保存至 {output_file}")

# 使用示例
if __name__ == "__main__":
    summary_path = '/home/test/yxl/MCoT/aokvqa/results/mistral/results_summary.txt'
    top_entries = extract_top_variance_entries(summary_path)

    # 提取关联信息
    top_data = extract_related_info(
        top_entries,
        "/home/test/yxl/MCoT/data/aokvqa/aokvqa_v1p0_train.json",
        "/home/test/yxl/MCoT/data/COCO/annotations/captions_train2017.json"
    )
    print("开始评分")
    print("-"*50)
    process_scoring(top_data)

