# PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 python3 scripts/imdb/gen_eft.py --sanity_check=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field
from typing import Optional, Text

import tyro
import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

from inference_time_alignment.utils import set_seeds
from inference_time_alignment.decoder import EFTPosthocGenerationMixin
from inference_time_alignment.model import PrefixPreTrainedWrapper
from utils import load_instruction_dataset

@dataclass
class EFTGenConfig:
    max_new_tokens: Optional[int] = 50
    do_sample: Optional[bool] = True
    temperature: Optional[float] = 0.7
    # scorer
    beta: Optional[float] = 1.0

@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="openai-community/gpt2-large")
    prompt_template: str = field(default="Here is a movie review from imdb: {raw_prompt}")
    dataset_name: str = field(default="anonymized_for_nips/imdb_preference")
    output_path: Text = field(default="tmp/imdb/gen_eft.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    generation_configs: EFTGenConfig = field(default_factory=lambda: EFTGenConfig())

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()
 
tune_gpt2_small = AutoModelForCausalLM.from_pretrained(
    "/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb-dpo",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
base_gpt2_small = AutoModelForCausalLM.from_pretrained(
    "/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
base_gpt2_large = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

# sample
results = []
for raw_prompt in tqdm.tqdm(dataset["raw_prompt"]):
    eft_model  = EFTPosthocGenerationMixin(
        base   = PrefixPreTrainedWrapper(base_gpt2_large, tokenizer, script_args.prompt_template.format(raw_prompt=raw_prompt)), # Here is a movie review from imdb: I think
        tune_r = PrefixPreTrainedWrapper(tune_gpt2_small, tokenizer, raw_prompt), # I think
        base_r = PrefixPreTrainedWrapper(base_gpt2_small, tokenizer, raw_prompt), # I think
        w      = script_args.generation_configs.beta,
    )
    outputs = eft_model.generate(
        max_new_tokens=script_args.generation_configs.max_new_tokens,
        do_sample=script_args.generation_configs.do_sample,
        temperature=script_args.generation_configs.temperature,
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    results.append({
        "prompt": raw_prompt,
        "response": response,
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
