# Licensed under the MIT license.

import os
from openai import OpenAI
import time
from tqdm import tqdm
import concurrent.futures


client = OpenAI(
    
    api_key=os.getenv("DASHSCOPE_API_KEY"),
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)

max_threads = 32


def load_OpenAI_model(model):
    return None, model


def generate_with_OpenAI_model(
    prompt,
    model_ckpt="gpt-35-turbo",
    max_tokens=256,
    temperature=0.8,
    top_k=40,
    top_p=0.95,
    stop=["\n"],
):
    messages = [
        # {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
        ]
    parameters = {
        "model": model_ckpt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "top_p": top_p,
        "stop": stop,
        "seed": 1,
    }

    ans, timeout = "", 5
    while not ans:
        try:
            time.sleep(timeout)
            completion = client.chat.completions.create(messages=messages, **parameters)
            ans = completion.choices[0].message.content

        except Exception as e:
            print(e)
        if not ans:
            timeout = timeout * 2
            if timeout > 120:
                timeout = 1
            try:
                print(f"Will retry after {timeout} seconds ...")
            except:
                pass
    return ans


def generate_n_with_OpenAI_model(
    prompt,
    n=1,
    model_ckpt="gpt-35-turbo",
    max_tokens=256,
    temperature=0.8,
    top_k=40,
    top_p=0.95,
    stop=["\n"],
    max_threads=3,
    disable_tqdm=True,
):
    preds = []
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_threads) as executor:
        futures = [
            executor.submit(generate_with_OpenAI_model, prompt, model_ckpt, max_tokens, temperature, top_k, top_p, stop)
            for _ in range(n)
        ]
        for i, future in tqdm(
            enumerate(concurrent.futures.as_completed(futures)),
            total=len(futures),
            desc="running evaluate",
            disable=disable_tqdm,
        ):
            ans = future.result()
            preds.append(ans)
    return preds

if __name__ =="__main__":
    ans =generate_with_OpenAI_model("你好",model_ckpt="qwen2.5-7b-instruct")
    print(ans)
    ans = generate_n_with_OpenAI_model("你好",model_ckpt="qwen2.5-7b-instruct")
    print(ans)