import pandas as pd
from nltk.tokenize import sent_tokenize
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import transformers
from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline
import torch
import shap
import pickle
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
train = pd.read_csv("semeval_train_500.csv")
train.columns = ['review','sentiment']
doc_split_sen = []
doc_split_sen_len = []
for doc in train.review.values:
    st = sent_tokenize(doc)
    doc_split_sen.append(st)
    doc_split_sen_len.append(len(st))
    
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#BERT_classification = BertForSequenceClassification.from_pretrained('bert-base-uncased')
#IMDB_test = pd.read_csv("IMDB_Testset.csv")

from sklearn.metrics import *
bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
ckpt = torch.load('.save_model/bert_base/3epoch(87.92)_500train_base.pt', map_location=torch.device('cuda')) 
bert_base.load_state_dict(ckpt['state_dict'])
model = bert_base
shap_use_model = (model).cuda()


labels = [x[0] for x in sorted(shap_use_model.config.label2id.items(), key=lambda x: x[1])]
def model_prediction_gpu(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=512, truncation=True) for v in x]).cuda()
    attention_mask = (tv!=0).type(torch.int64).cuda()
    outputs = shap_use_model(tv, attention_mask=attention_mask)[0]
    scores = torch.nn.Softmax(dim=-1)(outputs)
    val = torch.logit(scores).detach().cpu().numpy()

    return val
gpu_explainer = shap.Explainer(model_prediction_gpu, tokenizer, output_names=labels) 

token_shap_value_scarce = []
sentence_in_token_values = []
sentence_token_zip = []
success_count = 0
for num_D,input in (enumerate(train['review'][:],0)):
    success_count = 0
    
    #token_shap_value_scarce = []
    shap_values = gpu_explainer([input])
    shap_value_token = shap_values.data[0] # shap 결과 (토큰 - 단어)
    shap_value_values = shap_values.values[:,:,1] # shap 결과 (토큰 - 단어)
    token_shap_value_scarce.append(shap_value_values[0][1:512])    
    document_in_sentence_strip = []
    for sentence in doc_split_sen[num_D]:
        document_in_sentence_strip.append(sentence.strip().replace(" ",''))
    adding_shap_token = []
    for i in shap_value_token:
        adding_shap_token.append(i.strip())
    token_start = ''
    shap_value_sentence = 0
    token_num = 0
    start_num = 0
    token_num
    token_start = ''
    token = ''
    token_shap_values = shap_values.values[:,:,1]
    #sentence_in_token_values = []
    sen_list = []
    
    
    
    for num, token in enumerate(adding_shap_token[1:],0):
        token_start += token
        # print(token_start)
        # print(success_count)
        if token_start == document_in_sentence_strip[success_count]:
            #shap_value_sentence += shap_value_values[start_num:num+1]
            sentence_in_token_values.append(shap_value_values[0][start_num:num+1].tolist())
            sen_list.append(shap_value_values[0][start_num:num+1].tolist())
            #sentence_shap_value.append(sum(shap_value_sentence))
            token_start = ''
            shap_value_sentence = 0
            start_num = num+1
            success_count += 1
            
            if success_count == len(document_in_sentence_strip):
                break
            
    for num,i in enumerate(sentence_in_token_values,0):
        sentence_in_token_values[num] = torch.FloatTensor(i)
    sentence_token_zip.append(sen_list)
    
sentence_in_token_xai_dict_scarce = {}
sentence_in_token_xai_dict_scarce['token_xai'] = sentence_in_token_values
import pickle
with open("500trainset_sentence_in_token_xai_value.pickle","wb") as fw:
    pickle.dump(sentence_in_token_xai_dict_scarce, fw)

s_len = []
for i in sentence_token_zip:
    s_len.append(len(i))

sentence_length = {}
sentence_length['sentence_length'] = s_len
import pickle
with open("500trainset_sentence_len.pickle","wb") as fw:
    pickle.dump(sentence_length, fw)