from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    torch_dtype=torch.float16,
    device_map="auto",
)


def paraphrase(prompt, k=10):
    system = "You are a paraphrasing engine. Preserve every tag and keyword."
    chat = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": system},
            {
                "role": "user",
                "content": f"Write {k} alternative phrasings of:\n'''{prompt}'''. Return only the paraphrasings, no other text. The format should be [\"prompt1\", \"prompt2\", ...]",
            },
        ],
        return_tensors="pt",
    ).to(model.device)

    while True:
        try:
            out = model.generate(
                input_ids=chat,
                pad_token_id=tokenizer.eos_token_id,
                max_new_tokens=256,
                do_sample=True,
                temperature=0.9,
                top_p=0.95,
            )
            txt = tokenizer.decode(out[0], skip_special_tokens=True).split("assistant")[
                1
            ]
            variants = eval(txt)
            return variants[:k]
        except Exception as e:
            print(e)
            print(txt)
            continue


import pandas as pd, tqdm, json, numpy as np

all_prompts = pd.read_csv(
    "prompts/memorized_laion_prompts.csv", sep=";"
).Caption.tolist()
results = {}
for prompt in tqdm.tqdm(all_prompts):
    variants = paraphrase(prompt)
    results[prompt] = variants[:10]  # after post-filtering
json.dump(results, open("results/paraphrases.json", "w"), indent=2)
