# PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 python3 scripts/summarize_from_feedback/gen_tune.py --sanity_check=True --overwrite=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field, asdict
from typing import Optional, Text

import tyro
import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

from inference_time_alignment.utils import set_seeds, extract_responses
from utils import load_instruction_dataset


@dataclass
class GenConfig:
    max_new_tokens: Optional[int] = 128
    do_sample: Optional[bool] = True
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 1.0
    num_beams: Optional[int] = 1

@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="/mnt/petrelfs/share_data/llm-safety/models/gpt2-summarize-dpo-v2")
    dataset_name: str = field(default="anonymized_for_nips/openai_summarize_comparisons_relabel")
    output_path: Text = field(default="tmp/summarize_from_feedback/gen_tune.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    load_in_4bit: Optional[bool] = field(default=False)
    generation_configs: GenConfig = field(default_factory=lambda: GenConfig())

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()

tune = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # use_flash_attention_2=True,
    load_in_4bit=script_args.load_in_4bit,
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)

# sample
results = []
for raw_prompt in tqdm.tqdm(dataset["raw_prompt"]):
    prompt = raw_prompt + "TL;DR: "
    prompt_tokenized = tokenizer(
        prompt, 
        return_tensors="pt", 
        add_special_tokens=False,
    )
    outputs = tune.generate(
        input_ids=prompt_tokenized["input_ids"].cuda(),
        attention_mask=prompt_tokenized["attention_mask"].cuda(),
        **asdict(script_args.generation_configs),
    )
    response = extract_responses(outputs, tokenizer, prompt=prompt)[0]
    response = response.split('\n')[0]
    results.append({
        "prompt": raw_prompt,
        "response": response,
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
