import os
from check_hidden_state import (
    get_args,
    get_hidden_state
)
from utils.utils import (
    read_json,
    write_jsonl
)
from llms import BgeRerankCls, QwenRerankCls, GteRerankCls


def calculate_metrics(pos_by_llm_list, pos_list):
    def mean_reciprocal_rank(pred_pos, true_pos):
        for i, item in enumerate(pred_pos, 1):
            if item in true_pos:
                return 1 / i
        return 0

    """
    计算基础分类指标
    :param pos_by_llm_list: 模型预测的正例列表
    :param pos_list: 真实正例列表
    :return: (precision, recall, f1) 元组
    """
    true_positives = len(set(pos_by_llm_list) & set(pos_list))
    if not pos_by_llm_list:
        return 0, 0, 0, 0

    precision = true_positives / len(pos_by_llm_list)
    recall = true_positives / len(pos_list)
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    rr = mean_reciprocal_rank(pos_by_llm_list, pos_list)

    return precision, recall, f1, rr


def eval_rerank(rerank_instance, eval_dataset_path, output_dir, top_k=1):
    eval_dataset = read_json(eval_dataset_path)
    dataset_len = len(eval_dataset)
    assert hasattr(rerank_instance, "generate")

    generate_result = rerank_instance.generate(eval_dataset, top_k=top_k)
    eval_result = []
    print(f"len generate result: {len(generate_result)}")
    print(f"len eval dataset: {len(eval_dataset)}")
    assert len(generate_result) == len(eval_dataset)

    for index in range(len(generate_result)):
        eval_result.append(calculate_metrics(generate_result[index], eval_dataset[index]["pos"]))

    eval_scores = tuple(map(sum, zip(*eval_result)))
    print(eval_scores)
    precision, recall, f1, rr = eval_scores[0]/dataset_len, eval_scores[1]/dataset_len, eval_scores[2]/dataset_len, eval_scores[3]/dataset_len
    output = f"The precision is {precision}, the recall is {recall}, the f1 is {f1}, the rr is: {rr}. Dataset len is {dataset_len}"
    print(output)

    write_jsonl(output, os.path.join(output_dir, f"eval_result_top_{top_k}.jsonl"))
    write_jsonl(generate_result, os.path.join(output_dir, f"generate_result_top_{top_k}.jsonl"))


if __name__ == '__main__':
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--llm_weight', type=str, required=True)
    arg_parser.add_argument('--eval_data_path', type=str, required=True)
    arg_parser.add_argument('--output_data_dir', type=str, required=True)
    arg_parser.add_argument('--reranker_cls', type=str, required=True)
    arg_parser.add_argument('--top_k', type=int, required=True)
    args = arg_parser.parse_args()

    if args.reranker_cls == 'BgeRerankCls':
        reranker_cls = BgeRerankCls(args.llm_weight)
    elif args.reranker_cls == 'QwenRerankCls':
        reranker_cls = QwenRerankCls(args.llm_weight)
    elif args.reranker_cls == 'GteRerankCls':
        reranker_cls = GteRerankCls(args.llm_weight)
    else:
        raise ValueError(f"Unknown reranker_cls cls: {args.reranker_cls}")

    eval_rerank(
        reranker_cls,
        args.eval_data_path,
        args.output_data_dir,
        top_k=args.top_k
    )
