import pandas as pd
import re
import json
# import jsonlines

def normalize(text):
    text = text.upper().strip()
    text = re.sub(r'[^a-zA-Z0-9]', '', text)
    return text

# data_path="/data/yutao/dataset/wiki_data/bamboogle/test.parquet" # bam
# data_path="/data/yutao/dataset/wiki_data/nq/test.parquet" #nq
data_path="/data/yutao/dataset/wiki_data/popqa/test.parquet" # pop
# data_path="/data/yutao/dataset/wiki_data/hotpot/dev-00000-of-00001.parquet" # hotpot
# data_path="/data/yutao/dataset/wiki_data/musique/dev.parquet" # mus
# data_path="/data/yutao/dataset/wiki_data/2wiki/dev.parquet" # 2wiki

data_df = pd.read_parquet(data_path)
gt_answer = dict()
for i, row in data_df.iterrows():
    prompt = row['prompt']
    question = row["extra_info"]["question"]
    gt = row["extra_info"]["selected_answer"]
    gt_answer[question] = normalize(gt)

# gen_file = 'mus_main_sft_t0.jsonl'
gen_file = 'pop_3b_step6_t0.jsonl'
# with jsonlines.open(gen_file) as reader:   
#     gen_data = list(reader)
gen_data = []
with open(gen_file, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if line:
            gen_data.append(json.loads(line))

steps = 0
suc = 0
emp = 0
for data in gen_data:
    content = data['trajectory']
    input = content[-1]['input_seq']
    output = content[-1]['output_seq']

    question = re.findall(r'Objective: (.*?)\nObservation', input)[0]

    if not re.findall(r"```(.*?)```", output):
        answer = " "
    else:
        answer = re.findall(r"```(.*?)```", output)[0]
    
    if 'stop' in answer:
        try:
            ans = normalize(re.findall(r"\[(.*?)\]", answer)[0])
        except:
            ans = ""
        ground_truth = gt_answer[question]
        if ground_truth in ans or ''.join(sorted(ground_truth)) == ''.join(sorted(ans)):
            suc += 1
            steps += data['trajectory_length']
    else:
        emp += 1


print(f"问题数目：{len(gen_data)}")
print(f"回答正确数目：{suc}")
print(f"正确率：{suc/len(gen_data)}")
print(f"未回答数目：{emp}")
print(f"平均步数：{steps/suc}")