import argparse
import json
import os
import re
import sys
import datasets
from openai import OpenAI
from dotenv import load_dotenv
# from vllm import LLM, SamplingParams
from tqdm import tqdm
# import torch

TEMPLATE_q2mc_en = r"""
Below is an operations research question. Build a mathematical model and corresponding python code using `coptpy` that appropriately addresses the question.

# Question:
{Question}

# Response:
"""

# Load environment variables from .env file
load_dotenv()

# OpenAI API setup
openai_api_data = dict(
    api_key = os.getenv("OPENAI_API_KEY"),
    base_url = os.getenv("OPENAI_API_BASE")
)

def main(args):
    assert args.dataset_name is not None
    assert args.dataset_split is not None
    # assert isinstance(args.topk, int)
    # assert args.decoding_method in ["greedy", "sampling"]
    # assert os.path.exists(args.model_name_or_path), "We only support local model path!"
    assert args.openai_model is not None
    assert args.save_dir is not None

    os.makedirs(args.save_dir, exist_ok=True)

    # Load data
    sample = []
    TEMPLATE_q2mc_en = "Please generate a coptpy code for the following question.\nQuestion: {Question}\nCoptpy Code:"
    LOCAL_DATASET_DIR = "./dataset"
    local_path = os.path.join(LOCAL_DATASET_DIR, args.dataset_name)
    assert os.path.exists(local_path), f"⚠️ 警告: 目录 '{local_path}' 不存在，无法加载 '{args.dataset_name}'。"
    print(f"加载: {local_path}...")
    ds = datasets.load_dataset(local_path, trust_remote_code=True)
    ds = ds[args.dataset_split]
    for example in ds:
        assert "en_question" in example
        prompt = TEMPLATE_q2mc_en.replace("{Question}", example["en_question"].strip()).strip()
        example_t = {k: v for k, v in example.items() if k not in ["prompt"]}
        example_t["prompt"] = prompt
        sample.append(example_t)
    print(f"load dataset from `{args.dataset_name}` done. sample size: {len(ds)}")

    # --- 模型初始化 (修改部分) ---
    # 旧代码:
    # model = LLM(model=args.model_name_or_path, dtype=torch.float16, tensor_parallel_size=args.tensor_parallel_size)
    # print("init model done.")
    
    # 新代码: 初始化 OpenAI 客户端
    # 它会自动从环境变量 `OPENAI_API_KEY` 读取密钥
    try:
        client = OpenAI(  
            api_key=openai_api_data['api_key'],
            base_url=openai_api_data['base_url'] if openai_api_data['base_url'] else None)
        # OpenAI(api_key=args.api_key, base_url=args.base_url)
        print("OpenAI client initialized.")
        # 可以加一个简单的测试来验证API key是否有效
        # client.models.list() 
        print("OpenAI API connection successful.")
    except Exception as e:
        print(f"Error initializing OpenAI client: {e}")
        print("Please make sure your OPENAI_API_KEY is set correctly as an environment variable or passed via --api_key.")
        return

    # --- 生成参数 (修改部分) ---
    # 旧代码:
    # stop_tokens = ["</s>"]
    # if args.decoding_method == "greedy": ...
    # sampling_params = SamplingParams(...)
    
    # 新代码: 将参数组织成一个字典
    # stop_tokens = ["</s>"]
    # generation_params = {
    #     "n": args.topk,
    #     "temperature": 0 if args.decoding_method == "greedy" else args.temperature,
    #     "top_p": 1 if args.decoding_method == "greedy" else args.top_p,
    #     "max_tokens": args.max_tokens,
    #     "stop": stop_tokens
    # }
    # if args.decoding_method == "greedy":
    #     print(f"WARNING! greedy decoding will force temperature=0, top_p=1!")
    # print(f"init generation params done: {generation_params}")


    # --- 生成与保存 (主要修改逻辑) ---
    # 旧代码是批处理，新代码改为循环处理
    save_file = os.path.join(args.save_dir, "generated.jsonl")
    os.makedirs(args.save_dir, exist_ok=True)
    fw = open(save_file, "w", encoding='utf-8')
    num_total = 0
    num_skip_for_dup = 0
    

    for example in tqdm(sample, desc="Generating with OpenAI"):
        prompt = example["prompt"]
        
        try:
            # --- API 调用 ---
            response = client.chat.completions.create(
                model=args.openai_model,
                messages=[{"role": "user", "content": prompt}],
                # **generation_params
            )
            
            # --- 结果解析 ---
            outputs = response.choices
            outputs_t = []
            touched_output = set()

            for output_choice in outputs:
                num_total += 1
                output_text = output_choice.message.content.strip()
                
                if output_text not in touched_output:
                    outputs_t.append(output_text)
                    touched_output.add(output_text)
                else:
                    num_skip_for_dup += 1
            
            # --- 数据保存 (与原逻辑类似) ---
            for output in outputs_t:
                example_t = {k: v for k, v in example.items()}
                example_t["q2mc_en_prompt"] = prompt
                example_t["en_math_model_coptpy_code"] = output
                if args.verbose:
                    print("-" * 20 + "prompt" + "-" * 20)
                    print(prompt)
                    print("-" * 20 + "completion" + "-" * 20)
                    print(output)
                    print("-" * 80)

                dump = json.dumps(example_t, ensure_ascii=False)
                fw.write(dump + "\n")
        
        except Exception as e:
            print(f"An error occurred while processing a prompt: {e}")
            # 你可以选择跳过这个prompt或者重试
            continue

    fw.close()
    print(f"Generation complete. Results saved to {save_file}")
    print(f"num_total: {num_total}; num_skip_for_dup: {num_skip_for_dup}")

def parse_args():
    parser = argparse.ArgumentParser()
    # parser.add_argument("--model_name_or_path", type=str, default=None)  # model path
    
    parser.add_argument("--openai_model", type=str, default=None) 
    parser.add_argument("--dataset_name", type=str, default=None) 
    parser.add_argument("--dataset_split", type=str, default=None) 
    parser.add_argument("--save_dir", type=str, default=None)  
    # parser.add_argument("--tensor_parallel_size", type=int, default=8)  # num_gpus
    # parser.add_argument("--topk", type=int, default=1)  
    # parser.add_argument("--temperature", type=float, default=0.7) 
    # parser.add_argument("--top_p", type=float, default=0.95) 
    # parser.add_argument("--max_tokens", type=int, default=None) 
    # parser.add_argument("--decoding_method", type=str, default="greedy")  
    parser.add_argument("--verbose", action="store_true")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    main(args)