import numpy as np
import torch
import random
from transformers import BartForConditionalGeneration, BartTokenizer

from cell.algorithms.CELL_budget import CELL

seed = 3
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = 'cuda'
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-xsum")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum", forced_bos_token_id=0).to(device)

class model_cem():
    def __init__(self, model, tokenizer, device='cuda'):
        self._model = model
        self._tokenizer = tokenizer
        self._device = device

    def generate(self, input_text):
        input_ids = self._tokenizer(input_text, return_tensors="pt").input_ids.to(self._device)
        outputs = self._model.generate(input_ids)
        output_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True) # output from input text prompt
        return output_text

model_expl = model_cem(model, tokenizer)

explainer = CELL(model_expl, num_return_sequences=4, infiller='t5', metric='bleu')

# Text taken from dataset EdinburghNLP/xsum on HuggingFace
input_text = "Barcelona forward Messi, 29, made his decision in June after missing a penalty in the shootout as Argentina lost to Chile in the Copa America, a fourth major final loss in nine years. Bauza, who succeeded Gerardo Martino, said: My sole intention is to see if I can talk football with Messi. From that will come the possibility of him being called up in our next games. Argentina face 2018 World Cup qualifiers at home to Uruguay and away to Venezuela in the first week of September. They are third in the 10-nation South American group with 11 points from six matches, two points behind leading pair Uruguay and Ecuador. The top four after 18 matches qualify for the finals in Russia, while the fifth-placed team goes into an intercontinental play-off for one more berth. Bauza, 58, is a former central defender who has won the Copa Libertadores South American club competition twice as a coach. Asked about Messi, he added: I want to tell him my idea and for him to tell me how things are with him and then we'll see what comes out of it. I have felt frustrated for losing a match or a final and understand that statement [of quitting] when you are overwhelmed with frustration, but I know it can be reversed."

split_k = 2
epsilon_contrastive = 0.5

prompt_contrastive, input_tokens_curr, mask_order, masks_filled = explainer.explain_instance(input_text, radius=2, split_k=split_k, epsilon_contrastive=epsilon_contrastive)
