# PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 python3 scripts/alpaca_leaderboard/gen_bt.py --sanity_check=True --overwrite=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field
from typing import Optional, Text

import tyro
import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

from inference_time_alignment.utils import set_seeds, extract_responses
from inference_time_alignment.decoder import BeamTuningPosthocGenerationMixin
from inference_time_alignment.scorer import ImplicitRewardScorer
from utils import load_instruction_dataset


@dataclass
class BeamTuningGenConfig:
    max_new_tokens: Optional[int] = 2048
    temperature: Optional[float] = 0.6
    top_p: Optional[float] = 0.9
    # new
    num_beams: Optional[int] = 2
    num_candidates: Optional[int] = 4
    block_len: Optional[int] = 100000


@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="meta-llama/Meta-Llama-3-8B-Instruct")
    prompt_template: str = field(default="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{raw_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")
    dataset_name: str = field(default="tatsu-lab/alpaca_eval")
    output_path: Text = field(default="tmp/alpaca_leaderboard/gen_bt.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    load_in_4bit: Optional[bool] = field(default=False)
    generation_configs: BeamTuningGenConfig = field(default_factory=lambda: BeamTuningGenConfig())

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()

def get_scorer(beta, average_log_prob, reference_free):
    model = AutoModelForCausalLM.from_pretrained(
        "HuggingFaceH4/zephyr-7b-beta",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        load_in_4bit=script_args.load_in_4bit,
    )
    ref_model = AutoModelForCausalLM.from_pretrained(
        "HuggingFaceH4/mistral-7b-sft-beta",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        load_in_4bit=script_args.load_in_4bit,
    )
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    prompt_template = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": ""},
            {"role": "user",   "content": "{raw_prompt}"},
        ],
        tokenize=False, 
        add_generation_prompt=True
    )
    implicit_reward_scorer = ImplicitRewardScorer(
        model=model,
        ref_model=ref_model,
        tokenizer=tokenizer,
        model_prompt_template=prompt_template,
        ref_model_prompt_template=prompt_template,
        beta=beta,
        average_log_prob=average_log_prob,
        reference_free=reference_free,
    )
    return implicit_reward_scorer

base = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_flash_attention_2=True,
    load_in_4bit=script_args.load_in_4bit,
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)

bt_model = BeamTuningPosthocGenerationMixin(base, tokenizer)
# scorer   = get_scorer(
#     beta=script_args.generation_configs.beta, 
#     average_log_prob=script_args.generation_configs.average_log_prob,
#     reference_free=script_args.generation_configs.reference_free,
# )
class DummyScorer:
    def __call__(self, input):
        return torch.zeros(len(input["response"])).cuda()

scorer = DummyScorer()

# sample
results = []
for raw_prompt, ds_id in tqdm.tqdm(zip(dataset["raw_prompt"], dataset["dataset"])):
    prompt = script_args.prompt_template.format(raw_prompt=raw_prompt)
    prompt_tokenized = tokenizer(
        script_args.prompt_template.format(raw_prompt=raw_prompt), 
        return_tensors="pt", 
        add_special_tokens=False,
    )
    outputs = bt_model.bon_beam_sample(
        input_ids=prompt_tokenized["input_ids"].cuda(),
        attention_mask=prompt_tokenized["attention_mask"].cuda(),
        scorer=scorer,
        # 
        max_new_tokens=script_args.generation_configs.max_new_tokens,
        temperature=script_args.generation_configs.temperature,
        top_p=script_args.generation_configs.top_p,
        # 
        num_beams=script_args.generation_configs.num_beams,
        num_candidates=script_args.generation_configs.num_candidates,
        block_len=script_args.generation_configs.block_len, # block_len = inf to enable sequence wise bon
        # return_dict_in_generate=True,
    )
    response = extract_responses(outputs, prompt, tokenizer)[0]
    results.append({
        "instruction": raw_prompt,
        "output": response,
        "generator": f"{script_args.model_name}({str(script_args)})",
        "dataset": ds_id,
        "datasplit": "eval"
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
