import argparse
import json
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import os
from glossary import *
from utils import *

def normalize_text(text):
    # 去除标点并小写
    text = re.sub(r'[^\w\s]', '', text)
    return text.lower().strip()

def match_keyword(true_answer, pred_answer):
    norm_true = normalize_text(true_answer)
    norm_pred = normalize_text(pred_answer)

    return re.search(r'\b{}\b'.format(re.escape(norm_true)), norm_pred) is not None

def extract_single_word_answer(text):
    """从包含<answer>标签的文本中提取答案"""
    match = re.search(r'<answer>\s*([A-Za-z])\s*</answer>', text, re.IGNORECASE)
    return match.group(1).upper() if match else ""

def extract_answer_string(text):
    """从包含<answer>标签的文本中提取其中的字符串内容（支持多词、多字符）"""
    match = re.search(r'<answer>\s*(.*?)\s*</answer>', text, re.IGNORECASE | re.DOTALL)
    return match.group(1).strip() if match else ""

def calculate_ece(data, num_bins=10):
    bin_size = 1.0 / num_bins
    ece = 0.0
    bin_correct_counts = [0] * num_bins
    bin_total_counts = [0] * num_bins
    bin_conf_sums = [0.0] * num_bins

    for item in data:
        confidence = item.get("confidence", None)
        if confidence is None:
            continue

        true_answer = item.get("answer", "").upper().strip()
        pred_text = item.get("pred", "")

        extract = True if "answer" in pred_text else False
        pred_answer = extract_answer_string(pred_text).upper() if extract else pred_text.upper().strip()

        correct = true_answer in pred_answer

        # 计算置信度所属的 bin
        bin_index = min(int(confidence / bin_size), num_bins - 1)
        bin_total_counts[bin_index] += 1
        bin_conf_sums[bin_index] += confidence
        if correct:
            bin_correct_counts[bin_index] += 1

    # 计算 ECE
    for i in range(num_bins):
        if bin_total_counts[i] == 0:
            continue
        acc = bin_correct_counts[i] / bin_total_counts[i]
        avg_conf = bin_conf_sums[i] / bin_total_counts[i]
        ece += (bin_total_counts[i] / len(data)) * abs(acc - avg_conf)

    return ece

def calculate_ece_with_recall(data, num_bins=10):
    bin_size = 1.0 / num_bins
    ece = 0.0
    bin_recall_sums = [0.0] * num_bins
    bin_conf_sums = [0.0] * num_bins
    bin_total_counts = [0] * num_bins

    for item in data:
        confidence = item.get("confidence", None)
        if confidence is None:
            continue

        true_answer = item.get("answer", "").upper().strip()
        pred_text = item.get("pred", "")

        extract = "answer" in pred_text
        pred_answer = extract_answer_string(pred_text).upper() if extract else pred_text.upper().strip()

        recall = calculate_recall_only(pred_text.strip(), true_answer)
        recall = max(0.0, min(1.0, recall))

        # 确定 bin index
        bin_index = min(int(confidence / bin_size), num_bins - 1)

        # 累积统计
        bin_total_counts[bin_index] += 1
        bin_recall_sums[bin_index] += recall
        bin_conf_sums[bin_index] += confidence

    # 计算基于 recall 的 ECE
    total_samples = sum(bin_total_counts)
    for i in range(num_bins):
        if bin_total_counts[i] == 0:
            continue
        avg_recall = bin_recall_sums[i] / bin_total_counts[i]
        avg_conf = bin_conf_sums[i] / bin_total_counts[i]
        ece += (bin_total_counts[i] / total_samples) * abs(avg_recall - avg_conf)

    return ece

def calculate_recall(data):
    correct_ta = 0
    total = 0

    accuracy_list = []
    confidence_list = []
    
    for item in data:
        true_answer = item.get("answer", "").lower().strip()
        pred_text = item.get("pred", "")
        
        extract = True if "</answer>" in pred_text else False
        if extract:
            pred_answer = extract_answer_string(pred_text).lower()
        else:
            pred_answer = pred_text.lower().strip()

        recall_ta = calculate_recall_only(pred_text.lower().strip(),true_answer)

        if recall_ta > 0:
            correct_ta += recall_ta
            accuracy_list.append(True)
        else:
            accuracy_list.append(False)

        confidence_list.append(item["confidence"])
        total += 1
    
    return correct_ta, total, accuracy_list, confidence_list

def calculate_recall_only(candidate, reference):
    candidate = normalize_word(candidate)
    reference = normalize_word(reference)

    candidate_words = split_sentence(candidate, 1)
    reference_words = split_sentence(reference, 1)

    tp = 0
    fn = 0

    for word in reference_words:
        if word in candidate_words:
            tp += min(candidate_words[word], reference_words[word])
        else:
            fn += reference_words[word]

    if tp + fn == 0:
        return 0.0
    else:
        return tp / (tp + fn)

def calculate_accuracy(data):
    correct = 0
    total = 0

    accuracy_list = []
    confidence_list = []
    
    for item in data:
        true_answer = item.get("answer", "").upper().strip()
        pred_text = item.get("pred", "")
        
        extract = True if "</answer>" in pred_text else False
        if extract:
            pred_answer = extract_answer_string(pred_text).upper()
        else:
            pred_answer = pred_text.upper().strip()
        

        if true_answer in pred_answer:
            correct += 1
            accuracy_list.append(True)
        else:
            accuracy_list.append(False)

        confidence_list.append(item["confidence"])
        total += 1
    
    return correct, total, accuracy_list, confidence_list

def split_data_by_type(data):
    open_data = []
    closed_data = []
    
    for item in data:
        if item.get("type", "").upper() == "OPEN":
            open_data.append(item)
        else:
            closed_data.append(item)
    
    return open_data, closed_data

def main():
    parser = argparse.ArgumentParser(description='计算预测准确率')
    parser.add_argument('-i', '--input_file', required=True, help='输入JSON文件路径')
    parser.add_argument('-m', '--model')
    parser.add_argument('-s', '--stage')
    parser.add_argument('-d', '--dataset')
    args = parser.parse_args()

    try:
        with open(args.input_file, 'r') as f:
            data = json.load(f)
    except Exception as e:
        print(f"错误: {str(e)}")
        return

    if not data:
        print("错误: 文件为空")
        return

    # 分割数据为OPEN和CLOSED两类
    open_data, closed_data = split_data_by_type(data)
    
    # 确保输出目录存在
    os.makedirs("figs", exist_ok=True)
    
    # 处理OPEN类型数据
    if open_data:
        print("\nOPEN类型结果:")
        open_correct_ta, open_total, open_acc_list, open_conf_list = calculate_recall(open_data)
        print(f"OPEN召回率: {open_correct_ta/open_total:.4f} ({open_correct_ta}/{open_total})")
        open_ece = calculate_ece_with_recall(open_data)
        print(f"OPEN ECE: {open_ece:.4f}")
    else:
        print("\n警告: 没有找到OPEN类型数据")
    
    # 处理CLOSED类型数据
    if closed_data:
        print("\nCLOSED类型结果:")
        closed_correct, closed_total, closed_acc_list, closed_conf_list = calculate_accuracy(closed_data)
        print(f"CLOSED准确率: {closed_correct/closed_total:.4f} ({closed_correct}/{closed_total})")
        closed_ece = calculate_ece(closed_data)
        print(f"CLOSED ECE: {closed_ece:.4f}")
    else:
        print("\n警告: 没有找到CLOSED类型数据")
    w1 = 0
    w2 = 0
    if 'slake' in args.input_file:
        w1 = 0.6079
        w2 = 0.3921
    elif 'vqa-rad' in args.input_file:
        w1 = 0.3969
        w2 = 0.6031
    elif 'path' in args.input_file:
        w1 = 0.4996
        w2 = 0.5004
    ece = w1 * open_ece + w2 * closed_ece
    print(f"整体ECE: {ece:.4f}")

if __name__ == "__main__":
    main()