import json
import os
import time
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor
from evaluate import run_evaluation
import argparse

# API 配置信息
API_CONFIGS = {
    "77": {
        "api_key": os.getenv("OPENAI_API_KEY_77", ""),
        "base_url": "https://api.key77qiqi.com/v1",
    }
}
MODEL_MAP = {
    "4.1mini": "gpt-4.1-mini-2025-04-14",
    "4omini": "gpt-4o-mini-2024-07-18",
}

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate model on Math dataset using API.")
    parser.add_argument('--src_file', type=str, help="data file path")
    parser.add_argument('--model_name', type=str, default="4omini", help="API model name (e.g., 4omini, 4.1mini)")
    parser.add_argument('--api_provider', type=str, default="77", choices=["77"], help="API provider")
    parser.add_argument('--split', type=str, default='test', choices=['test', 'diamond', 'main', 'extended'], help="Dataset split to use.")
    parser.add_argument('--subset_num', type=int, default=-1, help="Number of examples to process. Defaults to all.")
    parser.add_argument('--temperature', type=float, default=0.3, help="Sampling temperature.")
    parser.add_argument('--top_p', type=float, default=0.7, help="Top-p sampling parameter.")
    parser.add_argument('--max_tokens', type=int, default=2048, help="Maximum number of tokens to generate.")
    parser.add_argument('--max_workers', type=int, default=128, help="Maximum number of parallel API calls.")
    return parser.parse_args()

def get_task_instruction_math(question):
    """Generate the math task instruction"""
    user_prompt = (
        'Please answer the following math question. You should think step by step to solve it.\n\n'
        'Provide your final answer in the format \\boxed{YOUR_ANSWER}.\n\n'
        f'Question:\n{question}\n\n'
    )
    return user_prompt

def load_json_or_jsonl(data_path):
    """读取 JSON 或 JSONL 文件，返回一个包含字典的列表"""
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"文件未找到: {data_path}")
    
    ext = os.path.splitext(data_path)[-1].lower()
    data = []

    with open(data_path, mode='r', encoding='utf-8') as f:
        if ext == '.jsonl':
            for line_num, line in enumerate(f, start=1):
                line = line.strip()
                if not line:
                    continue
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError as e:
                    raise ValueError(f"第 {line_num} 行 JSON 解析失败: {e}")
        elif ext == '.json':
            try:
                loaded = json.load(f)
                if isinstance(loaded, list):
                    data = loaded
                else:
                    raise ValueError("JSON 文件必须是列表格式（即以 `[` 开头的 JSON 文件）")
            except json.JSONDecodeError as e:
                raise ValueError(f"JSON 文件解析失败: {e}")
        else:
            raise ValueError("文件扩展名必须是 .json 或 .jsonl")
    
    return data

def generate_api_single(client, model, prompt, max_tokens, temperature, top_p, stop=None):
    """单个 API 调用生成"""
    try:
        messages = [{"role": "user", "content": prompt}]
        
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            stop=stop
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"API 调用失败: {e}")
        return ""

def generate_api(client, model, prompts, max_tokens, temperature, top_p, stop, max_workers=128):
    """
    API 模型批量生成
    """
    def generate_single_wrapper(prompt):
        return generate_api_single(client, model, prompt, max_tokens, temperature, top_p, stop)
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(generate_single_wrapper, prompt) for prompt in prompts]
        results = [future.result() for future in futures]
    return results

def main():
    args = parse_args()
    t_start = time.time()
    
    # API configuration
    api_provider = args.api_provider
    model_name = args.model_name
    api_config = API_CONFIGS[api_provider]
    
    # Initialize OpenAI client
    client = OpenAI(
        api_key=api_config["api_key"],
        base_url=api_config["base_url"]
    )
    
    # Data path and output path
    data_path = args.src_file
    model_short_name = model_name
    
    if "math500" in args.src_file:
        output_dir = f'../outputs/math500/{model_short_name}.vanilla'
    elif "amc23.jsonl" in args.src_file:
        output_dir = f'../outputs/amc23/{model_short_name}.vanilla'
    elif "amc.jsonl" in args.src_file:
        output_dir = f'../outputs/amc/{model_short_name}.vanilla'
    elif "minervamath" in args.src_file:
        output_dir = f'../outputs/minervamath/{model_short_name}.vanilla'
    elif "olympiad" in args.src_file:
        output_dir = f'../outputs/olympiad/{model_short_name}.vanilla'
    else:
        output_dir = f'../outputs/other/{model_short_name}.vanilla'
    
    os.makedirs(output_dir, exist_ok=True)

    # Load data
    data = load_json_or_jsonl(data_path)
    
    if args.subset_num != -1:
        data = data[:args.subset_num]

    # Prepare input prompts
    input_list = []
    for item in data:
        question = item['question']
        user_prompt = get_task_instruction_math(question)
        input_list.append(user_prompt)

    # Generate answers using API
    t_start = time.time()
    print(f"开始生成答案，共 {len(input_list)} 个问题...")
    
    # 生成答案
    output_contents = generate_api(
        client=client,
        model=MODEL_MAP[model_name],
        prompts=input_list,
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        stop=None,
        max_workers=args.max_workers
    )
    
    total_time = time.time() - t_start
    print(f"生成完成，用时 {total_time:.2f} 秒")

    # 构造与原代码兼容的输出格式
    class APIOutput:
        def __init__(self, prompt, output):
            self.prompt = prompt
            self.outputs = [type('obj', (object,), {'text': output})]
    
    output_list = [APIOutput(prompt, output) for prompt, output in zip(input_list, output_contents)]

    # Run evaluation
    run_evaluation(
        data,
        input_list,
        output_list,
        output_dir,
        total_time,
        'test',
    )

if __name__ == "__main__":
    main()