import pandas as pd
import re
import json


def normalize(text):
    text = text.upper().strip()
    text = re.sub(r'[^a-zA-Z0-9]', '', text)
    return text


# data_path="/data/yutao/dataset/wiki_data/bamboogle/test.parquet" # bam
# data_path="/data/yutao/dataset/wiki_data/nq/test.parquet" #nq
data_path="/data/yutao/dataset/wiki_data/popqa/test.parquet" # pop
# data_path="/data/yutao/dataset/wiki_data/hotpot/dev-00000-of-00001.parquet" #hotpot
# data_path="/data/yutao/dataset/wiki_data/musique/dev.parquet" # mus

data_df = pd.read_parquet(data_path)
data_df = data_df.sample(frac=1, random_state=42).reset_index(drop=True)
data_df = data_df[:1000]
gt_answer = dict()
for i, row in data_df.iterrows():
    prompt = row['prompt']
    question = row["extra_info"]["question"]
    gt = row["extra_info"]["selected_answer"]
    gt_answer[question] = normalize(gt)

# gen_file = 'mus_main_t0.jsonl'
gen_file = '20250921_032223_webarena_results.jsonl'
# with jsonlines.open(gen_file) as reader:   
#     gen_data = list(reader)
gen_data = []
with open(gen_file, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if line:
            gen_data.append(json.loads(line))

steps = 0
suc = 0
emp = 0
for data in gen_data:
    content = data['trajectory']
    input = content[-1]['input_seq']
    output = content[-1]['output_seq']

    question = re.findall(r'Objective: (.*?)\nObservation', input)[0]
    if question not in gt_answer:
        continue

    if not re.findall(r"```(.*?)```", output):
        answer = " "
    else:
        answer = re.findall(r"```(.*?)```", output)[0]

    if 'stop' in answer:
        try:
            ans = normalize(re.findall(r"\[(.*?)\]", answer)[0])
        except:
            ans = ""
        ground_truth = gt_answer[question]
        if ground_truth in ans or ''.join(sorted(ground_truth)) == ''.join(sorted(ans)):
            suc += 1
            steps += data['trajectory_length']
    else:
        emp += 1

print(f"问题数目：{len(gt_answer)}")
print(f"回答正确数目：{suc}")
print(f"正确率：{suc/len(gt_answer)}")
print(f"未回答数目：{emp}")
print(f"平均步数：{steps / suc}")