# PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 python3 scripts/imdb/gen_bt.py --sanity_check=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field
from typing import Optional, Text

import tyro
import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

from inference_time_alignment.utils import set_seeds, extract_responses
from inference_time_alignment.decoder import BeamTuningPosthocGenerationMixin
from inference_time_alignment.scorer import ImplicitRewardScorer
from utils import load_instruction_dataset

@dataclass
class BeamTuningGenConfig:
    max_new_tokens: Optional[int] = 50
    temperature: Optional[float] = 0.7
    # new
    num_beams: Optional[int] = 4
    num_candidates: Optional[int] = 4
    block_len: Optional[int] = 5
    # scorer
    beta: Optional[float] = 1
    average_log_prob: Optional[bool] = False
    reference_free: Optional[bool] = False
    sft_free: Optional[bool] = False

@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="meta-llama/Llama-2-7b-hf")
    prompt_template: str = field(default="Here is a movie review from imdb: {raw_prompt}")
    dataset_name: str = field(default="anonymized_for_nips/imdb_preference")
    output_path: Text = field(default="tmp/imdb/gen_bt.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    generation_configs: BeamTuningGenConfig = field(default_factory=lambda: BeamTuningGenConfig())

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)
assert not (script_args.generation_configs.reference_free and script_args.generation_configs.sft_free)

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()

def get_scorer(beta, average_log_prob, reference_free, sft_free):
    # scorer that encourages positive movie review
    model = AutoModelForCausalLM.from_pretrained(
        "/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb-dpo",
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    ref_model = AutoModelForCausalLM.from_pretrained(
        ("/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb" if not sft_free else "openai-community/gpt2"),
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained("/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb-dpo")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    implicit_reward_scorer = ImplicitRewardScorer(
        model=model,
        ref_model=ref_model,
        tokenizer=tokenizer,
        beta=beta,
        average_log_prob=average_log_prob,
        reference_free=reference_free,
    )
    return implicit_reward_scorer

# steer llama-2-7b-chat to give positive(beta>0)/negative(beta<0) movie feedback
base = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)

bt_model = BeamTuningPosthocGenerationMixin(base, tokenizer)
scorer   = get_scorer(
    beta=script_args.generation_configs.beta, 
    average_log_prob=script_args.generation_configs.average_log_prob,
    reference_free=script_args.generation_configs.reference_free,
    sft_free=script_args.generation_configs.sft_free
)

# sample
results = []
for raw_prompt in tqdm.tqdm(dataset["raw_prompt"]):
    prompt = script_args.prompt_template.format(raw_prompt=raw_prompt)
    prompt_tokenized = tokenizer(
        prompt, 
        return_tensors="pt", 
        add_special_tokens=False,
    )
    outputs = bt_model.bon_beam_sample(
        input_ids=prompt_tokenized["input_ids"].cuda(),
        attention_mask=prompt_tokenized["attention_mask"].cuda(),
        scorer=scorer.set_raw_prompt(raw_prompt),
        # 
        max_new_tokens=script_args.generation_configs.max_new_tokens,
        temperature=script_args.generation_configs.temperature,
        # 
        num_beams=script_args.generation_configs.num_beams,
        num_candidates=script_args.generation_configs.num_candidates,
        block_len=script_args.generation_configs.block_len, # block_len = inf to enable sequence wise bon
        # return_dict_in_generate=True,
    )
    response = extract_responses(outputs, tokenizer, prompt=prompt)[0]
    results.append({
        "prompt": raw_prompt,
        "response": response,
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
