# PYTHONPATH=. srun -p llm-safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 python3 scripts/summarize_from_feedback/gen_eft.py --sanity_check=True --rank=0 --world_size=2
import os
from dataclasses import dataclass, field
from typing import Optional, Text

import tyro
import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
from datasets import Dataset

from inference_time_alignment.utils import set_seeds, StopOnStringCriteria
from inference_time_alignment.decoder import EFTPosthocGenerationMixin
from inference_time_alignment.model import PrefixPreTrainedWrapper
from utils import load_instruction_dataset

DEFAULT_PROMPT_TEMPLATE = """\
SUBREDDIT: r/relationships\nTITLE: Me [21F] with my boyfriend [19M] of 2months has a gaming habit.\nPOST: I tried to start a convo with the boyfriend everyday but it seems to be making me a little depress because he's always playing video games than paying attention to me. I'm not trying to be an attention but it's seems to be a bad habit of his. I don't know what to do or how to even confront him about it. Any IDEAS?
TL;DR: Boyfriend has a gaming habit and I don't know how to confront him about it. I'm not trying to be an attention whore but it's making me feel bad.

SUBREDDIT: r/relationships\nTITLE: Me [21 M] with my...idk [19 F] Just not sure what to do.\nPOST: Went on vacation 1 1/2 years ago\n\nMet an amazing girl\n\nSpent a lot of time together\n\nHad to leave\n\nWe had agreed it would be ok to see other people\n\nBut we keep in contact and talk about how much we miss each other all the time\n\nStill have feelings for her\n\nShe just entered a relationship recently\n\nIt bothers me\n\nIdk if I should tell her how I feel or if I am just idealizing something we had and should move on.
TL;DR: Had an amazing time with this girl before we had to leave for summer vacation 1 1/2 years ago. Still have feelings for her and want to pursue relationship w/ her. Don't know whether to tell her or not.

{raw_prompt}
TL;DR:\
"""

@dataclass
class EFTGenConfig:
    max_new_tokens: Optional[int] = 128
    do_sample: Optional[bool] = True
    temperature: Optional[float] = 0.7
    eos_strings: Optional[list] = field(default_factory=lambda: ["\n"])
    # scorer
    beta: Optional[float] = 1.0

@dataclass(kw_only=True)
class ScriptArguments:
    model_name: str = field(default="openai-community/gpt2-large")
    prompt_template: str = field(default=DEFAULT_PROMPT_TEMPLATE)
    dataset_name: str = field(default="anonymized_for_nips/openai_summarize_comparisons_relabel")
    output_path: Text = field(default="tmp/summarize_from_feedback/gen_eft.jsonl")
    overwrite: Optional[bool] = field(default=False)
    rank: Optional[int] = field(default=0)
    world_size: Optional[int] = field(default=1)
    seed: Optional[int] = field(default=1)
    sanity_check: Optional[bool] = field(default=False)
    generation_configs: EFTGenConfig = field(default_factory=lambda: EFTGenConfig())

script_args = tyro.cli(ScriptArguments)
set_seeds(script_args.seed)

# init datasets
dataset = load_instruction_dataset(script_args.dataset_name, script_args)
if os.path.exists(script_args.output_path) and not script_args.overwrite:
    exit()
 
tune_gpt2_small = AutoModelForCausalLM.from_pretrained(
    "/mnt/petrelfs/share_data/llm-safety/models/gpt2-summarize-dpo-v2",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
base_gpt2_small = AutoModelForCausalLM.from_pretrained(
    "/mnt/petrelfs/share_data/llm-safety/models/gpt2-summarize-v2",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
base_gpt2_large = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
stopping_criteria = StoppingCriteriaList([
    StopOnStringCriteria(
        start_length=0, 
        eos_string=eos_string, 
        tokenizer=tokenizer
    ) for eos_string in script_args.generation_configs.eos_strings]
)

# sample
results = []
for raw_prompt in tqdm.tqdm(dataset["raw_prompt"]):
    eft_model  = EFTPosthocGenerationMixin(
        base   = PrefixPreTrainedWrapper(base_gpt2_large, tokenizer, script_args.prompt_template.format(raw_prompt=raw_prompt)),
        tune_r = PrefixPreTrainedWrapper(tune_gpt2_small, tokenizer, f'{raw_prompt}TL;DR: '),
        base_r = PrefixPreTrainedWrapper(base_gpt2_small, tokenizer, f'{raw_prompt}TL;DR: '),
        w      = script_args.generation_configs.beta,
    )
    outputs = eft_model.generate(
        max_new_tokens=script_args.generation_configs.max_new_tokens,
        do_sample=script_args.generation_configs.do_sample,
        temperature=script_args.generation_configs.temperature,
        stopping_criteria=stopping_criteria,
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    for eos_string in script_args.generation_configs.eos_strings:
        response = response.split(eos_string)[0]
    results.append({
        "prompt": raw_prompt,
        "response": response,
    })

dataset = Dataset.from_list(results)
dataset.to_json(script_args.output_path)
