from prompt_util import Prompting_roberta
from prompt_config import Prompt_Config
from tqdm import tqdm
import pandas as pd

from transformers import logging
logging.set_verbosity_error()    # erase warning 


def infer_vocab(df_name, df, user_patterns, item_patterns, pickle_path, pickle_file_name) :

    print(f'infer {df_name} dataset')
    
    print(f'infer vocab : {list(user_patterns.keys())}') 
    for col_name, user_pattern  in tqdm(user_patterns.items()) :
        
        df[col_name] = df['user_reviews_concat'].apply(lambda x : prompting.prompt_pred(str(x), user_pattern))

    print(f'infer vocab : {list(item_patterns.keys())}') 
    for col_name, item_pattern  in tqdm(item_patterns.items()) :

        df[col_name] = df['item_reviews_concat'].apply(lambda x : prompting.prompt_pred(str(x), item_pattern))

 
    print('saving_complete')

    df.to_pickle(f'{pickle_path}/{pickle_file_name}_{df_name}.pkl')  




if __name__ == '__main__':

    if Prompt_Config.model_path == "roberta-base":

        print("prompting - roberta")

        prompting= Prompting_roberta(device=Prompt_Config.device, k=Prompt_Config.top_k, model=Prompt_Config.model_path)

        train = pd.read_csv(Prompt_Config.train_file)
        valid = pd.read_csv(Prompt_Config.valid_file)
        test  = pd.read_csv(Prompt_Config.test_file)


        #dfs = {'train':train, 'valid':valid, 'test':test}
        dfs = {'test':test}
        
        for df_name , df in dfs.items() :

            infer_vocab(df_name, df, Prompt_Config.user_patterns, Prompt_Config.item_patterns, 
                        Prompt_Config.pickle_path, Prompt_Config.pickle_file_name)


    