from openai import AzureOpenAI
import pandas as pd
import openai
import numpy as np
import os
import time
from tqdm import tqdm
import click
from concurrent.futures import ThreadPoolExecutor, as_completed
try:
    from volcenginesdkarkruntime import Ark
except ImportError:
    Ark = None
import json
import requests

# Load environment variables
try:
    from dotenv import load_dotenv
    load_dotenv()  # 读取 .env
except:
    print("Warning: python-dotenv not installed. Install with: pip install python-dotenv")

def get_api_keys_from_env(prefix):
    """从环境变量中获取API密钥列表"""
    keys = []
    i = 1
    while True:
        key = os.getenv(f"{prefix}_API_KEY_{i}" if i > 1 else f"{prefix}_API_KEY")
        if key is None:
            break
        keys.append(key)
        i += 1
    return keys

config = {
    "o3": {
        "api_keys": get_api_keys_from_env("O3"),
        "api_version": "2024-03-01-preview",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "o3-2025-04-16",
        "max_tokens": 16384,
        "max_workers": 4,
    },
    "o4-mini": {
        "api_keys": get_api_keys_from_env("O3"),
        "api_version": "",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "o4-mini-2025-04-16",
        "max_tokens": 16384,
        "max_workers": 64,
    },
    "gpt4o": {
        "api_keys": get_api_keys_from_env("GPT4O"),
        "api_version": "",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gpt-4o-2024-11-20",
        "max_tokens": 8192,
        "max_workers": 64,
    },
    "gpt4.1": {
        "api_keys": get_api_keys_from_env("GPT4_1"),
        "api_version": "gpt-4.1-2025-04-14",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gpt-4.1-2025-04-14",
        "max_tokens": 16384,
        "max_workers": 64,
    },
    "qwen2.5-7b": {
        "api_keys": get_api_keys_from_env("GPT4O"),
        "api_version": "",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "openai_qwen2.5-7b-instruct",
        "max_tokens": 8192,
        "max_workers": 16,
    },

    "gpt5": {
        "api_keys": get_api_keys_from_env("GPT5"),
        "api_version": "",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gpt-5-2025-08-07",
        "max_tokens": 32768,
        "max_workers": 64,
    },
    "gpt5-mini": {
        "api_keys": get_api_keys_from_env("GPT5_MINI"),
        "api_version": "",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gpt-5-mini-2025-08-07",
        "max_tokens": 2048,
        "max_workers": 64,
    },
    "gpt5-nano": {
        "api_keys": get_api_keys_from_env("GPT5_NANO"),
        "api_version": "",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gpt-5-nano-2025-08-07",
        "max_tokens": 2048,
        "max_workers": 64,
    },
    "claude": {
        "api_keys": get_api_keys_from_env("CLAUDE"),
        "api_version": "2024-03-01-preview",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "aws_claude35_sdk_sonnet_v2",
        "max_tokens": 8192,
        "max_workers": 64,
    },
    "claude-4-sonnet": {
        "api_keys": get_api_keys_from_env("CLAUDE_4_SONNET"),
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gcp-claude4-sonnet",
        "api_version": "",
        "max_tokens": 16384,
        "max_workers": 64,
    },
    "claude-4-opus": {
        "api_keys": get_api_keys_from_env("CLAUDE_4_OPUS"),
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gcp-claude4-opus",
        "api_version": "",
        "max_tokens": 16384,
        "max_workers": 64,
    },
    "claude-4.1-opus": {
        "api_keys": get_api_keys_from_env("CLAUDE_4_1_OPUS"),
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gcp-claude4.1-opus",
        "api_version": "",
        "max_tokens": 16384,
        "max_workers": 64,
    },

    "gemini": {
        "api_keys": get_api_keys_from_env("GEMINI"),
        "api_version": "2024-03-01-preview",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gemini-1.5-pro-002",
        "max_tokens": 4096,
        "max_workers": 8,
    },
    "gemini-2.0": {
        "api_keys": get_api_keys_from_env("GEMINI_2_0"),
        "api_version": "",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT?ak=",
        "model_name": "gemini-2.0-flash-001",
        "max_tokens": 8000,
        "max_workers": 8,
    },
    "gemini-2.5-pro": {
        "api_keys": get_api_keys_from_env("GEMINI_2_5_PRO"),
        "api_version": "",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT?ak=",
        "model_name": "gemini-2.5-pro-preview-05-06",
        "max_tokens": 8000, 
        "max_workers": 8, 
    },
    "doubao": {
        "api_keys": get_api_keys_from_env("DOUBAO"),
        "api_version": "",
        "base_url": "https://YOUR_ARK_ENDPOINT",
        "model_name": "ep-20241224213346-nbxh7",
        "max_tokens": 32768,
        "max_workers": 16,
    },
    "doubao-1.5": {
        "api_keys": get_api_keys_from_env("DOUBAO_1_5"),
        "api_version": "",
        "base_url": "https://YOUR_ARK_ENDPOINT",
        "model_name": "ep-20250205170847-6gd9s",
        "max_tokens": 32768,
        "max_workers": 16,
    },
    "r1": {
        "api_keys": get_api_keys_from_env("R1"),
        "api_version": "",
        "base_url": "https://YOUR_ARK_ENDPOINT",
        "model_name": "ep-20250207134207-8zx79",
        "max_tokens": 65536,
        "max_workers": 64,
    },
    "deepseekv3-0324": {
        "api_keys": get_api_keys_from_env("DEEPSEEKV3_0324"),
        "api_version": "",
        "base_url": "https://YOUR_ARK_ENDPOINT",
        "model_name": "ep-20250328185550-g5zcd",
        "max_tokens": 32768,
        "max_workers": 8,
    }, 
    "openai_qwen3-4b": {
        "api_keys": get_api_keys_from_env("OPENAI_QWEN"),
        "api_version": "",
        "base_url": "https://YOUR_ARK_ENDPOINT",
        "model_name": "openai_qwen3-4b",
        "max_tokens": 32768,
        "max_workers": 8,
    },
    "gpt-oss-120b": {
        "api_keys": get_api_keys_from_env("GPT_OSS"),
        "api_version": "",
        "base_url": "https://YOUR_OPENAI_COMPAT_ENDPOINT",
        "model_name": "gpt-oss-120b",
        "max_tokens": 32768,
        "max_workers": 64,
    },
    "deepseek-3.1": {
        "api_keys": get_api_keys_from_env("DEEPSEEK_3_1"),
        "api_version": "",
        "base_url": "https://YOUR_ARK_ENDPOINT",
        "model_name": "ep-20250826141533-cf9bk",
        "max_tokens": 32768,
        "max_workers": 64,
    },
    


}
fangzhou_models = ["doubao", "doubao-1.5", "r1", "deepseekv3-0324"]  # 使用方舟接口的模型，其它模型使用OpenAPI接口
reasoning_models = ["r1"]
websearch_models = ["gemini-2.0", "gemini-2.5-pro"]  # 支持联网的模型

class LLMClient:
    def __init__(self, model="gpt4o"):
        self.conf = config[model]
        self.model = model
        print(self.conf)
        api_keys = self.conf["api_keys"]
        if model in fangzhou_models:
            self.clients = [
                Ark(
                    api_key=api_key,
                    base_url=self.conf["base_url"],
                )
                for api_key in api_keys
            ]
        elif model in websearch_models:
            self.clients = [{
                "url": self.conf["base_url"] + api_key,
            } for api_key in api_keys]
        else:
            self.clients = [
                AzureOpenAI(
                    api_key=api_key,
                    api_version=self.conf["api_version"],
                    azure_endpoint=self.conf["base_url"],
                )
                for api_key in api_keys
            ]

    def call_openai_api_with_retry(self, messages, max_retries=10, delay=1):
        retries = 0
        while retries < max_retries:
            try:
                client = np.random.choice(self.clients)
                if self.model in websearch_models:
                    response = requests.post(
                        url=client["url"],
                        json={
                            "model": self.conf["model_name"],
                            "messages": messages,
                            "tools": [{"type": "google_search"}]
                        },
                        headers={"Content-Type": "application/json"}
                    )
                else:
                    response = client.chat.completions.create(
                        model=self.conf["model_name"], messages=messages, max_tokens=self.conf["max_tokens"])
                return response
            except openai.OpenAIError as e:
                print(f"OpenAI error occurred: {e}")
            except Exception as e:
                print(f"An error occurred: {e}")
            retries += 1
            # 随机延迟：1 + [0-4]秒，即1-5秒之间浮动
            _delay = delay + np.random.uniform(0, 2)
            print(f"Retrying... ({retries}/{max_retries}), sleep: {_delay:.2f}s")
            time.sleep(_delay)
        raise Exception("API call failed after maximum retries")

    def chat(self, system_prompt, prompt):
        messages = []
        if len(system_prompt) > 0:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})
        try:
            response = self.call_openai_api_with_retry(messages)
            if self.model in websearch_models:
                return response.json()["choices"][0]["message"]["content"].strip()
            elif self.model in reasoning_models:
                return {"reasoning_content": response.choices[0].message.reasoning_content.strip(), "content": response.choices[0].message.content.strip()}
            else:
                return response.choices[0].message.content.strip()
        except Exception as e:
            print(f"Final error: {e}")
            if self.model in reasoning_models:
                return {"reasoning_content": "", "content": ""}
            else:
                return ""

def process_item(item, annotator, n_outputs=1):
    outputs = []
    for i in range(n_outputs):
        if "cook_prompt" in item:
            prompt = item["cook_prompt"]
        else:
            prompt = item["prompt"]
        outputs.append(annotator.chat(item.get("system_prompt", ""), prompt))
    if len(outputs) == 1:
        item["output"] = outputs[0]
    else:
        item["output"] = outputs
    return item

def main(input_path, output_path, n_outputs=1, n_sample=-1, max_workers=0, model="gpt4o"):
    df = pd.read_parquet(input_path)
    data = df.to_dict("records")
    if n_sample > 0 and n_sample < len(data):
        data = data[:n_sample]
    annotator = LLMClient(model=model)
    if max_workers == 0:
        max_workers = config[model]["max_workers"]
    print(f"total: {len(data)}({len(df)}), max_workers: {max_workers}")
    results = []
    futures = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for idx, item in enumerate(data):
            futures.append(executor.submit(process_item, item, annotator, n_outputs))
        for future in tqdm(as_completed(futures), total=len(futures)):
            result = future.result()
            if result:
                results.append(result)
                if len(results) % 10 == 0 and len(results) < 100:
                    print(result["output"])
                # Save intermediate results every n items
                if len(results) % 10000 == 0:
                    pd.DataFrame(results).to_parquet(output_path)
    # Save final results
    pd.DataFrame(results).to_parquet(output_path)

def batch_cook(df_in, n_outputs=1, n_sample=-1, max_workers=0, model="gpt4o"):
    df = df_in
    data = df.to_dict("records")
    if n_sample > 0 and n_sample < len(data):
        data = data[:n_sample]
    annotator = LLMClient(model=model)
    if max_workers == 0:
        max_workers = config[model]["max_workers"]
    print(f"total: {len(data)}({len(df)}), max_workers: {max_workers}")
    results = []
    futures = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for idx, item in enumerate(data):
            futures.append(executor.submit(process_item, item, annotator, n_outputs))
        for future in tqdm(as_completed(futures), total=len(futures)):
            result = future.result()
            if result:
                results.append(result)
    # Save final results
    # return pd.DataFrame(results)
    return results




@click.command()
@click.option("-i", "--input-file")
@click.option("-o", "--output-file")
@click.option("-n", "--n_outputs", default=1, show_default=True)
@click.option("-s", "--n_sample", default=-1, show_default=True)
@click.option("-w", "--max_workers", default=0, show_default=True)
@click.option("-m", "--model", default="gpt4o", show_default=True)
def process(input_file, output_file, n_outputs, n_sample, max_workers, model):
    main(input_file, output_file, n_outputs, n_sample, max_workers, model)
    click.echo("Processing complete!")

if __name__ == "__main__":
    all_data = pd.read_parquet(input_file)
    outputs = batch_cook(all_data, model="r1", n_outputs=1)
    print(outputs[0].keys())
    print(f"out_file: {out_file}")
    pd.DataFrame(outputs).to_parquet(out_file)
