# PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 python3 scripts/alpaca_leaderboard/gen_eft.py --overwrite=True --sanity_check=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field
from typing import Optional, Text

import tyro
import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

from inference_time_alignment.utils import set_seeds
from inference_time_alignment.decoder import EFTPosthocGenerationMixin
from inference_time_alignment.model import PrefixPreTrainedWrapper
from scripts.alpaca_leaderboard.src.utils import get_chat_prompt_template, get_scorer
from utils import load_instruction_dataset, get_local_model_name

@dataclass
class EFTGenConfig:
    max_new_tokens: Optional[int] = 2048
    do_sample: Optional[bool] = True
    temperature: Optional[float] = 0.7
    # scorer
    beta: Optional[float] = 1.0

@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="meta-llama/Llama-2-7b-chat-hf")
    scorer_name: str = field(default="zephyr-7b-beta")
    dataset_name: str = field(default="tatsu-lab/alpaca_eval")
    output_path: Text = field(default="tmp/alpaca_leaderboard/gen_eft.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    load_in_4bit: Optional[bool] = field(default=False)
    generation_configs: EFTGenConfig = field(default_factory=lambda: EFTGenConfig())

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()

base = AutoModelForCausalLM.from_pretrained(
    get_local_model_name(script_args.model_name),
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_flash_attention_2=True,
    load_in_4bit=script_args.load_in_4bit,
)
tokenizer = AutoTokenizer.from_pretrained(get_local_model_name(script_args.model_name))
prompt_template = get_chat_prompt_template(script_args.model_name, tokenizer)

scorer   = get_scorer(
    scorer_name=script_args.scorer_name,
    load_in_4bit=script_args.load_in_4bit,
    beta=script_args.generation_configs.beta, 
    average_log_prob=False,
    reference_free=False,
)

# sample
results = []
for raw_prompt, ds_id in tqdm.tqdm(zip(dataset["raw_prompt"], dataset["dataset"])):
    eft_model  = EFTPosthocGenerationMixin(
        base   = PrefixPreTrainedWrapper(base, tokenizer,  prompt_template.format(raw_prompt=raw_prompt)),
        tune_r = PrefixPreTrainedWrapper(scorer.model,     scorer.tokenizer, scorer.model_prompt_template.format(raw_prompt=raw_prompt)),
        base_r = PrefixPreTrainedWrapper(scorer.ref_model, scorer.tokenizer, scorer.model_prompt_template.format(raw_prompt=raw_prompt)),
        w      = script_args.generation_configs.beta,
    )
    outputs = eft_model.generate(
        max_new_tokens=script_args.generation_configs.max_new_tokens,
        do_sample=script_args.generation_configs.do_sample,
        temperature=script_args.generation_configs.temperature,
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    results.append({
        "instruction": raw_prompt,
        "output": response,
        "generator": f"{script_args.model_name}({str(script_args)})",
        "dataset": ds_id,
        "datasplit": "eval"
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
