# This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval
import os
import argparse
import pandas as pd
import torch
import json
from llama2_evaluator import Llama_Evaluator
from glob import glob
import time
from collections import defaultdict
from categories import name_en2zh, subcategories, categories
choices = ["A", "B", "C", "D"]

category2subject = defaultdict(list)
for k,v in categories.items():
    for subject, subcat in subcategories.items():
        for c in subcat:
            if c in v:
                category2subject[k].append(subject)
category2subject_list = defaultdict(list)
for key,value in category2subject.items():
      for val in value:
          category2subject_list[val]=[val,name_en2zh[val],key]
category2subject=category2subject_list
choices = ["A", "B", "C", "D"]

def main(args, evaluator,take):

    subject_mapping = category2subject #json.load(f)
    filenames = [s.split('/')[-1] for s in glob(args.input_dir+"/test/*csv")]
    subject_list = [val_file.replace(".csv","") for val_file in filenames]
    accuracy, summary = {}, {}

    run_date=time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(time.time()))
    output_dir = args.output_dir
    save_result_dir=os.path.join(output_dir,f"take{take}")
    if not os.path.exists(save_result_dir):
        os.makedirs(save_result_dir,exist_ok=True)

    all_answers = {}
    for index,subject_name in enumerate(subject_list):
        print(f"{index/len(subject_list)} Inference starts at {run_date} on {args.model_path} with subject of {subject_name}!")
        val_file_path=os.path.join(args.input_dir+'/test',f'{subject_name}.csv')
        dev_file_path=os.path.join(args.input_dir+'/dev',f'{subject_name}.csv')

        val_df=pd.read_csv(val_file_path)
        dev_df=pd.read_csv(dev_file_path) if args.few_shot else None

        correct_ratio, answers = evaluator.eval_subject(subject_name, val_df, dev_df,
            save_result_dir=save_result_dir if args.do_save_csv else None,
            few_shot=args.few_shot,
            cot=args.cot,
            with_prompt=args.with_prompt,
            constrained_decoding=args.constrained_decoding,
            do_test=False)
        print(f"Subject: {subject_name}")
        print(f"Acc: {correct_ratio}")
        accuracy[subject_name] = correct_ratio
        summary[subject_name] = {"score":correct_ratio,
                                 "num":len(val_df),
                                 "correct":correct_ratio*len(val_df)/100}
        all_answers[subject_name] = answers

    json.dump(all_answers,open(save_result_dir+'/submission.json','w'),ensure_ascii=False,indent=4)
    print("\n\nModel:",args.model_path)
    print("Accuracy:")
    for k, v in accuracy.items():
        print(k, ": ", v)


    total_num = 0
    total_correct = 0
    summary['grouped'] = {
        "China specific": {"correct": 0.0, "num": 0},
        "STEM": {"correct": 0.0, "num": 0}, 
        "Social Science": {"correct": 0.0, "num": 0}, 
        "Humanities": {"correct": 0.0, "num": 0}, 
        "Other": {"correct": 0.0, "num": 0}
        }
    for subj, info in subject_mapping.items():
        group = info[2]
        summary['grouped'][group]["num"]   += summary[subj]['num']
        summary['grouped'][group]["correct"] += summary[subj]['correct']
    for group, info in summary['grouped'].items():
        info['score'] = info["correct"] / info["num"]
        total_num += info["num"]
        total_correct += info["correct"]
    summary['All'] = {"score": total_correct / total_num, "num": total_num, "correct": total_correct}

    json.dump(summary,open(save_result_dir+'/summary.json','w'),ensure_ascii=False,indent=2)




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=5)
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--cot",choices=["False","True"], default="False")
    parser.add_argument("--few_shot", choices=["False","True"], default="True")
    parser.add_argument("--with_prompt", choices=["False","True"], default="False")
    parser.add_argument("--constrained_decoding", choices=["False","True"], default="False")
    parser.add_argument("--temperature",type=float,default=0.2)
    parser.add_argument("--n_times", default=1,type=int)
    parser.add_argument("--do_save_csv", choices=["False","True"], default="False")
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--input_dir", type=str)
    parser.add_argument("--verbose", action="store_true", help="Print detailed information of each example.")

    args = parser.parse_args()

    args.cot = args.cot == "True"
    args.few_shot = args.few_shot == "True"
    args.with_prompt = args.with_prompt == "True"
    args.do_save_csv = args.do_save_csv == "True"
    args.constrained_decoding = args.constrained_decoding == "True"
    if args.constrained_decoding is True:
        args.n_times=max(args.n_times,1)
    print(args)

    device = torch.device(0)
    print(device)
    evaluator=Llama_Evaluator(
        choices=choices,
        k=args.ntrain,
        model_path=args.model_path,
        device=device,
        temperature = args.temperature,
        verbose = args.verbose
    )
    for i in range(args.n_times):
        main(args,evaluator=evaluator,take=i)
