from transformers.generation_logits_process import LogitsProcessor
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import torch
import tqdm

class OptFreqLogitsProcessor(LogitsProcessor):
    def __init__(self, eos_token_id: int):
        self.eos_token_id = eos_token_id
        self.max_length = 32
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        num_tokens = 50265
        scores[:, [i for i in range(num_tokens) if i not in range(4,2004)]] = -float("inf")
        
        cur_len = input_ids.shape[-1]
        if cur_len == self.max_length - 1:
            num_tokens = scores.shape[1]
            scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
            scores[:, self.eos_token_id] = 0
        return scores

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
device = 'cuda:1'
model = model.to(device)

prompt_text = torch.ones(4096,1).long().to(device)*2

output_list = []
for i in tqdm.tqdm(range(10000)):
    output_sequences = model.generate(
        input_ids=prompt_text,
        top_k=0,
        do_sample=True,
        logits_processor=[OptFreqLogitsProcessor(2)],
        renormalize_logits=True
    )
    output_list.append(output_sequences.data.cpu().numpy())

np.save('opt_2000.npy', np.concatenate(output_list, axis = 0))