# output probs when LLM generates a sentence
# import nltk
from nltk import word_tokenize, pos_tag
import torch
import numpy as np
# from minigpt4.conversation.output_probs import output_probs



def output_probs(output_idss, probs, tokenizer):
    # output_ids = outputs.sequences#[0]
    # probs = outputs.scores

    # # output_text = tokenizer.decode(output_token, skip_special_tokens=True)
        
    # input_token_len = input_ids.shape[1]
    output_texts = tokenizer.batch_decode(
        output_idss[:, :], skip_special_tokens=True)
    # output_text = tokenizer.batch_decode(output_ids[:, :], skip_special_tokens=True)
    # if self.mode == 'rewrite':
    #     output_text = output_text.split('###')[0].split('\n')[0] 
    # else:
    #     output_text = output_text.split('###')[0]             

    # output_text = output_text.split('Assistant:')[-1].strip()

    results = []
    for j in range(len(output_texts)):
        # assert len(probs) == len(output_text)
        output_text = output_texts[j]
        output_ids = output_idss[j]
        tokens = word_tokenize(output_text.strip())
        pos_tags = pos_tag(tokens)
    
        u_wordlist=list()
        wordlist = list()
        p_list= list()
        p_all = {}
        
        for word0, pos0 in pos_tags:
            token0 = tokenizer.encode(word0, return_tensors="pt", add_special_tokens=False)
            if word0 not in p_all.keys():
                p_all[word0] = list()
            if word0 not in wordlist and pos0.startswith('NN'):
                wordlist.append(word0)  
                            
            for i in range(token0.shape[0]):
                token = token0[0, i]
                if torch.where(output_ids == token)[0].numel() != 0:
                    toke_idx = torch.where(output_ids == token)[0][0]
                    prob0 = probs[toke_idx][j][token].cpu().item()
                    p_all[word0].append(prob0)
                    if -np.log(prob0)>0.9:
                        if word0 not in u_wordlist and pos0.startswith('NN'):
                            u_wordlist.append(word0)
                            p_list.append(prob0)
                            break
        results.append({"caption": output_text,"objs": wordlist, "plist": p_list, "p_all": p_all})


    print(output_text, p_all)
    return results
    # return output_text, u_wordlist, wordlist, p_list, p_all

# output, _,  _, u_wordlist, wordlist, plist, p_all = chat.answer(chat_state, img_list)
#                     result = {"id": filename, "question": prompt, "caption": output,"objs": wordlist, "plist": float_list, "p_all": p_all, "model": "MiniGPT-4_13b"}
