import re
import string
import json
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from enum import Enum, auto
import argparse
import math


from src.utils import json_parser, list_parser, logger


def normalize(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    s = s.lower()
    exclude = set(string.punctuation)
    s = "".join(char for char in s if char not in exclude)
    s = re.sub(r"\b(a|an|the)\b", " ", s)
    # remove <pad> token:
    s = re.sub(r"\b(<pad>)\b", " ", s)
    s = " ".join(s.split())
    return s


def match(s1: str, s2: str) -> bool:
    s1 = normalize(s1)
    s2 = normalize(s2)
    return s2 in s1

def eval_acc(prediction, answer):
    prediction_str = ' '.join(prediction)
    matched = 0
    for a in answer:
        if match(prediction_str, a):
            matched += 1
    return matched / len(answer)

def eval_hit(prediction, answer):
    prediction_str = ' '.join(prediction)
    for a in answer:
        if match(prediction_str, a):
            return 1
    return 0

def eval_f1(prediction, answer):
    if len(prediction) == 0 or len(answer) == 0:
        return 0, 0, 0

    true_positives = 0
    for a in answer:
        if any(match(p, a) for p in prediction):
            true_positives += 1

    p = true_positives / (len(prediction) + 1e-10)
    r = true_positives / (len(answer) + 1e-10)
    f1 = 2 * p * r / (p + r + 1e-10)

    return f1, p, r



def evaluate(predictions, verbose=False):
    results = defaultdict(list)

    for prediction in predictions:
        answer_set = prediction['answer']
        pred = prediction['pred']

        try:
            f1, p, r = eval_f1(pred, answer_set)
            hit1 = eval_hit(pred[:1], answer_set)
            hits = eval_hit(pred, answer_set)
            acc = eval_acc(pred, answer_set)
        except Exception as e:
            print(f"Error: {e}")
            print(f"Prediction: {pred}")
            print(f"Answer: {answer_set}")
            raise e

        results['f1'].append(f1)
        results['p'].append(p)
        results['r'].append(r)
        results['hit1'].append(hit1)
        results['hits'].append(hits)
        results['acc'].append(acc)

    scores = {
        'all_hit1': np.mean(results['hit1']),
        'all_hits': np.mean(results['hits']),
        'all_acc': np.mean(results['acc']),
        'all_f1': np.mean(results['f1']),
        'all_precision': np.mean(results['p']),
        'all_recall': np.mean(results['r'])
    }

    return scores


def categorize_by_hops(prediction):
    """根据truth_paths计算hop数并分类"""
    mean_path_len = prediction.get("mean_path_len", 0)
    mean_path_len = math.ceil(mean_path_len)

    # 将hop数分类为1-hop, 2-hop, 3-hop, 4+hop
    match mean_path_len:
        case 1:
            return "1-hop"
        case 2:
            return "2-hop"
        case 3:
            return "3-hop"
        case _:
            return "4+hop"


def categorize_by_answer_count(prediction):
    """根据答案数量进行分类"""
    answer_count = len(prediction["answer"])

    # 将答案数量分类为#Ans=1, 2≤#Ans≤4, 5≤#Ans≤9, #Ans≥10
    match answer_count:
        case 1:
            return "#Ans=1"
        case 2 | 3 | 4:
            return "2≤#Ans≤4"
        case 5 | 6 | 7 | 8 | 9:
            return "5≤#Ans≤9"
        case _:
            return "#Ans≥10"


def evaluate_by_hops(predictions, verbose=False):
    """按hop数分组评估"""
    hop_groups = defaultdict(list)

    # 按hop分组
    for prediction in predictions:
        hop_category = categorize_by_hops(prediction)
        hop_groups[hop_category].append(prediction)

    hop_results = {}

    # 对每个hop组进行评估
    for hop_category, group_predictions in hop_groups.items():
        if len(group_predictions) > 0:
            hop_results[f"{hop_category}_count"] = len(group_predictions)
            hop_scores = evaluate(group_predictions, verbose=False)

            # 添加hop前缀到每个指标
            for metric, score in hop_scores.items():
                hop_results[f"{hop_category}_{metric}"] = score

    return hop_results


def evaluate_by_answer_count(predictions, verbose=False):
    """按答案数量分组评估"""
    answer_groups = defaultdict(list)

    # 按答案数量分组
    for prediction in predictions:
        answer_category = categorize_by_answer_count(prediction)
        answer_groups[answer_category].append(prediction)

    answer_results = {}

    # 对每个答案数量组进行评估
    for answer_category, group_predictions in answer_groups.items():
        if len(group_predictions) > 0:
            answer_results[f"{answer_category}_count"] = len(group_predictions)
            answer_scores = evaluate(group_predictions, verbose=False)

            # 添加答案数量前缀到每个指标
            for metric, score in answer_scores.items():
                answer_results[f"{answer_category}_{metric}"] = score

    return answer_results


class ParseError(Enum):
    NO_VALID_RESPONSE = auto()
    NO_ANSWER_FOUND = auto()
    NO_ANSWER_IN_GRAPH = auto()
    NO_VALID_FINAL_RESPONSE = auto()
    INVALID_RESULT_FORMAT = auto()
    MISSING_FIELDS = auto()
    OTHER = auto()
    HUMAN_CHECK = auto()


def eval_from_config(config_path, verbose=False, exclude_missing=False):
    from src.config import ConfigBase
    with open(config_path, "r") as f:
        conf = ConfigBase(**json.load(f))

    conf.exclude_missing = exclude_missing

    # 加载数据，获得图中的节点
    # graph_nodes = {}
    # data = load_dataset(conf.dataset, split="test")
    # for ins in data:
    #     nodes = []
    #     for _t in ins["graph"]:
    #         nodes.extend([_t[0], _t[2]])
    #     graph_nodes[ins["id"]] = nodes

    results = {}
    line_count = 0
    predictions = []
    error_counts = {error_type.name: 0 for error_type in ParseError}

    with open(conf.results_path, "r") as f:
        for line in f:
            line_count += 1
            res = json.loads(line)

            if conf.exclude_missing and res.get("answer_not_in_graph"):
                print(f"exclude_missing: {res['id']}")
                continue

            if error_type := res.get("error_type"):
                error_counts[error_type] += 1
                res["pred"] = []
                predictions.append(res)
                continue

            try:
                if conf.final_prompt_type == "rog_final":
                    res["final_response"] = list_parser(res.get("final_response_raw", ""))
                    assert res["final_response"] is not None, f"NO_VALID_FINAL_RESPONSE: {res['final_response_raw']}"
                    pred_answers = res["final_response"]

                else:
                    res["final_response"] = json_parser(res.get("final_response_raw", ""))
                    assert res["final_response"] is not None, f"NO_VALID_FINAL_RESPONSE: {res['final_response_raw']}"

                    pred_answers = res["final_response"]["possible_answers"]
                    top_answer = res["final_response"]["most_possible_answer"]
                    if isinstance(top_answer, list):
                        top_answer = top_answer[0]

                    pred_answers.remove(top_answer) if top_answer in pred_answers else None

                    pred_answers = [top_answer] + list(set(pred_answers))
                    pred_answers = [a for a in pred_answers if a != "Unknown" and a != ["Unknown"]]

                    # 如果节点在图中，则保留
                    # pred_answers = [a for a in pred_answers if a in graph_nodes[res["id"]]]

                    for pred in list(set(pred_answers)):
                        if not isinstance(pred, str):
                            raise ValueError(f"INVALID_RESULT_FORMAT: {pred}, {pred_answers}")

                assert len(pred_answers) > 0, f"MISSING_FIELDS: {res['final_response_raw']}"
                res["pred"] = pred_answers

            except Exception as e:
                if "INVALID_RESULT_FORMAT" in str(e):
                    error_type = ParseError.INVALID_RESULT_FORMAT.name
                elif "MISSING_FIELDS" in str(e):
                    error_type = ParseError.MISSING_FIELDS.name
                elif "NO_VALID_FINAL_RESPONSE" in str(e):
                    error_type = ParseError.NO_VALID_FINAL_RESPONSE.name
                else:
                    error_type = ParseError.OTHER.name

                error_counts[error_type] += 1
                res["pred"] = []

            predictions.append(res)

            if line_count in [100, 300, 1000]:
                results[f"part_{line_count}"] = evaluate(predictions, verbose)

    # 输出评估结果
    results.update(evaluate(predictions, verbose))
    results["count"] = line_count
    results["error_count"] = sum(error_counts.values())

    # 添加按hop分组的评估结果
    hop_results = evaluate_by_hops(predictions, verbose)
    results.update(hop_results)

    # 添加按答案数量分组的评估结果
    answer_results = evaluate_by_answer_count(predictions, verbose)
    results.update(answer_results)

    # 添加错误类型统计
    errors = {"error_rate": f"{results['error_count']}/{results['count']} ({results['error_count'] / results['count'] * 100:.2f}%)"}
    for error_type in ParseError:
        count = error_counts[error_type.name]
        percentage = count / line_count * 100
        errors[error_type.name] = f"{count} ({percentage:.2f}%)"

    # 计算中间分数
    scores = defaultdict(list)
    for res in tqdm(predictions, desc="计算中间分数"):
        # candidate_answers
        f1, p, r = eval_f1(res.get("candidate_answers", []), res["answer"])
        scores["llm_f1"].append(f1)
        scores["llm_precision"].append(p)
        scores["llm_recall"].append(r)
        scores["llm_hit"].append(eval_hit(res.get("candidate_answers", []), res["answer"]))
        scores["llm_nodes"].append(len(res.get("candidate_answers", [])))

        # sub_nodes
        f1, p, r = eval_f1(res.get("sub_nodes", []), res["answer"])
        scores["inter_f1"].append(f1)
        scores["inter_precision"].append(p)
        scores["inter_recall"].append(r)
        scores["inter_hit"].append(eval_hit(res.get("sub_nodes", []), res["answer"]))
        scores["inter_nodes"].append(len(res.get("sub_nodes", [])))

        rerank_nodes = [p["node"] for p in res.get("path_top_precent", [])]
        f1, p, r = eval_f1(rerank_nodes, res["answer"])
        scores["rerank_f1"].append(f1)
        scores["rerank_precision"].append(p)
        scores["rerank_recall"].append(r)
        scores["rerank_hit"].append(eval_hit(rerank_nodes, res["answer"]))
        scores["rerank_nodes"].append(len(rerank_nodes))

    # 输出平均中间分数
    for key in scores:
        results[key] = np.mean(scores[key])

    if verbose:
        count = len(scores['llm_hit'])
        print("=" * 80)
        print("整体评估结果:")
        print("=" * 80)
        print(f"llm_hit:    {results['llm_hit']*100:.2f}% ({sum(scores['llm_hit'])}/{count}) / recall: {results['llm_recall']*100:.2f}% / precision: {results['llm_precision']*100:.2f}% / f1: {results['llm_f1']*100:.2f}% ({np.mean(scores['llm_nodes']):.0f}) ")
        print(f"inter_hit:  {results['inter_hit']*100:.2f}% ({sum(scores['inter_hit'])}/{count}) / recall: {results['inter_recall']*100:.2f}% / precision: {results['inter_precision']*100:.2f}% / f1: {results['inter_f1']*100:.2f}% ({np.mean(scores['inter_nodes']):.0f}) ")
        print(f"rerank_hit: {results['rerank_hit']*100:.2f}% ({sum(scores['rerank_hit'])}/{count}) / recall: {results['rerank_recall']*100:.2f}% / precision: {results['rerank_precision']*100:.2f}% / f1: {results['rerank_f1']*100:.2f}% ({np.mean(scores['rerank_nodes']):.0f}) ")
        print("-" * 80)
        print(f"all_hit1: {results['all_hit1']*100:.2f}% ({results['all_hit1']*count:.0f}/{count})")
        print(f"all_hits: {results['all_hits']*100:.2f}% ({results['all_hits']*count:.0f}/{count})")
        print(f"all_f1:   {results['all_f1']*100:.2f}%; precision: {results['all_precision']*100:.2f}%; recall: {results['all_recall']*100:.2f}%")

        # 输出按hop分组的评估结果
        print("\n" + "=" * 80)
        print("按Hop数分组的评估结果:")
        print("=" * 80)

        hop_categories = ["1-hop", "2-hop", "3-hop", "4+hop"]
        for hop_cat in hop_categories:
            if f"{hop_cat}_count" in results:
                hop_count = results[f"{hop_cat}_count"]
                hit1 = results.get(f"{hop_cat}_all_hit1", 0)
                hits = results.get(f"{hop_cat}_all_hits", 0)
                f1 = results.get(f"{hop_cat}_all_f1", 0)
                precision = results.get(f"{hop_cat}_all_precision", 0)
                recall = results.get(f"{hop_cat}_all_recall", 0)

                print(f"\n{hop_cat}: {hop_count} 个样本")
                print(f"  hit1: {hit1*100:.2f}% ({hit1*hop_count:.0f}/{hop_count})")
                print(f"  hits: {hits*100:.2f}% ({hits*hop_count:.0f}/{hop_count})")
                print(f"  f1: {f1*100:.2f}%; precision: {precision*100:.2f}%; recall: {recall*100:.2f}%")

        print("\n" + "-" * 80)
        print(f"error_rate: {errors['error_rate']}")
        for error_type in ParseError:
            print(f"{error_type.name}: {errors[error_type.name]}")

        # 输出按答案数量分组的评估结果
        print("\n" + "=" * 80)
        print("按答案数量分组的评估结果:")
        print("=" * 80)

        answer_categories = ["#Ans=1", "2≤#Ans≤4", "5≤#Ans≤9", "#Ans≥10"]
        for ans_cat in answer_categories:
            if f"{ans_cat}_count" in results:
                ans_count = results[f"{ans_cat}_count"]
                hit1 = results.get(f"{ans_cat}_all_hit1", 0)
                hits = results.get(f"{ans_cat}_all_hits", 0)
                f1 = results.get(f"{ans_cat}_all_f1", 0)
                precision = results.get(f"{ans_cat}_all_precision", 0)
                recall = results.get(f"{ans_cat}_all_recall", 0)

                print(f"\n{ans_cat}: {ans_count} 个样本")
                print(f"  hit1: {hit1*100:.2f}% ({hit1*ans_count:.0f}/{ans_count})")
                print(f"  hits: {hits*100:.2f}% ({hits*ans_count:.0f}/{ans_count})")
                print(f"  f1: {f1*100:.2f}%; precision: {precision*100:.2f}%; recall: {recall*100:.2f}%")

        print("=" * 80)

    conf.results = results
    conf.errors = errors
    with open(config_path, "w") as f:
        json.dump(conf.to_dict(), f, ensure_ascii=False, indent=4)

    return results


def check_error(config_path):
    from src.config import ConfigBase
    with open(config_path, "r") as f:
        conf = ConfigBase(**json.load(f))

    with open(conf.results_path, "r") as f:
        lines = f.readlines()
    for line in lines:
        res = json.loads(line)
        if res.get("error_type"):
            print(f"{res['id']}: {res['error_type']}\nFinal Response: {res.get('final_response_raw', '')}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate results file")
    parser.add_argument("path", help="Path to the file")
    parser.add_argument("--exclude-missing", help="The missing id", action="store_true")
    parser.add_argument("--check-error", action="store_true", help="Whether to check error")
    args = parser.parse_args()
    results = eval_from_config(args.path, True, exclude_missing=args.exclude_missing)
    if args.check_error:
        check_error(args.path)
