import os
import sys
import json
import argparse
import pandas as pd
import numpy as np
from tqdm import tqdm

def load_csv(file_path:str)->pd.DataFrame:
    return pd.read_csv(file_path)

def get_args():
    parser = argparse.ArgumentParser(description="Grader for different models, types, intent_space, and params.")
    parser.add_argument("--results_path", type=str, required=True, help="Path to the results csv file.")
    parser.add_argument("--report_path", type=str, required=True, help="Path to the report.")
    args = parser.parse_args()
    return args

class BenchmarkGrader:
    def __init__(self, results_path:str):
        self.results = load_csv(results_path)

    def grade(self, pred:list, label:list):
        number_score = 0
        answer_score = 0
        # if len(label) != len(pred):
        #     number_score = -0.5
        # pred = json.loads(pred)
        # label = json.loads(label)
        for label_idx in label:
            if label_idx in pred:
                print(f"{label_idx} in {pred}")
                answer_score += 1
        answer_score = answer_score / len(label)
        print(f"score: {answer_score}, pred: {pred}, label: {label}")
        return number_score + answer_score
        
    def format_output(self, data):
        if isinstance(data, str):
            try:
                data = json.loads(data)
                return data
            except:
                data = int(data)
        elif isinstance(data, list):
            return data
        else:
            return [data]
    def run(self, report_path:str=None):
        score_list = []
        logs = []
        for index, row in self.results.iterrows():
            item = row.to_dict()
            pred_idx = item["action_idx"]
            label_idx = item["ground_truth_idx"]
            pred_idx = self.format_output(pred_idx)
            label_idx = self.format_output(label_idx)
            # print(pred_idx, label_idx)
            # print(type(pred_idx), type(label_idx))
            score = self.grade(pred_idx, label_idx)
            score_list.append(score)
            # logs
            log_item = {
                "score": score,
                "action": item["action"],
                "ground_truth": item["ground_truth"],
                "choice_space": item["choice_space"],
                "system_prompt": item["system_prompt"],
                "user_prompt": item["user_prompt"],
            }
            logs.append(log_item)
        mean_score = np.mean(score_list)
        total_score = np.sum(score_list)
        print(f"Mean score: {mean_score:.2f}")
        print(f"Total score: {total_score:.2f} / {len(score_list)}")
        if report_path is not None:
            log_df = pd.DataFrame(logs)
            log_df.to_csv(report_path, index=False)
        return score_list

def main():
    # args = get_args()
    # results_path = args.results_path
    # report_path = args.report_path
    results_path = "/mnt/workspace/workgroup/Benchmark/output/evaluate_results/single_0813_qwen-plus-latest_mc-1_1000/qwen3-4b.txt"
    report_path = "/mnt/workspace/workgroup/Benchmark/output/grader_report/single_0813_qwen-plus-latest_mc-1_1000/qwen3-4b.txt"
    grader = BenchmarkGrader(results_path)
    grader.run(report_path)

if __name__ == "__main__":
    main()