from transformers import AutoModelForCausalLM, AutoTokenizer


class PROMPTIST():
    def __init__(self, device='cpu'):
        self.prompter_model, self.prompter_tokenizer = self._load_prompter()
        self.prompter_model = self.prompter_model.to(device)
        self.device = device

    def generate(self, plain_text, n=10):
        input_ids = self.prompter_tokenizer(plain_text.strip()+" Rephrase:", return_tensors="pt").input_ids.to(self.device)
        eos_id = self.prompter_tokenizer.eos_token_id
        outputs = self.prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, num_beams=n, num_return_sequences=n, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
        output_texts = self.prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
        opt_prompts = [output_text.replace(plain_text+" Rephrase:", "").strip() for output_text in output_texts]
        return opt_prompts

    def _load_prompter(self):
        prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        return prompter_model, tokenizer


# examples = ["A rabbit is wearing a space suit", "Several railroad tracks with one train passing by", "The roof is wet from the rain", "Cats dancing in a space club"]

# # res = generate(examples[0])

# promptist = PROMPTIST()

# opt_prompt = promptist.generate(examples[0])
# print(opt_prompt)
