import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AutoTokenizer
from nltk.tokenize import sent_tokenize
from torch.utils.data import Dataset, DataLoader
import numpy as np
from nltk.tokenize import sent_tokenize
import random
import shap
import torch.nn as nn
import torch.nn.functional as F
from nltk import sent_tokenize
import pickle
import torch.optim as optim
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import pickle
with open("/dshome/ddualab/yuho/AAAI/Extract_shap/500trainset_sentence_in_token_xai_value.pickle","rb") as fr: #load shap value
    shap_1 = pickle.load(fr)

    
with open("/dshome/ddualab/yuho/AAAI/Extract_shap/500trainset_sentence_len.pickle","rb") as fr: #load sentence length
    sen_len_1 = pickle.load(fr)

token_shap = shap_1['token_xai']
sentence_len = sen_len_1['sentence_length']

from sklearn.metrics import *
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
ckpt = torch.load('/dshome/ddualab/yuho/AAAI/save_model/bert_base/5epoch(87.92)_500train_base.pt', map_location=torch.device('cuda')) 
bert_base.load_state_dict(ckpt['state_dict'])
pooler_output = bert_base.bert
pooler_output = pooler_output.cuda()


train_df = pd.read_csv("/dshome/ddualab/yuho/AAAI/Dataset/semeval_train_500.csv")
class dataset(Dataset):
    def __init__(self, data, mode):
        self.data = data
        self.mode = mode
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        text = self.data.iloc[idx, 0]
        #label = self.data.iloc[idx,1]
        tokenized = tokenizer(text, padding="max_length", truncation=True, max_length=512)
        input_ids = torch.tensor(tokenized['input_ids'])
        token_type_ids = torch.tensor(tokenized['token_type_ids'])
        attention_mask = torch.tensor(tokenized['attention_mask'])

        return input_ids, token_type_ids, attention_mask#, text
    
def mean_pooling_add_token_xai(hidden_states, mask, shap, token_length): 
    input_mask_expanded = mask.unsqueeze(-1).expand(hidden_states.size()).float() #차원 맞춰줌
    multi_tensor = hidden_states * input_mask_expanded
    for idx, token_embedding in enumerate(multi_tensor,0):
        if idx < token_length:
            #print(multi_tensor[])
            new_token_embedding = multi_tensor[idx+1] + shap[idx].cuda()
            multi_tensor[idx+1] = new_token_embedding
            new_token_embedding = torch.FloatTensor()
        else:
            multi_tensor[idx] = token_embedding
    return torch.sum(multi_tensor, 0) / torch.clamp(input_mask_expanded.sum(0), min=1e-9) #분모가 0이 되지 않게 최소값 설정

def mean_pooling(hidden_states, mask):
    input_mask_expanded = mask.unsqueeze(-1).expand(hidden_states.size()).float() #차원 맞춰줌
    multi_tensor = hidden_states * input_mask_expanded
    return torch.sum(multi_tensor, 0) / torch.clamp(input_mask_expanded.sum(0), min=1e-9) #분모가 0이 되지 않게 최소값 설정

def mean_pooler(data_loader, shap_value):
    token_shap = shap_value
    mean_pool = []
    token_len = []
    shap_V = []
    train_loader = data_loader
    for input_ids, token_type_ids, attention_mask in train_loader:
        input_ids = input_ids.cuda()
        attention_mask = attention_mask.cuda()
        token_type_ids = token_type_ids.cuda()

        with torch.no_grad():
            outputs = pooler_output(input_ids = input_ids,
                                    attention_mask = attention_mask,
                                    token_type_ids = token_type_ids)
            for num, mini in enumerate(input_ids, 0):
                masks = attention_mask[num]
                h_s = outputs.last_hidden_state[num]
                shap = token_shap[num]
                token_length = token_shap[num].size()[0]
                token_len.append(token_length)
                shap_V.append(shap)
                mean_pool.append(mean_pooling_add_token_xai(h_s, masks,shap,token_length))
    return mean_pool

def mean_token_array(pooling_value):
    mean_pooling = pooling_value
    mean_tokens = []
    for cls_token in mean_pooling:
        cls_token = np.array(cls_token.cpu())
        mean_tokens.append(cls_token)
    return mean_tokens

#obtain document_mean_embedding (trained bert (trainset_2000))
def doc_mean_embedding(df_name, file_names, shap):
    
    file_name = file_names
    train_df = df_name
    train_dataset = dataset(train_df,'train')
    train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = False, num_workers = 2)
    input_shap = shap
    
    answer_doc_mean = mean_pooler(train_loader, input_shap)
    df_poolervalue_input = mean_token_array(answer_doc_mean)
    train_df['doc_mean_pooling'] = df_poolervalue_input
    #train_df.to_csv('scarce_dataset_document_mean_pooling(2000)_.csv', index = False)
    doc_mean_embedding = []

    for i in answer_doc_mean:
        doc_mean_embedding.append(i.mean().cpu().tolist())
        
    doc_emb = {}
    doc_emb['document_embedding'] = doc_mean_embedding
    with open("{}_document_mean_embedding.pickle".format(file_name),"wb") as fw:
        pickle.dump(doc_emb, fw)
        
    train_df['doc_mean_embedding'] = doc_mean_embedding
    save_df = train_df[['text','doc_mean_embedding','label']]
    save_df.to_csv("{}_contained_document_embedding_dataset.csv".format(file_name),index = False)
    
    return save_df

#obtain sentence_embedding (trained bert (trainset_2000))
def sentence_embedding(df_name, file_names, shap):
    
    file_name = file_names
    train_df = df_name
    shap_value = shap
    
    all_sentence = []
    
    all_review = train_df.text.values
    
    print(len(all_review))
    
    for doc in all_review:
        for sentence in sent_tokenize(doc):
            all_sentence.append(sentence)
    all_sentence_df = pd.DataFrame(all_sentence,columns = ['sentence'])

    train_dataset = dataset(all_sentence_df,'train')
    train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = False, num_workers = 2)
    answer_doc_mean = mean_pooler(train_loader, shap_value)
    df_poolervalue_input = mean_token_array(answer_doc_mean) # 19039개 DOCUMENT -> SENTENCE, sentences embedding
    all_sentence_df['sentence_embedding']=df_poolervalue_input
    sen_emb = {}
    sen_emb['sentence_embedding'] = df_poolervalue_input
    with open("{}_sentence_embedding.pickle".format(file_name),"wb") as fw:
        pickle.dump(sen_emb, fw)

    return df_poolervalue_input, all_sentence_df

train_df_document_embedding = doc_mean_embedding(train_df, "sem-eval_500dataset", token_shap)
train_df_sentence_embedding = sentence_embedding(train_df, "sem-eval_500dataset", token_shap)