# OPENAI_API_KEY="..." PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:0 --cpus-per-task=8 python3 scripts/alpaca_leaderboard/sh/openai/gpt-3.5-turbo-instruct/gen_base.py --sanity_check=True --overwrite=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field
from typing import Optional, Text

import tyro
import tqdm
from datasets import Dataset 

import time
from openai import OpenAI, RateLimitError, APIConnectionError
client = OpenAI()

from inference_time_alignment.utils import set_seeds
from utils import load_instruction_dataset



@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="openai/gpt-3.5-turbo-instruct")
    dataset_name: str = field(default="tatsu-lab/alpaca_eval")
    output_path: Text = field(default="tmp/alpaca_leaderboard/gen_gpt_3.5_turbo_instruct_base.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    wait_time: Optional[float] = field(default=1.0)

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)
script_args.model_name = script_args.model_name.split('/')[-1]

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()


# sample
results = []
for raw_prompt, ds_id in tqdm.tqdm(zip(dataset["raw_prompt"], dataset["dataset"])):
    while True:
        try:
            response = client.completions.create(
                model=script_args.model_name,
                prompt=f"{raw_prompt}.\n\n",
                seed=script_args.seed,
                max_tokens=2048,
            )
            break
        except (RateLimitError, APIConnectionError) as e:
            print(e, flush=True)
            time.sleep(script_args.wait_time)
            pass
    results.append({
        "instruction": raw_prompt,
        "output": response.choices[0].text,
        "generator": script_args.model_name,
        "dataset": ds_id,
        "datasplit": "eval"
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
