import os
import json
import torch
import numpy as np

import asyncio

import openai
from openai import OpenAI, AsyncOpenAI

import time
from tqdm import tqdm

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default="/home/extra_scratch/wentsec/Meta-Llama-3.1-70B-Instruct", help='rank of the process')
parser.add_argument('--node_id', type=str, default="26", help='id of the gpu node, used as data prefix')
parser.add_argument('--file_path', type=str, default="", help='where to save the trajectory')
parser.add_argument('--batch_size', type=int, default=256, help='number batch size for each iteration')
parser.add_argument('--seed', type=int, default=42, help='rabdom seed')
args = parser.parse_args()

# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# api_key = "EMPTY"
# api_base = "http://localhost:8000/v1"
# client = AsyncOpenAI(api_key=api_key, base_url=api_base)
# model = args.model_name
# max_tokens = 1024

# api_key = "sk-_ZmK_SHYf93_Jf9FMM4CoOxBwwrJU8m-I1g08wy27YT3BlbkFJAKbskKNRZ58OjYRsS6gp__GWBLz5lrJI2hbdus0zcA" #"EMPTY"
api_key = "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"
client = AsyncOpenAI(api_key=api_key)
model = "gpt-4.1-mini"
max_tokens = 1024

async def async_gpt(msg, temperature):
    completion = await client.chat.completions.create(
        model=model,
        messages=msg,
        temperature=temperature,
        max_tokens=max_tokens,
    )
    return completion.choices[0].message.content
async def async_batch_gpt(msgs_list):
    tasks = []
    for msgs in msgs_list:
        temperature = 0.0
        tasks.append(async_gpt(msgs["prompt"], temperature))
    return await asyncio.gather(*tasks)
def batch_gpt(msgs_list):
    results = asyncio.run(async_batch_gpt(msgs_list))
    return results


# wait for the server to start
test_msg = {
    "idx": 0,
    "type": "feedback",
    "prompt": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Reply to this message with Hello."}
    ]
}
while True:
    try:
        result = batch_gpt([test_msg])
        break
    except Exception as e:
        time.sleep(.25)
        print("Waiting for the server to start...")

# get traj from traj.json
file_name = os.path.join(args.file_path, "reflect.json")
with open(file_name, "r") as f:
    feedback_prompts = json.load(f)

# variables
gpt_cache = []
feedback_responses = []
checked = []

# hyper-parameters
max_retry_time = 3
mini_batch_size = 32 # if args.node_id == "test" else 20
num_data = len(feedback_prompts)

# get feedback prompt
for idx in range(num_data):
    feedback_responses.append(None)
    checked.append(False)

idx = 0
last_idx = num_data - 1
while (np.array(checked)==True).sum() < args.batch_size:

    # get feedback from GPT
    if not checked[idx]:
        gpt_cache.append({
            "idx": idx,
            "prompt": feedback_prompts[idx]
        })
        last_idx = idx

    idx = (idx + 1) % num_data
    
    if len(gpt_cache) == mini_batch_size or last_idx == idx:
        # query GPT
        num_retry = 0
        while num_retry < max_retry_time:
            try:
                responses = batch_gpt(gpt_cache)
                break
            except Exception as e:
                time.sleep(1.)
                num_retry += 1
        assert num_retry < max_retry_time, "Failed to get GPT response"
        # check responses
        for i in range(len(responses)):
            correct_format = "Feedback: " in responses[i]
            if correct_format:
                data_idx = gpt_cache[i]["idx"]
                # print("="*30)
                # print(responses[i])
                # print("="*30)
                response = responses[i].split("Feedback: ")[1].strip()
                feedback_responses[data_idx] = response
                checked[data_idx] = True
                
        # clear cache
        gpt_cache = []
        
file_name = args.file_path + "/agent.json"
with open(file_name, "r") as f:
    agent_prompt = json.load(f)

def deepcopy_list(list_msgs):
    list_msgs_copy = []
    for msg in list_msgs:
        list_msgs_copy.append({"role": msg["role"], "content": msg["content"]})
    return list_msgs_copy

file_name = args.file_path + "/feedback.json"
with open(file_name, "a") as f:
    json.dump(feedback_responses, f)

