# OPENAI_API_KEY="..." PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 python3 scripts/alpaca_leaderboard/sh/openai/gpt-3.5-turbo-instruct/gen_bt.py --sanity_check=True --overwrite=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field, asdict
from typing import Optional, Text

import tyro
import tqdm
import torch
from datasets import Dataset 

import time
from openai import OpenAI, RateLimitError, APIConnectionError
client = OpenAI()

from inference_time_alignment.utils import set_seeds
from scripts.alpaca_leaderboard.src.utils import get_scorer
from utils import load_instruction_dataset


@dataclass
class BeamTuningGenConfig:
    max_new_tokens: Optional[int] = 2048
    temperature: Optional[float] = 0.6
    top_p: Optional[float] = 0.9
    # new
    num_beams: Optional[int] = 2
    num_candidates: Optional[int] = 2
    block_len: Optional[int] = 100
    # scorer
    beta: Optional[float] = 1
    average_log_prob: Optional[bool] = False
    reference_free: Optional[bool] = False


@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="openai/gpt-3.5-turbo-instruct")
    scorer_name: str = field(default="zephyr-7b-beta")
    dataset_name: str = field(default="tatsu-lab/alpaca_eval")
    output_path: Text = field(default="tmp/alpaca_leaderboard/gen_gpt_3.5_turbo_instruct_bt.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    wait_time: Optional[float] = field(default=1.0)
    load_in_4bit: Optional[bool] = field(default=False)
    generation_configs: BeamTuningGenConfig = field(default_factory=lambda: BeamTuningGenConfig())

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)
script_args.model_name = script_args.model_name.split('/')[-1]

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()

scorer = get_scorer(
    scorer_name=script_args.scorer_name,
    load_in_4bit=script_args.load_in_4bit,
    beta=script_args.generation_configs.beta, 
    average_log_prob=script_args.generation_configs.average_log_prob,
    reference_free=script_args.generation_configs.reference_free,
)

def cbs_gpt_completion_response(model="gpt-3.5-turbo-instruct", raw_prompt=None, num_beams=2, num_candidates=2, block_len=100, seed=0, scorer=None):
    beams = ["" for i in range(num_beams)]
    finisheds = [False for i in range(num_beams)]

    while True:

        new_beams = []
        new_finisheds = []
        for i, (beam, finished) in enumerate(zip(beams, finisheds)):
            if finished: 
                new_beams.append(beam)
                new_finisheds.append(finished)
                continue
            response = client.completions.create(
                model=model,
                prompt=f"{raw_prompt}.\n\n{beam}",
                n=num_candidates,
                max_tokens=block_len,
                seed=seed+i,
            )
            for choice in response.choices:
                new_beams.append(beam + choice.text)
                new_finisheds.append(choice.finish_reason in ("stop", "content_filter") or response.usage.total_tokens >= script_args.generation_configs.max_new_tokens)

        beam_scores = scorer(
            {
                "response": new_beams,
                "eos": new_finisheds,
            },
        )

        _, beam_idx = torch.topk(
            beam_scores, 
            min(num_beams, len(beam_scores)), dim=0, largest=True, sorted=True
        )

        beams = [new_beams[idx] for idx in beam_idx]
        finisheds = [new_finisheds[idx] for idx in beam_idx]

        if all(finisheds):
            break

    return beams[0]

# sample
results = []
for raw_prompt, ds_id in tqdm.tqdm(zip(dataset["raw_prompt"], dataset["dataset"])):
    print(len(results), flush=True)
    while True:
        try:
            response = cbs_gpt_completion_response(
                model=script_args.model_name,
                raw_prompt=raw_prompt,
                num_beams=script_args.generation_configs.num_beams,
                num_candidates=script_args.generation_configs.num_candidates,
                block_len=script_args.generation_configs.block_len, # block_len = inf to enable sequence wise bon
                seed=script_args.seed,
                scorer=scorer.set_raw_prompt(raw_prompt),
            )
            break
        except (RateLimitError, APIConnectionError) as e:
            print(e, flush=True)
            time.sleep(script_args.wait_time)
            pass
    results.append({
        "instruction": raw_prompt,
        "output": response,
        "generator": script_args.model_name,
        "dataset": ds_id,
        "datasplit": "eval"
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
