import json
import pandas as pd
import re
import string


def get_accuracy_gqa(path):
    df = pd.read_json(path, lines=True)
    # compute accuracy
    correct = 0
    for pred, label in zip(df["pred"], df["label"]):
        if label in pred:
            correct += 1
    return correct / len(df)


def get_accuracy_expla_graphs(path):
    df = pd.read_json(path, lines=True)
    # compute accuracy
    correct = 0
    for pred, label in zip(df["pred"], df["label"]):
        matches = re.findall(r"support|Support|Counter|counter", pred.strip())
        if len(matches) > 0 and matches[0].lower() == label:
            correct += 1

    return correct / len(df)


def normalize(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    s = s.lower()
    exclude = set(string.punctuation)
    s = "".join(char for char in s if char not in exclude)
    s = re.sub(r"\b(a|an|the)\b", " ", s)
    # remove <pad> token:
    s = re.sub(r"\b(<pad>)\b", " ", s)
    s = " ".join(s.split())
    return s


def match(s1: str, s2: str) -> bool:
    s1 = normalize(s1)
    s2 = normalize(s2)
    return s2 in s1


def eval_f1(prediction, answer):
    if len(prediction) == 0:
        return 0, 0, 0
    matched = 0
    prediction_str = " ".join(prediction)
    for a in answer:
        if match(prediction_str, a):
            matched += 1
    precision = matched / len(prediction)
    recall = matched / len(answer)
    if precision + recall == 0:
        return 0, precision, recall
    else:
        return 2 * precision * recall / (precision + recall), precision, recall


def eval_acc(prediction, answer):
    matched = 0.0
    for a in answer:
        if match(prediction, a):
            matched += 1
    return matched / len(answer)


def eval_hit(prediction, answer):
    for a in answer:
        if match(prediction, a):
            return 1
    return 0


def get_accuracy_webqsp(path):
    df = pd.read_json(path, lines=True)

    # Load results
    acc_list = []
    hit_list = []
    f1_list = []
    precission_list = []
    recall_list = []

    for prediction, answer in zip(df.pred.tolist(), df.label.tolist()):

        prediction = prediction.replace("|", "\n")
        answer = answer.split("|")

        prediction = prediction.split("\n")
        f1_score, precision_score, recall_score = eval_f1(prediction, answer)
        f1_list.append(f1_score)
        precission_list.append(precision_score)
        recall_list.append(recall_score)
        prediction_str = " ".join(prediction)
        acc = eval_acc(prediction_str, answer)
        hit = eval_hit(prediction_str, answer)
        acc_list.append(acc)
        hit_list.append(hit)

    acc = sum(acc_list) * 100 / len(acc_list)
    hit = sum(hit_list) * 100 / len(hit_list)
    f1 = sum(f1_list) * 100 / len(f1_list)
    pre = sum(precission_list) * 100 / len(precission_list)
    recall = sum(recall_list) * 100 / len(recall_list)

    print(f"Accuracy: {acc:.4f}")
    print(f"Hit: {hit:.4f}")
    print(f"Precision: {pre:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1: {f1:.4f}")

    return hit

from sklearn.metrics import f1_score, accuracy_score, classification_report
def get_accuracy_glbench(path):
    # 读取 JSON 文件
    df = pd.read_json(path, lines=True)

    # 获取预测值和真实标签
    y_true = df["label"].tolist()
    y_pred = df["pred"].tolist()

    # 获取所有唯一的标签
    labels = list(set(y_true))
    print(f"Unique labels: {labels}")

    # 计算整体准确率
    accuracy = accuracy_score(y_true, y_pred)

    # 计算 F1 分数（支持多分类）
    f1 = f1_score(y_true, y_pred, labels=labels, average='macro')  # 使用 'macro' 平均方法

    # 计算每个类别的准确率
    class_accuracies = {}
    for label in labels:
        # 找到属于该类别的索引
        indices = [i for i, true_label in enumerate(y_true) if true_label == label]
        if not indices:
            class_accuracies[label] = 0.0  # 如果没有该类别的样本，准确率为 0
        else:
            # 计算该类别中被正确分类的样本数
            correct_predictions = sum(1 for i in indices if y_pred[i] == label)
            class_accuracy = correct_predictions / len(indices)
            class_accuracies[label] = class_accuracy

    # 打印每个类别的准确率
    # print("Class-wise accuracies:")
    # for label, acc in class_accuracies.items():
    #     print(f"Class {label}: {acc:.4f}")

    # 打印详细的分类报告（包括精确率、召回率、F1 分数等）
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, labels=labels, target_names=[str(label) for label in labels]))

    return accuracy, f1

eval_funcs = {
    "expla_graphs": get_accuracy_expla_graphs,
    "scene_graphs": get_accuracy_gqa,
    "scene_graphs_baseline": get_accuracy_gqa,
    "webqsp": get_accuracy_webqsp,
    "webqsp-gr": get_accuracy_webqsp,
    "webqsp_baseline": get_accuracy_webqsp,
    "glbench": get_accuracy_glbench,
    "cwq": get_accuracy_webqsp
}
