# Adopted from https://github.com/THUDM/LongBench
import os
import json
import argparse
import numpy as np
import re 

from metrics import (
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
)


QA_TASKS = ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique"]

dataset2metric = {
    "narrativeqa": qa_f1_score,
    "qasper": qa_f1_score,
    "multifieldqa_en": qa_f1_score,
    "hotpotqa": qa_f1_score,
    "2wikimqa": qa_f1_score,
    "musique": qa_f1_score,

    "gov_report": rouge_score,
    "qmsum": rouge_score,
    "multi_news": rouge_score,
    "vcsum": rouge_zh_score,
    "trec": classification_score,
    "triviaqa": qa_f1_score,
    "samsum": rouge_score,
    "lsht": classification_score,
    "passage_retrieval_en": retrieval_score,
    "passage_count": count_score,
    "passage_retrieval_zh": retrieval_zh_score,
    "lcc": code_sim_score,
    "repobench-p": code_sim_score,
}

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='llama3_8b')
    parser.add_argument('--directory', type=str, default='lb1_table')
    return parser.parse_args(args)


def scorer(dataset, predictions, answers, all_classes, best_at_k=False):
    if '-' in dataset:
        dataset = dataset.split('-')[0]
    total_score = 0.
    for prediction in predictions:
        score = 0.
        if best_at_k and isinstance(prediction, list):
            for ground_truths in answers:
                for ground_truth in ground_truths:
                    for pred in prediction:
                        score = max(score, dataset2metric[dataset](pred, ground_truth, all_classes=all_classes))
                        if score == 1.0 and dataset in QA_TASKS:
                            break
                    if score == 1.0 and dataset in QA_TASKS:
                        break
                if score == 1.0 and dataset in QA_TASKS:
                    break   
        else:
            score = 0.
            if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
                prediction = prediction.lstrip('\n').split('\n')[0]
            for ground_truth in ground_truths:
                score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))

        total_score += score
    if len(predictions) == 0: return 0
    else:
        return round(100 * total_score / len(predictions), 2)

if __name__ == '__main__':
    args = parse_args()
    
    # The path contains files for each datasets and seeds
    path =f'{args.directory}/{args.model_name}/'

    sub_models = os.listdir(path)
    ds = ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique"]

    print('=' * 100)
    for sub_model in sub_models:
        scores = dict()
        if not os.path.isdir(os.path.join(path, sub_model)):
            continue
        print('=' * 100)
        print("Evaluating on:", sub_model)
        all_files = os.listdir(os.path.join(path, sub_model))
        # print("Evaluating on:", all_files)
        out_path = f'{path}/{sub_model}/result.json'
        for filename in all_files:
            if not filename.endswith("jsonl"):
                continue
            predictions, answers, lengths = [], [], []
            dataset = filename.split('.')[0]

            # if dataset not in ds:
            #     continue
            # else:
            #     print(f"EVALUATING {dataset}")
            print('EVAL on ', dataset)
            with open(f"{path}/{sub_model}/{filename}", "r", encoding="utf-8") as f:
                for line in f:
                    try:
                        data = json.loads(line)
                        # Process the data
                    except json.JSONDecodeError:
                        # Skip the line with error
                        print(f"DECODING ERROR {path}{filename}")
                        continue

                    try:
                        pred_list = []
                        # Extract answers from <answer>ANSWER</answer> tags
                        if 'no_tag' not in filename:
                            if '<answer>' in data["pred"] and '</answer>' in data["pred"]:
                                pred_list = re.findall(r'<answer>(.*?)</answer>', data["pred"])
                                pred_list = [answer.strip() for answer in pred_list]

                            else:
                                pred_list = [data["pred"].strip()]
                        else:
                            pred_list = [data["pred"].strip()]

                        # print(pred_list)
                        predictions.append(pred_list)
                        
                        answers.append(data["answers"])
                        all_classes = data["all_classes"]
                        if "length" in data:
                            lengths.append(data["length"])
                    except:
                        print(f"Key error occur for {filename}")
                        print(data)
                        continue

            try:
                score = scorer(dataset, predictions, answers, all_classes, best_at_k=True)
                
                scores[dataset] = score
                if len(predictions) != 40:
                    print(f"filename: {filename}, length: {len(predictions)}, score: {score}")
            except:
                pass

        
        gathered_dict = {d: [] for d in ds}
        total = []
        for k, v in scores.items():
            key = k.split('-')[0]
            gathered_dict[key].append(v)
            total.append(v)

        print(gathered_dict)

        print(' & '.join(['{:.2f}'.format(np.mean(sc)) for sc in gathered_dict.values()]), ' & {:.2f}'.format(np.mean(total)))

        with open(out_path, "w") as f:
            json.dump(scores, f, ensure_ascii=False, indent=4)



