import pandas as pd
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
import numpy as np
import random
from nltk import sent_tokenize
import matplotlib.pyplot as plt
import os
import time
import math
from train_utils import (cal_running_avg_loss, eta, progress_bar, time_since, user_friendly_time)
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity


os.environ['CUDA_VISIBLE_DEVICES'] = '0'
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

import pickle

with open("./pickle/sem-eval_500dataset_sentence_embedding.pickle","rb") as fr: # with cls, sep token sentence embedding
    all_sentences_embedding_1 = pickle.load(fr)  

with open("./pickle/unlabeled_dataset_total_embedding_zero_padding(769).pickle","rb") as fr:
    all_sentences_embedding_2 = pickle.load(fr) 

correct_train_df = pd.read_csv("labeled_sentence_embedding_text.csv")
correct_train_df1 = pd.read_csv("unlabeled_sentence_embedding_text.csv")
correct_train_df['sentence_embeddings'] = all_sentences_embedding_1['sentence_embedding']
correct_train_df1['sentence_embeddings'] = all_sentences_embedding_2

correct_doc_train_df = pd.read_csv("sem-eval_500dataset_contained_document_embedding_dataset.csv")
correct_doc_train_df.columns = ['review','doc_mean_embedding','label']
sentence_length = []
for doc in correct_doc_train_df.review:
    docu = sent_tokenize(doc)
    sentence_length.append(len(docu))
correct_doc_train_df['sentence_length'] = sentence_length

all_sentences = []
for i in correct_doc_train_df.review.values:
    sentence = sent_tokenize(i)
    for j in sentence:
        all_sentences.append(j)

df_doc = correct_doc_train_df

make_sentence_count = np.insert(df_doc.sentence_length.values,0,0).tolist()

df_doc
document_emb = df_doc.doc_mean_embedding.values
doc_emb = []
sentence_length = df_doc.sentence_length.values
for num, i in enumerate(df_doc.sentence_length,0):
    for j in range(i):
        doc_emb.append(document_emb[num])
correct_train_df['doc_emb'] = doc_emb

full_sequence_input = []
for num, i in enumerate(correct_train_df.sentence_embeddings.values,0):
    full_sequence_input.append(np.insert(i,0,doc_emb[num]))
correct_train_df['sequence_embedding'] = full_sequence_input

a = []
sum_ = 0
start_token = 0
start_ = 0
for i in range(0,len(df_doc)):
    start_token += sum_
    if i == 0:
        sum_ = sum(make_sentence_count[i:i+2])
        a.append(correct_train_df.sequence_embedding.values[start_token:sum_])

    if i != 0:
        start_ += make_sentence_count[i]
        sum_ = sum(make_sentence_count[0:i+2])
        a.append(correct_train_df.sequence_embedding.values[start_:sum_])

df_doc['total_emb']= a

all_doc_train_df = pd.read_csv("./dataset/semeval_train_500.csv")
all_doc_train_df.columns = ['review','sentiment']
sentence_length = []
for doc in correct_doc_train_df.review:
    docu = sent_tokenize(doc)
    sentence_length.append(len(docu))
correct_doc_train_df['sentence_length'] = sentence_length

correct_train_df.columns = ['sentence','emebdding','sentence_embedding','doc_emb','sequence_embedding']

correct_train_df = correct_train_df[['sentence','sequence_embedding']]

correct_train_df1 = correct_train_df1[['sentence','sentence_embeddings']]

correct_train_df.columns = ['review','sequence_embedding']
correct_train_df1.columns = ['review','sequence_embedding']
correct_train_df_2 = pd.concat([correct_train_df,correct_train_df1], axis = 0).reset_index(drop = True)


add_amazon_all_sentences = []
for i in correct_train_df_2.review.values:
    add_amazon_all_sentences.append(i)
    
correct_train_df_2
correct_train_df_2 = correct_train_df_2.drop_duplicates(subset = 'review', keep = 'first', inplace = False, ignore_index = True)


correct_train_df_2.columns = ['sentence','sentence_embeddings']
sent_df = correct_train_df_2

sent_df = sent_df[:]

sentence_to_dic = dict(zip(sent_df['sentence'], sent_df.index))

sentence_embeddings = sent_df.sentence_embeddings.values

len(sentence_embeddings)

with open('sentence_to_dic.pickle', 'wb') as fw:
    pickle.dump(sentence_to_dic, fw)
    
with open('sentence_embeddings_for_augmentation.pickle', 'wb') as fw:
    pickle.dump(sentence_embeddings, fw)
    
with open("./pickle/sentence_to_dic.pickle","rb") as fr: # with cls, sep token sentence embedding
    sentence_to_dic = pickle.load(fr)  
with open("./pickle/sentence_embeddings_for_augmentation.pickle","rb") as fr: # with cls, sep token sentence embedding
    sentence_embeddings = pickle.load(fr)  
with open("./pickle/co_mat.pickle","rb") as fr: # with cls, sep token sentence embedding
    co_mat = pickle.load(fr)  

sentence_matrix = sentence_embeddings.tolist()
sentence_matrix = np.array(sentence_matrix)
# co_mat = cosine_similarity(sentence_matrix, sentence_matrix)

sentence_length_array = df_doc.sentence_length.values
valuable_sentence_length_list = []
for i in sentence_length_array:
    if i >= 1:
        valuable_sentence_length_list.append(i)
valuable_sentence_length_array = np.array(valuable_sentence_length_list)

with open("./pickle/critical_idx.pickle","rb") as fr: # index of high shap value sentences 
    critical_idx = pickle.load(fr)
critical_idx = critical_idx['idx']

add_all_sentences = list(sent_df.sentence.values)

class dataset(Dataset):
    def __init__(self, data):
        self.data = data
        #self.mode = mode
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        input = self.data.iloc[idx,2]
        input = torch.Tensor(input)
        label = self.data.iloc[idx,1]
        label = torch.tensor(label)
        return input, label

def seed_everything(seed):
    torch.manual_seed(seed) #torch를 거치는 모든 난수들의 생성순서를 고정한다
    torch.cuda.manual_seed(seed) #cuda를 사용하는 메소드들의 난수시드는 따로 고정해줘야한다 
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True #딥러닝에 특화된 CuDNN의 난수시드도 고정 
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed) #numpy를 사용할 경우 고정
    random.seed(seed) #파이썬 자체 모듈 random 모듈의 시드 고정
seed_everything(42)

class LSTM(nn.Module):
    
    def __init__(self, embedding_dim, hidden_size, num_layers):
        super(LSTM, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size 
        self.input_size = embedding_dim
        self.lstm_layer = nn.LSTM(input_size = self.input_size,
                                hidden_size = self.hidden_size,
                                num_layers = self.num_layers,
                                bidirectional=True,
                                dropout=0.3,
                                batch_first = True)
        
        self.fc_layer = nn.Sequential(nn.Linear(hidden_size*2, 2048),
                                    nn.Dropout(0.2),
                                    nn.Tanh(),
                                    nn.Linear(2048, 1024),
                                    nn.Dropout(0.1),
                                    nn.Tanh(),
                                    nn.Linear(1024, 512),
                                    nn.Dropout(0.2),
                                    nn.Tanh(),
                                    nn.Linear(512,2),
                                    nn.Softmax())
    
    def forward(self, x):
        h0 = torch.zeros(2 * self.num_layers, x.size(0), self.hidden_size).cuda() # (BATCH SIZE, SEQ_LENGTH, HIDDEN_SIZE)
        c0 = torch.zeros(2 * self.num_layers, x.size(0), self.hidden_size).cuda()
        output, (h_0,c_0) = self.lstm_layer(x, (h0,c0))
        last_output = output[:,-1,:]
        last_output = self.fc_layer(last_output)
        return last_output
    
    
lstm_classifier = LSTM(
    embedding_dim = 769,
    hidden_size = 1024,
    num_layers = 4,
)


lstm_classifier = lstm_classifier.cuda()
criterion = torch.nn.CrossEntropyLoss().cuda()

def new_augmenatation(count_of_review, cut_off_standard, a, m):
    alpha = a
    m_ = m
    import math
    use_sentence_index_zip = []
    make_review_count = count_of_review
    index_cut_off = cut_off_standard
    
    for all_count in range(make_review_count):
        
        start_usable_index = critical_idx
        co_mat_clone = np.copy(co_mat, order='K')

        new_cosine_mat = np.copy(co_mat_clone, order='K')
        
        select_all_idx = []
        start_usable_idx = np.random.choice(np.array(critical_idx),1,replace=True).item()
        this_turn_review_sentence_length = np.random.choice(valuable_sentence_length_array,1,replace = True).item()
        key_id = []
        choose_prob_division_max_prob = 0.0
        
        for count_of_sentences in range(this_turn_review_sentence_length):
            print("start")
            print(this_turn_review_sentence_length)
            cut_off = 0.0
            no_zero_count = 0
            no_zero_count_s = 0
            trash_prob = 0
            trash_prob_s = 0
            max_probability = 0
            
            if count_of_sentences == 0:
                review_start_idx = np.random.choice(start_usable_index, 1, replace = False).item()
                select_all_idx.append(review_start_idx)

                standard_cosine_matrix_based_first_choose_idx = new_cosine_mat[review_start_idx]
                # print("----------------------------------------------------------------")
            elif count_of_sentences == 1:
                
                sorting_idx_array = np.argsort(-standard_cosine_matrix_based_first_choose_idx) # sorting index by new cosine matrix probability
                # print(sorting_idx_array)
                
                
                this_tutn_choose_sentence_idx = sorting_idx_array[1]
                select_all_idx.append(this_tutn_choose_sentence_idx)
                
                sorting_idx_array = np.setdiff1d(sorting_idx_array,select_all_idx, assume_unique=True) #
 
                max_prob = standard_cosine_matrix_based_first_choose_idx.max()
                select_prob = standard_cosine_matrix_based_first_choose_idx[this_tutn_choose_sentence_idx]
                
                sentences_similarity = select_prob / max_prob
            
                
                for turn_idx, proba in enumerate(standard_cosine_matrix_based_first_choose_idx,0):
                    if turn_idx in select_all_idx:

                        trash_prob += standard_cosine_matrix_based_first_choose_idx[turn_idx]
                        standard_cosine_matrix_based_first_choose_idx[turn_idx] = 0.0
                    else:
                        no_zero_count += 1
                
                add_trash_prob = (trash_prob.item())/(standard_cosine_matrix_based_first_choose_idx.shape[0]-len(select_all_idx))
                        
                for turn_idx, proba in enumerate(standard_cosine_matrix_based_first_choose_idx,0):
                    if turn_idx not in select_all_idx:
                        standard_cosine_matrix_based_first_choose_idx[turn_idx] = standard_cosine_matrix_based_first_choose_idx[turn_idx] + add_trash_prob
            # 2번째로 뽑인 idx의 확률과 실제 argmax확률을 비교해서 유사도를 측정함
            # 뽑힌 idx의 확률을 0으로 만들고 새로운 확률을 가지는 cosine matrix를 생성하고 다음차례로 넘김
                # print("----------------------------------------------------------------")
            else:
                alpha = 0.0

                max_prob = standard_cosine_matrix_based_first_choose_idx.max()
                # print("max_prob :", max_prob)
                print("sentences_similarity :", sentences_similarity)
                alpha = sentences_similarity

                key_idx = int((alpha**1)*co_mat.shape[0])
                k_ = alpha
                m_in = m_
                e_ = math.e
                
                key_idx = ((e_**(10*(k_*alpha-m_in)))/(10+e_**(10*(k_*alpha-m_in))))
                # print("key:", key_idx)
                key_idx = math.ceil(key_idx*co_mat.shape[0])
                key_id.append(key_idx)

                cutting_sorting_idx_array = sorting_idx_array[:key_idx]
                
                print(index_cut_off)
                
                while (1):
                    print('len(select_all_idx) :',len(select_all_idx))
                    this_turn_choose_sentence_idx = np.random.choice(cutting_sorting_idx_array,1,replace = False).item()
                    select_prob = standard_cosine_matrix_based_first_choose_idx[this_turn_choose_sentence_idx]
                    sentences_similarity = select_prob / max_prob
                    
                    print('sentences_similarity :', sentences_similarity)
                    
                    if sentences_similarity <= index_cut_off:
                        this_turn_choose_sentence_idx = np.random.choice(cutting_sorting_idx_array,1,replace = False).item()
                        
                    if sentences_similarity > index_cut_off:
                        select_all_idx.append(this_turn_choose_sentence_idx)
                        break
                        
                
                    
                sorting_idx_array = np.setdiff1d(sorting_idx_array,select_all_idx, assume_unique=True) #
            

                for turn_idx, proba in enumerate(standard_cosine_matrix_based_first_choose_idx,0):
                    if turn_idx in select_all_idx:
                        # print('turn_idx :', turn_idx)
                        # print('proba :', proba)
                        trash_prob += standard_cosine_matrix_based_first_choose_idx[turn_idx]
                        standard_cosine_matrix_based_first_choose_idx[turn_idx] = 0.0
                    else:
                        no_zero_count += 1
                
                add_trash_prob = (trash_prob.item())/(standard_cosine_matrix_based_first_choose_idx.shape[0]-len(select_all_idx))
                        
                for turn_idx, proba in enumerate(standard_cosine_matrix_based_first_choose_idx,0):
                    if turn_idx not in select_all_idx:
                        standard_cosine_matrix_based_first_choose_idx[turn_idx] = standard_cosine_matrix_based_first_choose_idx[turn_idx] + add_trash_prob
                
                print("------------------------------------------------------------------------")
        # print(key_id)
    return use_sentence_index_zip


def get_recommendatinos(augmentation_index_list, all_embedding, all_sentences):
    
    
    add_amazon_all_sentences = all_sentences
    all_sentences_embedding = all_embedding
    sentence_index_list = augmentation_index_list
    
    review_list = []
    sentence_embedding = []
    tokenize_len = 0
    review = ''
    count = 0
    make_review = []
    sequence_max_length = 7
    emb_size = all_sentences_embedding[0].shape[0]
    # print(emb_size)
    for review_index in sentence_index_list:
        
        sequence_length = len(review_index)
        # print(sequence_length)
        zero_length = sequence_max_length-sequence_length
        # print(zero_length)
        
        make_zero_vector = np.zeros(shape = (zero_length, emb_size), dtype = np.float64)
        
        semi_review = []
        semi_sentence_embedding = []
        
        for sentence_index in review_index:
            
            review += add_amazon_all_sentences[sentence_index] + ' '
            semi_sentence_embedding.append(all_sentences_embedding[sentence_index])
        
        semi_review.append(review)
        np_semi_sentence_embedding = np.array(semi_sentence_embedding)
        last_stack = np.concatenate((np_semi_sentence_embedding, make_zero_vector),axis = 0)
        sentence_embedding.append(last_stack)
        
        review = ''
        review_list.append(semi_review)
    
    np_sentence_embedding = np.array(sentence_embedding)
    
    return review_list, sentence_embedding, np_sentence_embedding

def save_pickle_dataset(augment_data,a,m):
    
    review_list_ablation, sentence_embedding_ablation, np_sentence_embedding_ablation = get_recommendatinos(augment_data,sentence_embeddings,add_all_sentences)
    ablation = pd.DataFrame(review_list_ablation, columns = ['review'])
    ablation['total_emb'] = sentence_embedding_ablation
    with open('./pickle/ablation(sem)(0.5)_a{}_m{}.pickle'.format(a,m), mode='wb') as f: #() is dataset name
        pickle.dump(ablation, f)
    
    return ablation

augmentation_f = new_augmenatation(1000,0.5,3,1)
augdataset = save_pickle_dataset(augmentation_f,3,1)

Augmented_dataset = augdataset
class dataset(Dataset):
    def __init__(self, data, mode):
        self.data = data
        self.mode = mode
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        input = self.data.iloc[idx,1]
        input = torch.Tensor(input)
        #label = self.data.iloc[idx,1]
        #label = torch.tensor(label)
        if self.mode == 'train':
            #label = label.unsqueeze(-1)
            return input #, label
        else:
            return input#, label
        
aug_dataset = dataset(Augmented_dataset,'train')
aug_loader = DataLoader(aug_dataset, batch_size = 32, shuffle = False, num_workers=4)

ckpt = torch.load('./save/lstm/lstm.pt', map_location=torch.device('cuda'))
model = lstm_classifier
model.load_state_dict(ckpt['state_dict'])

model.eval()
with torch.no_grad():
    pred_label_lstm = []
    for batch in aug_loader:
        input_embedding = batch
        input_embedding = input_embedding.cuda()
        output = model(input_embedding)
        if output.max().item() >= 0.99:
            pred = torch.argmax(output,dim=-1).tolist()
            pred_label_lstm.append(pred)
        elif output.max().item() < 0.99:
            pred = -1
            pred_label_lstm.append(pred)
        # pred = torch.argmax(output, dim=-1).tolist()
        # pred_label_lstm.append(pred)

pred_all_label_lstm = []
for i in pred_label_lstm:
    for j in i:
        k = j
        pred_all_label_lstm.append(k)
        
Augmented_dataset['sentiment'] = pred_all_label_lstm
Augmented_dataset = Augmented_dataset[['review','sentiment']]


correct_doc_train_df.columns = ['review','doc_mean_embedding','sentiment','sentence_length']
train_doc = correct_doc_train_df[['review','sentiment']]
all_doc = all_doc_train_df
Augmented_dataset_lstm = pd.concat([all_doc,Augmented_dataset],axis = 0, ignore_index = True)
Augmented_dataset_lstm.sentiment.value_counts()

Augmented_dataset_lstm.to_csv("./dataset/augmentation/augmented_dataset.csv", index = False)

print(Augmented_dataset_lstm.sentiment.value_counts())
