import os
import json
import argparse
import numpy as np
import openai
from datasets import load_dataset
from alpaca_farm.auto_annotations import alpaca_leaderboard
import datasets
from metrics import (
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
)
openai.api_key_path = "data/openai_api_key.txt"

dataset2metric = {
    "narrativeqa": qa_f1_score,
    "qasper": qa_f1_score,
    "multifieldqa_en": qa_f1_score,
    "multifieldqa_zh": qa_f1_zh_score,
    "hotpotqa": qa_f1_score,
    "2wikimqa": qa_f1_score,
    "musique": qa_f1_score,
    "dureader": rouge_zh_score,
    "gov_report": rouge_score,
    "qmsum": rouge_score,
    "multi_news": rouge_score,
    "vcsum": rouge_zh_score,
    "trec": classification_score,
    "triviaqa": qa_f1_score,
    "samsum": rouge_score,
    "lsht": classification_score,
    "passage_retrieval_en": retrieval_score,
    "passage_count": count_score,
    "passage_retrieval_zh": retrieval_zh_score,
    "lcc": code_sim_score,
    "repobench-p": code_sim_score,
    "konwledge_memorization": qa_f1_score,
    "konwledge_understanding": qa_f1_score,
    "longform_qa": rouge_score,
    "finance_qa": rouge_score,
}

def parse_args(args=None):
    parser = argparse.ArgumentParser(description="Evaluate texts generated by every method")

    parser.add_argument(
        "--input_dir",
        type=str,
        default="/data2/tsq/WaterBench/pred/llama2-7b-chat-4k_no_g0.5_d5.0")
    args = parser.parse_args()

    return args

# def scorer_e(dataset, predictions, answers, lengths, all_classes):
#     scores = {"0-4k": [], "4-8k": [], "8k+": []}
#     for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
#         score = 0.
#         if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
#             prediction = prediction.lstrip('\n').split('\n')[0]
#         for ground_truth in ground_truths:
#             score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
#         if length < 4000:
#             scores["0-4k"].append(score)
#         elif length < 8000:
#             scores["4-8k"].append(score)
#         else:
#             scores["8k+"].append(score)
#     for key in scores.keys():
#         scores[key] = round(100 * np.mean(scores[key]), 2)
#     return scores

def scorer(dataset, predictions, answers, all_classes):
    total_score = 0.
    for (prediction, ground_truths) in zip(predictions, answers):
        score = 0.
        if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip('\n').split('\n')[0]
        for ground_truth in ground_truths:
            score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
        total_score += score
    return round(100 * total_score / len(predictions), 2)

def alpacafarm_score(prompts, predictions, model_name):
    # outputs should be a list of json as such:
    # [{'instruction': 'What are the names of some famous actors that started their careers on Broadway?', 'input': '', 'output': 'Some famous actors that started their careers on Broadway are Hugh Jackman, Meryl Streep, Denzel Washington, Audra McDonald, and Lin-Manuel Miranda.', 'generator': 'gpt-3.5-turbo-0301', 'dataset': 'helpful_base', 'datasplit': 'eval'},
    # ...]
    my_outputs = []
    alapaca_eval_data = load_dataset("tatsu-lab/alpaca_farm", "alpaca_farm_evaluation")["eval"]
    for i, json_obj in enumerate(alapaca_eval_data):
        prompt = json_obj["instruction"]
        _input = json_obj["input"]
        prediction = predictions[i]
        my_outputs.append({"instruction": prompt, "input": _input, "generator": model_name, "output": prediction})
    print("my_outputs[0] is:", my_outputs[0])
    df_results = alpaca_leaderboard(
        path_or_all_outputs=my_outputs,
        name=model_name,
        is_add_reference_methods=False,
        annotators_config = "greedy_gpt4/configs.yaml"
    )
    score = df_results.to_string(float_format="%.2f")
    return score


if __name__ == '__main__':
    args = parse_args()
    scores = dict()
    # get all files from input_dir
    files = os.listdir(args.input_dir)
    model_name = args.input_dir.split("/")[-1]
    # get all json files
    json_files = [f for f in files if f.endswith(".jsonl")]
    save_dir =  os.path.join(args.input_dir, "eval")
    os.makedirs(save_dir, exist_ok=True)
    print("Evaluating on:", files)
    for json_file in json_files:
        if not json_file.endswith("jsonl"):
            continue
        print(f"{json_file} has began.........")
        # read jsons
        dataset = json_file.split(".")[0]
        predictions, answers, lengths, all_classes = [], [], [], []
        with open(os.path.join(args.input_dir, json_file), "r") as f:
            # lines
            lines = f.readlines()
            # texts
            prompts = [json.loads(line)["prompt"] for line in lines]
            predictions = [json.loads(line)["pred"] for line in lines]
            answers = [json.loads(line)["answers"] for line in lines]
            all_classes = json.loads(lines[0])["all_classes"]
            print(f"predictions[0] is: {predictions[0]}")
            if dataset == "alpacafarm":
                score = alpacafarm_score(prompts, predictions, model_name)
            else:
                score = scorer(dataset, predictions, answers, all_classes)
            scores[dataset] = score
    # save
    out_path = os.path.join(save_dir, "result.json")
    with open(out_path, "w") as f:
        json.dump(scores, f, ensure_ascii=False, indent=4)
