# PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 python3 scripts/alpaca_leaderboard/gen_base.py --sanity_check=True --overwrite=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field, asdict
from typing import Optional, Text

import tyro
import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

from inference_time_alignment.utils import set_seeds, extract_responses
from scripts.alpaca_leaderboard.src.utils import get_chat_prompt_template
from utils import load_instruction_dataset, get_local_model_name


@dataclass
class GenConfig:
    max_new_tokens: Optional[int] = 2048
    do_sample: Optional[bool] = True
    temperature: Optional[float] = 0.6
    top_p: Optional[float] = 0.9
    num_beams: Optional[int] = 1

@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="meta-llama/Meta-Llama-3-8B-Instruct")
    dataset_name: str = field(default="tatsu-lab/alpaca_eval")
    output_path: Text = field(default="tmp/alpaca_leaderboard/gen_base.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    load_in_4bit: Optional[bool] = field(default=False)
    generation_configs: GenConfig = field(default_factory=lambda: GenConfig())

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()

base = AutoModelForCausalLM.from_pretrained(
    get_local_model_name(script_args.model_name),
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_flash_attention_2=True,
    load_in_4bit=script_args.load_in_4bit,
)
tokenizer = AutoTokenizer.from_pretrained(get_local_model_name(script_args.model_name))
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
prompt_template = get_chat_prompt_template(script_args.model_name, tokenizer)

# sample
results = []
for raw_prompt, ds_id in tqdm.tqdm(zip(dataset["raw_prompt"], dataset["dataset"])):
    prompt = prompt_template.format(raw_prompt=raw_prompt)
    prompt_tokenized = tokenizer(
        prompt, 
        return_tensors="pt", 
        add_special_tokens=False,
    )
    outputs = base.generate(
        input_ids=prompt_tokenized["input_ids"].cuda(),
        attention_mask=prompt_tokenized["attention_mask"].cuda(),
        eos_token_id=terminators,
        **asdict(script_args.generation_configs),
    )
    response = extract_responses(outputs, tokenizer, prompt_len=prompt_tokenized["input_ids"].size(1))[0]
    results.append({
        "instruction": raw_prompt,
        "output": response,
        "generator": f"{script_args.model_name}({str(script_args)})",
        "dataset": ds_id,
        "datasplit": "eval"
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
