from transformers.generation_logits_process import LogitsProcessor
import numpy as np
import torch
import tqdm
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
device = 'cuda:0'
model = model.to(device)
model.half()

with open('opt_2000.txt') as f:
    x = f.readlines()
x = [xx.strip('\n') for xx in x if xx != '\n']

num_tokens = 50265

import copy

i = 0
batch_input_id = []
prod_list = []
with torch.no_grad():
    for j in tqdm.tqdm(range(100000)):
        btid = tokenizer.encode(tokenizer.tokenize(x[j]))
        if len(btid) == 33 and not (False in (np.array(btid) < 2006)):
            batch_input_id.append(btid)
            prod_vec = []
            for kk in range(1,32):
                o_input = []
                btid1 = copy.deepcopy(btid)
                for jj in range(4,2004):
                    btid1[kk] = jj
                    o_input.append(btid1)
                o_input = torch.tensor(np.array(batch_input_id).astype('int64')).to(device)
                prod_prob = compute_prob(o_input)
                prod_vec.append((prod_prob/ torch.sum(prod_prob)).data.cpu().numpy())
            prod_list.append(np.concatenate(prod_vec))

        else:
            continue


def compute_prob(o_input):
    with torch.no_grad():
        prod_prob = torch.ones(o_input.shape[0]).to(device)
        for q in range(2, 32):
            scores = model.forward(o_input[:,1:q]).logits[:, -1]
            scores[:, [i for i in range(num_tokens) if i not in range(4,2004)]] = -float("inf")
            scores = F.softmax(scores, dim = -1)
            prod_prob *= torch.diag(scores[:,o_input[:,q]])
    return prod_prob