import os
import json
import re
from tqdm import tqdm
from argparse import ArgumentParser
import sys
import pandas as pd

def compute_cost(path_list):
    total_cost = {
        "prompt_token_count": 0,
        "candidates_token_count": 0,
        "thoughts_token_count": 0
    }
    webs_count = len(path_list)
    tasks_count = 0
    for path in tqdm(path_list, desc="Computing cost"):
        with open(os.path.join(path, "tasks.txt"), "r") as f:
            tasks = [line.strip() for line in f.readlines()]
        tasks_count += len(tasks)
        for task in tasks:
            with open(os.path.join(task, "messages.json"), "r") as f:
                messages = json.load(f)
            total_cost["prompt_token_count"] += messages["costs"]["prompt_token_count"]
            total_cost["candidates_token_count"] += messages["costs"]["candidates_token_count"]
            total_cost["thoughts_token_count"] += messages["costs"]["thoughts_token_count"]
    cost = total_cost["prompt_token_count"] * 3.5 / 1e6 + total_cost["candidates_token_count"] * 12 / 1e6 + total_cost["thoughts_token_count"] * 12 / 1e6
    print(f"Total cost: {cost}")
    print(f"Average cost per web: {cost / webs_count}")
    print(f"Average cost per task: {cost / tasks_count}")


def eval_results(path_list, label_path):
    results_count = {
        "DONE": 0, "FAILED": 0, "PARSING RESPONSE ERROR": 0, "UNRECOGNIZED ACTION TYPE": 0, "SERVER ERROR": 0, "INITIAL_ERROR": 0, "MAX ROUNDS": 0
    }
    res = {}
    for path in tqdm(path_list, desc="Evaluating results"):
        with open(os.path.join(path, "tasks.txt"), "r") as f:
            tasks = [line.strip() for line in f.readlines()]
        for task in tasks:
            with open(os.path.join(task, "messages.json"), "r") as f:
                messages = json.load(f)
            results_count[messages["final_result"]] += 1
            with open(os.path.join(task, "metadata.json"), "r") as f:
                metadata = json.load(f)
            if messages["final_result"] == "DONE":
                res[f"{metadata['question_id']}_{metadata['task_id']}"] = 1
            else:
                res[f"{metadata['question_id']}_{metadata['task_id']}"] = 0
    print(results_count)
    with open(label_path, "r") as f:
        label = json.load(f)
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for key, value in res.items():
        if value == 1 and label[key] == 1:
            TP += 1
        elif value == 1 and label[key] == 0:
            FP += 1
        elif value == 0 and label[key] == 0:
            TN += 1
        elif value == 0 and label[key] == 1:
            FN += 1
    print(f"TP: {TP}, FP: {FP}, TN: {TN}, FN: {FN}")
    print(f"Precision: {round((TP / (TP + FP)) * 100, 1)}%")
    print(f"Recall: {round((TP / (TP + FN)) * 100, 1)}%")
    print(f"F1-score: {round((2 * TP / (2 * TP + FP + FN)) * 100, 1)}%")
    print(f"Accuracy: {round(((TP + TN) / (TP + TN + FP + FN)) * 100, 1)}%")

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--path_list", type=str, default="web_unit.txt")
    parser.add_argument("--label_path", type=str, default="WebDevJudge_Unit/data/label.json")
    args = parser.parse_args()
    path_list = [line.strip() for line in open(args.path_list, "r")]
    eval_results(path_list, args.label_path)