import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import OpenAI
from tqdm import tqdm
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct")
args = parser.parse_args()

STORAGE_PATH = os.getenv("STORAGE_PATH")

# 从 tokens.json 加载 API 配置
script_dir = os.path.dirname(os.path.abspath(__file__))
tokens_path = os.path.join(script_dir, "..", "tokens.json")
with open(tokens_path, 'r') as f:
    tokens = json.load(f)

# 创建 OpenAI 客户端（复用连接池）
client = OpenAI(
    api_key=tokens["API_KEY"],
    base_url=tokens["API_BASE"],
)

# 并行处理的最大线程数
MAX_WORKERS = 20


def process_example(idx, answer, response, max_retries=5, timeout=120):
    """
    调用 API 检查答案是否正确
    
    Args:
        idx: 样本索引（用于并行处理时追踪结果）
        answer: 标准答案 (ground truth)
        response: 模型生成的答案
        max_retries: 最大重试次数（默认3次）
        timeout: 请求超时时间（默认120秒）
    
    Returns:
        (idx, gpt_check_result) 元组
    """
    messages = [
            {"role": "system", "content": "You are a math answer checker."},
            {"role": "user", "content": f"Hi, there is a answer: {answer}\n\n, and the ground truth answer is: {response}\n\n, please check whether the answer is correct or not, and return the **only** Yes or No."}
    ]
    
    for attempt in range(max_retries):
        try:
            response_obj = client.chat.completions.create(
                model="gpt-4o",
                messages=messages,
                temperature=0.1,
                timeout=timeout
            )
            return (idx, response_obj.choices[0].message.content)
        except Exception as e:
            print(f"[idx={idx}] API request failed (attempt {attempt + 1}/{max_retries}): {type(e).__name__}: {e}")
            if attempt < max_retries - 1:
                # 指数退避：等待 2^attempt 秒后重试
                wait_time = 2 ** attempt
                print(f"[idx={idx}] Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                print(f"[idx={idx}] All {max_retries} attempts failed, returning 'No'")
                return (idx, "No")
    return (idx, "No")

new_results = []
for model_name in [args.model_name]:
    for dataset in [
    "math",
    "gsm8k", 
    "amc",
    "minerva",
    "olympiad",
    "aime2024",
    "aime2025",
    ]:
        with open(f'{STORAGE_PATH}/evaluation/{model_name.replace("/","_")}/results_{dataset}.json', 'r') as f:
            results = json.load(f)

        # 收集需要 recheck 的样本
        to_recheck = []
        for i in range(len(results)-1):
            if results[i]['score'] < 0.5:
                to_recheck.append((i, results[i]['answer'], results[i]['response']))
        
        print(f"[{dataset}] Total samples: {len(results)-1}, need recheck: {len(to_recheck)}")
        
        # 并行处理 recheck
        if to_recheck:
            with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
                futures = [
                    executor.submit(process_example, idx, answer, response)
                    for idx, answer, response in to_recheck
                ]
                
                for future in tqdm(as_completed(futures), total=len(futures), desc=f"Rechecking {dataset}"):
                    idx, gpt_check = future.result()
                    if "yes" in gpt_check.lower():
                        results[idx]['score'] = 1
        
        new_results.append({
            'model': model_name,
            'dataset': dataset,
            'score': round(sum([result['score'] for result in results[:-1]])/len(results[:-1])*100, 2)
        })
        print(new_results)
        with open(f'final_results.jsonl', 'a') as f:
            json.dump({
                'model': model_name,
                'dataset': dataset,
                'score': round(sum([result['score'] for result in results[:-1]])/len(results[:-1])*100, 2)
            }, f)
            f.write('\n')





