# PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:2 --cpus-per-task=8 python3 scripts/alpaca_leaderboard/gen_bt.py --sanity_check=True --overwrite=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field
from typing import Optional, Text

import tyro
import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

from inference_time_alignment.utils import set_seeds, extract_responses
from inference_time_alignment.decoder import BeamTuningPosthocGenerationMixin
from scripts.alpaca_leaderboard.src.utils import get_chat_prompt_template, get_scorer
from utils import load_instruction_dataset, get_local_model_name


@dataclass
class BeamTuningGenConfig:
    max_new_tokens: Optional[int] = 2048
    temperature: Optional[float] = 0.6
    top_p: Optional[float] = 0.9
    # new
    num_beams: Optional[int] = 2
    num_candidates: Optional[int] = 4
    block_len: Optional[int] = 30
    # scorer
    beta: Optional[float] = 1
    average_log_prob: Optional[bool] = False
    reference_free: Optional[bool] = False

@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="meta-llama/Meta-Llama-3-8B-Instruct")
    scorer_name: str = field(default="zephyr-7b-beta")
    dataset_name: str = field(default="tatsu-lab/alpaca_eval")
    output_path: Text = field(default="tmp/alpaca_leaderboard/gen_bt.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    load_in_4bit: Optional[bool] = field(default=False)
    generation_configs: BeamTuningGenConfig = field(default_factory=lambda: BeamTuningGenConfig())

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()

base = AutoModelForCausalLM.from_pretrained(
    get_local_model_name(script_args.model_name),
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_flash_attention_2=True,
    load_in_4bit=script_args.load_in_4bit,
)
tokenizer = AutoTokenizer.from_pretrained(get_local_model_name(script_args.model_name))
prompt_template = get_chat_prompt_template(script_args.model_name, tokenizer)

bt_model = BeamTuningPosthocGenerationMixin(base, tokenizer)
scorer   = get_scorer(
    scorer_name=script_args.scorer_name,
    load_in_4bit=script_args.load_in_4bit,
    beta=script_args.generation_configs.beta, 
    average_log_prob=script_args.generation_configs.average_log_prob,
    reference_free=script_args.generation_configs.reference_free,
)

# sample
results = []
for raw_prompt, ds_id in tqdm.tqdm(zip(dataset["raw_prompt"], dataset["dataset"])):
    prompt = prompt_template.format(raw_prompt=raw_prompt)
    prompt_tokenized = tokenizer(
        prompt, 
        return_tensors="pt", 
        add_special_tokens=False,
    )
    outputs = bt_model.bon_beam_sample(
        input_ids=prompt_tokenized["input_ids"].cuda(),
        attention_mask=prompt_tokenized["attention_mask"].cuda(),
        scorer=scorer.set_raw_prompt(raw_prompt),
        split_by_prompt_text=False,
        # 
        max_new_tokens=script_args.generation_configs.max_new_tokens,
        temperature=script_args.generation_configs.temperature,
        top_p=script_args.generation_configs.top_p,
        # 
        num_beams=script_args.generation_configs.num_beams,
        num_candidates=script_args.generation_configs.num_candidates,
        block_len=script_args.generation_configs.block_len, # block_len = inf to enable sequence wise bon
        # return_dict_in_generate=True,
    )
    response = extract_responses(outputs, tokenizer, prompt_len=prompt_tokenized["input_ids"].size(1))[0]
    results.append({
        "instruction": raw_prompt,
        "output": response,
        "generator": f"{script_args.model_name}({str(script_args)})",
        "dataset": ds_id,
        "datasplit": "eval"
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
