import numpy as np
import torch
import random
from transformers import T5Tokenizer, T5ForConditionalGeneration

from cell.algorithms.CELL_budget import CELL

seed = 1
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")

class model_cem():
    def __init__(self, model, tokenizer, device='cuda'):
        self._model = model
        self._tokenizer = tokenizer
        self._device = device

    def generate(self, input_text):
        input_ids = self._tokenizer(input_text, return_tensors="pt").input_ids.to(self._device)
        outputs = self._model.generate(input_ids)
        output_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True) # output from input text prompt
        return output_text

model_expl = model_cem(model, tokenizer)

explainer = CELL(model_expl, num_return_sequences=4, infiller='t5', metric='contradiction')

split_k = 2
epsilon_contrastive = 0.25

# Text taken from Moral Integrity Corpus [Caleb Ziems, Jane Yu, Yi-Chia Wang, Alon Halevy, and Diyi Yang. The moral integrity corpus: A benchmark for ethical dialogue systems. In ACL, 2022]
input_text = "What are some tips poor people can give to rich people so they don't go broke?"
prompt_contrastive, input_tokens_curr, mask_order, masks_filled = explainer.explain_instance(input_text, split_k=split_k, epsilon_contrastive=epsilon_contrastive)