
from transformers import AutoModelForMaskedLM , AutoTokenizer
import torch

# Prompting_roberta
class Prompting_roberta():
 

  # init
  def __init__(self, device, k=20,**kwargs): # **kwargs : dict

    # path
    model_path=kwargs['model']
    tokenizer_path= kwargs['model']

    ##### tokenizer path#####
    if "tokenizer" in kwargs.keys():
      tokenizer_path= kwargs['tokenizer']

    # model and tokenizer
    self.device = device

    self.model = AutoModelForMaskedLM.from_pretrained(model_path).to(self.device) 
    self.tokenizer = AutoTokenizer.from_pretrained(model_path)

    self.mask_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
    
    self.top_k = k
    print(f'top_k : {self.top_k}')


  # predict
  def prompt_pred(self, reviews, pattern): # k : Top-K

    # length adjustment
    reviews_indexed_tokens=self.tokenizer(reviews, return_tensors="pt").input_ids.to(self.device)    
    pattern_indexed_tokens=self.tokenizer(pattern, return_tensors="pt").input_ids.to(self.device) 

    # length filitering
    reviews_indexed_tokens = reviews_indexed_tokens[0][:-1].unsqueeze(0)  
    reviews_indexed_tokens = reviews_indexed_tokens[0][:(512 - (len(pattern_indexed_tokens[0] - 1)))].unsqueeze(0)

    # erase [CLS]
    pattern_indexed_tokens = pattern_indexed_tokens[0][1:].unsqueeze(0)

    # concat
    indexed_tokens = torch.cat([reviews_indexed_tokens, pattern_indexed_tokens], dim=1)

    # tokenized_text= self.tokenizer.convert_ids_to_tokens(indexed_tokens[0]) 
    # print(tokenized_text)

    # position [MASK] 
    # mask_pos=tokenized_text.index(self.tokenizer.mask_token)
    mask_pos = (indexed_tokens[0] == self.mask_id).nonzero().item() 
    
    ######### predict #########
    self.model.eval()
    with torch.no_grad():


      outputs = self.model(indexed_tokens)    ##### inference #####    
      predictions = outputs[0]                ##### get logic #####    

  
    _, indices=torch.sort(predictions[0, mask_pos],  descending=True)  ###### sort
    # values = torch.nn.functional.softmax(values, dim=0)              

    # indices_top_k = indices[:, :10]
    indices_top_k = indices[:self.top_k]                               ###### get top_k indices

    # return (top_k) token
    result = self.tokenizer.convert_ids_to_tokens(indices_top_k)     

    # result = self.tokenizer.convert_tokens_to_string(result).split() 
    result = [self.tokenizer.convert_tokens_to_string([result]).strip() for result in result] 

    return result                          #######   return : list