import argparse
import json
import os
import threading
from pathlib import Path
from dotenv import load_dotenv


def compare_numbers_simple(input1, input2):
    import requests
    """简化版数值比较"""
    prompt = f"""# Numerical Comparison Task

Determine if two numerical expressions are mathematically equal. Consider different formats like fractions, decimals, LaTeX, percentages, and units.

**Rules:**
- Same numerical value = True
- Different units = False (e.g., "3 m" ≠ "3 cm")
- Ignore formatting differences

**Examples:**
- `1/2` vs `0.5` → True
- `\\frac{{3}}{{4}}` vs `75%` → True  
- `\\sqrt{{16}}` vs `4` → True
- `3.5 m` vs `3.5 cm` → False
- `2^3` vs `8` → True
- `π` vs `3.14159` → True

**Output format:** Just return `True` or `False`

**Compare:**
Input1: `{input1}`
Input2: `{input2}`

**Answer:**"""
    load_dotenv()  # Load environment variables from .env file
    LLM_MODEL_NAME = os.environ.get("OPENAI_LLM_MODEL_NAME", "gpt-4.1")
    LLM_API_KEY = os.environ.get("OPENAI_LLM_API_KEY", "")
    LLM_BASE_URL = os.environ.get("OPENAI_LLM_BASE_URL", "")

    headers = {
        "Authorization": f"Bearer {LLM_API_KEY}",
        "Content-Type": "application/json"
    }

    # API call
    data = {
        "model": LLM_MODEL_NAME,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 50,
        "temperature": 0.1
    }
    
    response = requests.post(f"{LLM_BASE_URL}/chat/completions", 
                           json=data, headers=headers)
    
    if response.status_code == 200:
        result = response.json()['choices'][0]['message']['content'].strip()
        return result.lower() == 'true'
    return None


def load_jsonl_multiline(path: str):
    """支持多行格式的 jsonl 文件读取"""
    records = []
    with open(path, "r", encoding="utf-8") as f:
        buffer = ""
        for line in f:
            line = line.strip()
            if not line:
                continue
            buffer += line
            # 尝试解析
            try:
                record = json.loads(buffer)
                records.append(record)
                buffer = ""  # 清空，等待下一个对象
            except json.JSONDecodeError:
                # 说明还没凑成一个完整 JSON
                continue
    return records


append_answer_lock = threading.Lock()


def append_answer(entry: dict, jsonl_file: str) -> None:
    jsonl_path = Path(jsonl_file)
    jsonl_path.parent.mkdir(parents=True, exist_ok=True)
    with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp:
        fp.write(json.dumps(to_builtin_type(entry), ensure_ascii=False, indent=2) + "\n")
    assert jsonl_path.exists(), "File not found!"
    print("Answer exported to file:", jsonl_path.resolve())
    
def to_builtin_type(obj):
    # 如果是 dict，递归处理每个 kv
    if isinstance(obj, dict):
        return {k: to_builtin_type(v) for k, v in obj.items()}
    # 如果是 list，递归处理每个元素
    elif isinstance(obj, list):
        return [to_builtin_type(i) for i in obj]
    # 发现 pyarrow、pandas、numpy 的 Integer 对象，都转成 int
    elif type(obj).__name__ == "Integer":
        return int(obj)
    # 支持 numpy 的 int
    try:
        import numpy as np
        if isinstance(obj, np.integer):
            return int(obj)
    except ImportError:
        pass
    return obj


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file-path", type=str, default="", help="Path to the input JSONL file.")
    parser.add_argument("--file-path-res", type=str, default="", help="Path to the output JSONL file with results.")
    return parser.parse_args()

# 使用
if __name__ == "__main__":
    args = parse_args()
    file_path = args.file_path
    file_path_res = args.file_path_res
    records = load_jsonl_multiline(file_path)
    old_records = load_jsonl_multiline(file_path_res)
    try:
        done_questions = [r["question"] for r in old_records if "question" in r]
        print(f"Found {len(done_questions)} previous results!")
    except Exception as e:
        print("Error when loading records: ", e)
        print("No usable records! ▶️ Starting new.")
        done_questions = []

    examples = [line for line in records if line["question"] not in done_questions]
    for r in examples:
        if "is_correct" not in r:
            r["is_correct"] = compare_numbers_simple(r["prediction"], r["true_answer"]) if r["prediction"] else False
        append_answer(r, file_path_res)
        
    res = load_jsonl_multiline(file_path_res)
    num_correct = sum(1 for r in res if r.get("is_correct", False))
    total_tokens = sum(r.get("tokens", 0) for r in res)
    total = len(res)
    accuracy = num_correct / total if total > 0 else 0.0
    print(f"Final Accuracy: {num_correct}/{total} = {accuracy:.2%}")
    print(f"Total Tokens Used: {total_tokens}")
    print(f"Average Tokens per Example: {total_tokens / total if total > 0 else 0:.2f}")