import gym
import json
import torch
import numpy as np

import asyncio

import openai
from openai import OpenAI, AsyncOpenAI

import time
from tqdm import tqdm

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default="/home/extra_scratch/wentsec/Meta-Llama-3.1-70B-Instruct", help='rank of the process')
parser.add_argument('--node_id', type=str, default="26", help='id of the gpu node, used as data prefix')
parser.add_argument('--file_name', type=str, default="", help='name of the saved file')
parser.add_argument('--rank', type=int, default=1, help='rank of the process')
parser.add_argument('--file_path', type=str, default="", help='where to save the trajectory')
parser.add_argument('--batch_size', type=int, default=256, help='number batch size for each iteration')
parser.add_argument('--max_rank', type=int, default=4, help='maximum number of ranks')
parser.add_argument('--seed', type=int, default=42, help='rabdom seed')
args = parser.parse_args()

# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)

api_key = "EMPTY"
api_base = "http://localhost:8000/v1"
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
model = args.model_name
max_tokens = 8192

async def async_gpt(msg, temperature):
    completion = await client.chat.completions.create(
        model=model,
        messages=msg,
        temperature=temperature,
        max_tokens=max_tokens,
    )
    return completion.choices[0].message.content
async def async_batch_gpt(msgs_list):
    tasks = []
    for msgs in msgs_list:
        temperature = 0.0
        tasks.append(async_gpt(msgs["prompt"], temperature))
    return await asyncio.gather(*tasks)
def batch_gpt(msgs_list):
    results = asyncio.run(async_batch_gpt(msgs_list))
    return results


# wait for the server to start
test_msg = {
    "idx": 0,
    "type": "feedback",
    "prompt": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Reply to this message with Hello."}
    ]
}
while True:
    try:
        result = batch_gpt([test_msg])
        break
    except Exception as e:
        print("error:", e)
        time.sleep(.25)
        print("Waiting for the server to start...")
