# encoding:utf-8

from score import calculate_score,calculate_score_cn
import json
import os










def test_en(file):

    pred=[]
    truth_answer = []
    with open(file, 'r') as f:
        for num,line in enumerate(f):
            dic= json.loads(line)
            pred.append(dic["pred"].strip("<|eot_id|>").strip("<|im_end|>").strip("</s>").strip("<end_of_turn>\n<eos>").strip("\n<end_of_turn><eos>").strip("<|user|>"))
           


            truth_answer.append(dic["ground_truth"].replace("\n"," "))
            
    calculate_score(pred,truth_answer)


def test_cn(file):
 
    pred=[]
    t=0
    truth_answer = []
    with open(file, 'r',encoding="utf-8") as f:
        for num,line in enumerate(f):
            dic= json.loads(line)

            
            if "pred" in dic:
                model_output = dic["pred"]
            elif "response" in dic:
                model_output = dic["response"]

            if not model_output:
                continue
            #     dic["pred"] = "\n"
            pred.append(model_output.strip("<|eot_id|>").strip("<|im_end|>").strip("</s>").strip("<|user|>").strip("<end_of_turn><eos>"))
            
            if "ground_truth" in dic:
                label_ = dic["ground_truth"]
            elif "response" in dic:
                label_ = dic["answer"]
            
            truth_answer.append(label_.replace("\n"," "))
            

    calculate_score_cn(pred,truth_answer,use_jieba=True)






def test_llm_ensemble(file):
    print("***************************************************")
    print(os.path.basename(file).strip(".json"))
    pred=[]
    truth_answer = []

    answer_dict={}
    with open("", 'r') as f2:
            for line in f2.readlines():

                an = json.loads(line)
                answer_dict[an["prompt"]] = an["answer"]


    with open(file, 'r') as f:
        for num,line in enumerate(f):
            dic= json.loads(line)

 
           
            pred.append(dic["response"].strip("<|eot_id|>").strip("<|im_end|>").strip("</s>").strip("<end_of_turn>\n<eos>").strip("\n<end_of_turn><eos>").strip("<|user|>"))
           

          

            truth_answer.append(answer_dict[dic["prompt"]].replace("\n"," "))




    calculate_score(pred,truth_answer)

