import os
import json
from utils.utils import (
    write_jsonl,
    write_json,
    read_json,
    has_answer,
    has_answer_by_llm
)
from llms import llm_client_another, llm_client, BgeRerankCls, QwenRerankCls, GteRerankCls
import argparse


def generate_query_by_rerank_nq(input_data_path, output_data_path, reranker, top_k=1):
    print(f"output_data_path: {output_data_path}")
    print(f"reranker: {reranker}")
    input_data = read_json(input_data_path)
    rerank_meta_data = []
    output_data = []

    for item in input_data:
        question = item["question"]
        context = item["context"]
        rerank_meta_data.append({
            "query": question,
            "pos": context,
            "neg": [],
            "reference": item["reference"]
        })

    rerank_result = reranker.generate(rerank_meta_data, top_k=top_k)
    assert len(rerank_result) == len(input_data)
    for _index in range(len(rerank_result)):
        output_data.append({
            "query": rerank_meta_data[_index]["query"],
            "context": rerank_result[_index],
            "reference": rerank_meta_data[_index]["reference"]
        })

    write_jsonl(output_data, output_data_path)


def generate_query_by_rerank_hotpot(input_data_path, output_data_path, reranker, top_k=1):
    print(f"output_data_path: {output_data_path}")
    print(f"reranker: {reranker}")

    input_data = json.loads(open(input_data_path).read())
    rerank_meta_data = []
    output_data = []

    for item in input_data:
        qid = item["_id"]
        question = item["question"]
        context = [" ".join(i[1]) for i in item["context"]]
        rerank_meta_data.append({
            "qid": qid,
            "query": question,
            "pos": context,
            "neg": []
        })

    rerank_result = reranker.generate(rerank_meta_data, top_k=top_k)
    assert len(rerank_result) == len(input_data)
    for _index in range(len(rerank_result)):
        output_data.append({
            "qid": rerank_meta_data[_index]["qid"],
            "query": rerank_meta_data[_index]["query"],
            "context": rerank_result[_index],
        })

    write_jsonl(output_data, output_data_path)


def generate_result_by_llm_hotpot(input_data_path, output_data_path, llm_ins):
    prompt = "You are a rigorous language model. Please answer the question based on the provided context. If the context does not support reasoning about the answer, please answer the question based on your own knowledge. \n Context: {context} \nQuestion: {question}"
    input_data = read_json(input_data_path)
    length = len(input_data)
    output_data = {"answer": {}}
    for _index, item in enumerate(input_data):
        print(f"index: {_index}/{length}")
        qid = item["qid"]
        question = item["query"]
        context = item["context"]
        request = prompt.format(context=context, question=question)
        response = llm_ins.chat([{"role": "user", "content": request}])
        output_data["answer"][qid] = response

    write_json(output_data, output_data_path)


def generate_result_by_llm_nq(input_data_path, output_data_path, llm_ins):
    prompt = "You are a rigorous language model. Please answer the question based on the provided context. If the context does not support reasoning about the answer, please answer the question based on your own knowledge. \n Context: {context} \nQuestion: {question}"
    input_data = read_json(input_data_path)
    length = len(input_data)
    output_data = []
    for _index, item in enumerate(input_data):
        print(f"index: {_index}/{length}")
        question = item["query"]
        context = item["context"]
        request = prompt.format(context=context, question=question)
        response = llm_ins.chat([{"role": "user", "content": request}])
        output_data.append({
            "question": question,
            "answer": response,
            "context": context,
            "reference": item["reference"]
        })

    write_json(output_data, output_data_path)


def generate_result_by_llm_nq_dynamic(hidden_prob_data_path, input_data_path, output_data_path, llm_ins, threshold):
    prompt = "You are a rigorous language model. Please answer the question based on the provided context. If the context does not support reasoning about the answer, please answer the question based on your own knowledge. \n Context: {context} \nQuestion: {question}"
    meta_prompt = "You need to read the question carefully and answer it based on your own knowledge.\nQuestion: {question}"
    prob_data = read_json(hidden_prob_data_path)
    input_data = read_json(input_data_path)
    length = len(input_data)
    output_data = []
    for _index, item in enumerate(input_data):
        print(f"index: {_index}/{length}")
        _neg_prob, _pos_prob = prob_data[_index][0]
        question = item["query"]
        context = item["context"]
        if float(_pos_prob) >= threshold:
            tag = "meta"
            request = meta_prompt.format(question=question)
        else:
            tag = "rag"
            request = prompt.format(context=context, question=question)
        response = llm_ins.chat([{"role": "user", "content": request}])
        output_data.append({
            "tag": tag,
            "question": question,
            "answer": response,
            "context": context,
            "reference": item["reference"]
        })

    write_json(output_data, output_data_path)


def eval_result_for_llm_hotpot(prediction_file, gold_file):
    with open(prediction_file) as f:
        prediction = json.load(f)
    with open(gold_file) as f:
        gold = json.load(f)

    metrics = {'prec': 0}
    miss_count = 0
    for dp in gold:
        cur_id = dp['_id']
        if cur_id not in prediction['answer']:
            print('missing answer {}'.format(cur_id))
            miss_count += 1
        else:
            _precision = has_answer([dp["answer"]], prediction['answer'][cur_id])
            metrics['prec'] += _precision

    precision = metrics['prec'] / (len(gold) - miss_count)
    print(f"Prediction file is : {prediction_file}")
    print(f"Precision: {precision}")


def eval_result_for_llm_nq(prediction_file):
    with open(prediction_file) as f:
        prediction = json.load(f)

    metrics = {'prec': 0}
    for dp in prediction:
        _precision = has_answer(dp["reference"], dp["answer"])
        metrics['prec'] += _precision

    precision = metrics['prec'] / len(prediction)
    print(f"Prediction file is : {prediction_file}")
    print(f"Precision: {precision}")


if __name__ == '__main__':
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--llm_weight', type=str, required=True)
    arg_parser.add_argument('--input_data_path', type=str, required=True)
    arg_parser.add_argument('--output_data_path', type=str, required=True)
    arg_parser.add_argument('--eval_data_path', type=str, required=True)
    arg_parser.add_argument('--reranker_cls', type=str, required=True)
    arg_parser.add_argument('--top_k', type=int, required=True)
    arg_parser.add_argument('--data_type', type=str, required=True)
    arg_parser.add_argument('--gold_data_path', type=str, required=False)
    arg_parser.add_argument('--dynamic', type=int, required=False)
    arg_parser.add_argument('--threshold', type=float, required=False)
    arg_parser.add_argument('--hidden_prob_data_path', type=str, required=False)
    args = arg_parser.parse_args()

    print(f"LLM weight: {args.llm_weight}")
    print(f"Input data path: {args.input_data_path}")
    print(f"Output data path: {args.output_data_path}")
    print(f"Eval data path: {args.eval_data_path}")

    print(f"Will generate query {args.llm_weight}")
    if args.reranker_cls == 'BgeRerankCls':
        reranker_cls = BgeRerankCls(args.llm_weight)
    elif args.reranker_cls == 'QwenRerankCls':
        reranker_cls = QwenRerankCls(args.llm_weight)
    elif args.reranker_cls == 'GteRerankCls':
        reranker_cls = GteRerankCls(args.llm_weight)
    else:
        raise ValueError(f"Unknown llm cls: {args.reranker_cls}")

    if args.data_type == "nq":
        generate_query_by_rerank_nq(
            args.input_data_path,
            args.output_data_path,
            reranker_cls,
            top_k=args.top_k
        )
        print(f"Finish generate query {args.llm_weight}")

        print(f"Will process {args.llm_weight}")
        if args.dynamic == 1:
            generate_result_by_llm_nq_dynamic(
                args.hidden_prob_data_path,
                args.input_data_path,
                args.output_data_path,
                llm_client,
                args.threshold
            )
        else:
            generate_result_by_llm_nq(
                args.output_data_path,
                args.eval_data_path,
                llm_client)
        print(f"Success finish {args.llm_weight}")

        print(f"Will eval {args.llm_weight}")
        eval_result_for_llm_nq(args.eval_data_path)
        print(f"Finish eval {args.llm_weight}")

    elif args.data_type == "hotpot_qa":
        generate_query_by_rerank_hotpot(
            args.input_data_path,
            args.output_data_path,
            reranker_cls,
            top_k=args.top_k
        )
        print(f"Finish generate query {args.llm_weight}")

        print(f"Will process {args.llm_weight}")
        generate_result_by_llm_hotpot(
            args.output_data_path,
            args.eval_data_path,
            llm_client)
        print(f"Success finish {args.llm_weight}")

        print(f"Will eval {args.llm_weight}")
        eval_result_for_llm_hotpot(args.eval_data_path, args.gold_data_path)
        print(f"Finish eval {args.llm_weight}")

