# %%
import shap
import transformers
import nlp
import torch
import numpy as np

# load a BERT sentiment analysis model
tokenizer = transformers.DistilBertTokenizerFast.from_pretrained(
    "distilbert-base-uncased"
)
model = transformers.DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
)


def entropy(x):
    _x = x
    logp = np.log(_x)
    plogp = np.multiply(_x, logp)
    out = np.sum(plogp, axis=1)
    return -out


def f(x):
    tv = torch.tensor(
        [tokenizer.encode(v, pad_to_max_length=True, max_length=500) for v in x]
    )
    outputs = model(tv)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = entropy(scores)
    # val = sp.special.logit(scores[:,1]) # use one vs rest logit units
    return val


# %%
imdb_train = nlp.load_dataset("imdb")["train"]
# %%
background = 100
test_reviews = []
for i in range(background):
    if len(imdb_train[i]["text"]) < 300:
        test_reviews.append(i)

subset = imdb_train.select(test_reviews)
# %%
# build an explainer using a token masker
explainer = shap.Explainer(f, tokenizer)
# %%
# explain the model's predictions on IMDB reviews
shap_values = explainer(subset)
# %%
for i in range(len(shap_values)):
    print(i)
    uncertainty = shap_values[i].base_values + np.sum(shap_values[i].values)
    print(uncertainty)
    if uncertainty > 0.5:
        file = open(str(i) + ".html", "w")
        file.write(shap.plots.text(shap_values[i], display=False))
        file.close
