'''
Author: swchen
Date: 2025-09-17 15:07:11
LastEditors: swchen
LastEditTime: 2025-09-25 17:27:24
FilePath: /SupervisorAgent/smolagents/examples/smolagents_benchmark_gsmhard/eval-2.py
Description: 

Copyright (c) 2025 by Shaowen Chen, All Rights Reserved. 
'''
# EXAMPLE COMMAND: from folder examples/open_deep_research, run: python run_gaia.py --concurrency 32 --run-name generate-traces-03-apr-noplanning --model-id gpt-4o
import argparse
import json
import os
import threading
from pathlib import Path
from dotenv import load_dotenv
import json
from pathlib import Path
import datasets
import requests


def compare_numbers_simple(input1, input2):
    import requests
    """简化版数值比较"""
    prompt = f"""# Numerical Comparison Task

Determine if two numerical expressions are mathematically equal. Consider different formats like fractions, decimals, LaTeX, percentages, and units.

**Rules:**
- Same numerical value = True
- Different units = False (e.g., "3 m" ≠ "3 cm")
- Ignore formatting differences

**Examples:**
- `1/2` vs `0.5` → True
- `\\frac{{3}}{{4}}` vs `75%` → True  
- `\\sqrt{{16}}` vs `4` → True
- `3.5 m` vs `3.5 cm` → False
- `2^3` vs `8` → True
- `π` vs `3.14159` → True

**Output format:** Just return `True` or `False`

**Compare:**
Input1: `{input1}`
Input2: `{input2}`

**Answer:**"""
    load_dotenv()  # Load environment variables from .env file
    LLM_MODEL_NAME = os.environ.get("OPENAI_LLM_MODEL_NAME", "gpt-4.1")
    LLM_API_KEY = os.environ.get("OPENAI_LLM_API_KEY", "")
    LLM_BASE_URL = os.environ.get("OPENAI_LLM_BASE_URL", "")


    url = f"{LLM_BASE_URL}/chat/completions"
        
    headers = {
        "Authorization": f"Bearer {LLM_API_KEY}",
        "Content-Type": "application/json"
    }

    # API call
    data = {
        "model": LLM_MODEL_NAME,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 50,
        "temperature": 0.1
    }
    
    response = requests.post(f"{LLM_BASE_URL}/chat/completions", 
                           json=data, headers=headers)
    
    if response.status_code == 200:
        result = response.json()['choices'][0]['message']['content'].strip()
        return result.lower() == 'true'
    return None


def load_jsonl_multiline(path: str):
    """支持多行格式的 jsonl 文件读取"""
    records = []
    with open(path, "r", encoding="utf-8") as f:
        buffer = ""
        for line in f:
            line = line.strip()
            if not line:
                continue
            buffer += line
            # 尝试解析
            try:
                record = json.loads(buffer)
                records.append(record)
                buffer = ""  # 清空，等待下一个对象
            except json.JSONDecodeError:
                # 说明还没凑成一个完整 JSON
                continue
    return records


append_answer_lock = threading.Lock()


def append_answer(entry: dict, jsonl_file: str) -> None:
    jsonl_path = Path(jsonl_file)
    jsonl_path.parent.mkdir(parents=True, exist_ok=True)
    with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp:
        fp.write(json.dumps(to_builtin_type(entry), ensure_ascii=False, indent=2) + "\n")
    assert jsonl_path.exists(), "File not found!"
    print("Answer exported to file:", jsonl_path.resolve())
  
  
def to_builtin_type(obj):
    # 如果是 dict，递归处理每个 kv
    if isinstance(obj, dict):
        return {k: to_builtin_type(v) for k, v in obj.items()}
    # 如果是 list，递归处理每个元素
    elif isinstance(obj, list):
        return [to_builtin_type(i) for i in obj]
    # 发现 pyarrow、pandas、numpy 的 Integer 对象，都转成 int
    elif type(obj).__name__ == "Integer":
        return int(obj)
    # 支持 numpy 的 int
    try:
        import numpy as np
        if isinstance(obj, np.integer):
            return int(obj)
    except ImportError:
        pass
    return obj


def get_examples_to_answer(answers_file: str, eval_ds: datasets.Dataset, task_id: str = None) -> list[dict]:
    # return [line for line in eval_ds.to_list()]
    print(f"Loading answers from {answers_file}...")
    try:
        records = load_jsonl_multiline(answers_file)
        done_questions = [r["task_id"] for r in records if "task_id" in r]
        print(f"Found {len(done_questions)} previous results!")
    except Exception as e:
        print("Error when loading records: ", e)
        print("No usable records! ▶️ Starting new.")
        done_questions = []

    raw_examples = [line for line in eval_ds if line["task_id"] not in done_questions]
    examples = []
    for ex in raw_examples:
        examples.append(ex)
    # 如果指定了 task_id，就只保留这一条
    if task_id:
        examples = [ex for ex in examples if str(ex.get("task_id")) == str(task_id)]
        print(f"Filtered for task_id={task_id}, {len(examples)} tasks found.")
    return examples

    
def remove_repeat_examples(records):
    task_id_set = set()
    examples = []
    for r in records:
        if r["prediction"] is None:
            continue
        if r["task_id"] in task_id_set:
            continue
        else:
            task_id_set.add(r["task_id"])
            examples.append(r)
    return examples


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file-path-1", type=str, default="", help="Path to the input JSONL file.")
    parser.add_argument("--file-path-2", type=str, default="", help="Path to the second input JSONL file.")
    parser.add_argument("--file-path-res-1", type=str, default="", help="Path to the output JSONL file with results.")
    parser.add_argument("--file-path-res-2", type=str, default="", help="Path to the second output JSONL file with results.")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    file_path_1 = args.file_path_1
    file_path_2 = args.file_path_2
    res1 = load_jsonl_multiline(file_path_1)
    res2 = load_jsonl_multiline(file_path_2)
    res1_path = args.file_path_res_1
    res2_path = args.file_path_res_2
    
    task_id_set1 = set()
    res1 = remove_repeat_examples(res1)
    res2 = remove_repeat_examples(res2)
    print(res1[1])
    print(len(res1), len(res2))
    res1 = get_examples_to_answer(res1_path, res1)
    res2 = get_examples_to_answer(res2_path, res2)
    print(len(res1), len(res2))
    cor1 = 0
    cor2 = 0
    sum1 = 0
    sum2 = 0
    r1_total_tokens = 0
    r2_total_tokens = 0
    for r1 in res1:
        for r2 in res2:
            if r1["task_id"] == r2["task_id"]:
                r1_total_tokens += r1["token_cost"]
                r2_total_tokens += r2["token_cost"]
                import time
                r1["is_correct"] = compare_numbers_simple(r1["prediction"], r1["true_answer"])
                r2["is_correct"] = compare_numbers_simple(r2["prediction"], r2["true_answer"])
                time.sleep(1)  # 避免请求过快
                append_answer(r1, res1_path)
                append_answer(r2, res2_path)
                sum1 += 1
                sum2 += 1
                break
    print(f"File1 Correct: {cor1}/{sum1}, Accuracy: {cor1/sum1}")
    print(f"File2 Correct: {cor2}/{sum2}, Accuracy: {cor2/sum2}")
    print(f"File1 Total Tokens: {r1_total_tokens}")
    print(f"File2 Total Tokens: {r2_total_tokens}")