import pandas as pd
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
import numpy as np
import random
from nltk import sent_tokenize
import matplotlib.pyplot as plt
import os
import time
from train_utils import (cal_running_avg_loss, eta, progress_bar, time_since, user_friendly_time)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# with open("407dataset_token_embedding_add_shap_sentence_embedding_real.pickle","rb") as fr:
#     all_sentences_embedding = pickle.load(fr)
with open("sem-eval_500dataset_sentence_embedding.pickle","rb") as fr:
    all_sentences_embedding = pickle.load(fr)
#correct_train_df = pd.read_csv("407_dataset_all_sentence.csv")
correct_train_df = pd.read_csv("labeled_sentence_embedding_text.csv")
correct_train_df['sentence_embeddings'] = all_sentences_embedding['sentence_embedding']
#correct_doc_train_df = pd.read_csv("correct_dataset_407_doc_emb.csv")
correct_doc_train_df = pd.read_csv("sem-eval_500dataset_contained_document_embedding_dataset.csv")

sentence_length = []
for doc in correct_doc_train_df.text:
    docu = sent_tokenize(doc)
    sentence_length.append(len(docu))
correct_doc_train_df['sentence_length'] = sentence_length

all_sentences = []
for i in correct_doc_train_df.text.values:
    sentence = sent_tokenize(i)
    for j in sentence:
        all_sentences.append(j)

df_doc = correct_doc_train_df
df_doc.columns = ['review','doc_mean_embedding','sentiment','sentence_length']
make_sentence_count = np.insert(df_doc.sentence_length.values,0,0).tolist()

df_doc
document_emb = df_doc.doc_mean_embedding.values
doc_emb = []
sentence_length = df_doc.sentence_length.values
for num, i in enumerate(df_doc.sentence_length,0):
    for j in range(i):
        doc_emb.append(document_emb[num])
correct_train_df['document_embedding'] = doc_emb

full_sequence_input = []
for num, i in enumerate(correct_train_df.sentence_embeddings.values,0):
    full_sequence_input.append(np.insert(i,0,doc_emb[num]))
correct_train_df['sequence_embedding'] = full_sequence_input

a = []
sum_ = 0
start_token = 0
start_ = 0
for i in range(0,len(df_doc)):
    start_token += sum_
    if i == 0:
        sum_ = sum(make_sentence_count[i:i+2])
        a.append(correct_train_df.sequence_embedding.values[start_token:sum_])
        #print(start_,sum_)
    if i != 0:
        start_ += make_sentence_count[i]
        #print(i)
        #print(start_)
        sum_ = sum(make_sentence_count[0:i+2])
        #print(sum_)
        a.append(correct_train_df.sequence_embedding.values[start_:sum_])
        #print(start_,sum_)
df_doc['total_emb']= a

sen_embeddings = {}
sen_embeddings['sentence_embedding'] = full_sequence_input
with open("500dataset_sentence_embedding_with_doc_emb(769_new_idea_real).pickle","wb") as fw:
    pickle.dump(sen_embeddings, fw)
    
document_sentence_embedding = [] # 각 문서 내 sentence embedding stack
sentence_max_length = 7

for i in range(len(df_doc)):
    sequence_length = np.stack((df_doc['total_emb'].values[i]), axis = 0).shape[0]
    if sequence_length <= sentence_max_length:
        embedding_stack = (np.stack((df_doc['total_emb'].values[i])))
        zero_length = shape=sentence_max_length - embedding_stack.shape[0]
        embedding_size = embedding_stack.shape[-1]
        make_zero_vector = np.zeros(shape = (zero_length,embedding_size), dtype = np.float64)
        last_stack = np.concatenate((embedding_stack, make_zero_vector), axis = 0)
        document_sentence_embedding.append(last_stack)
    else:
        embedding_stack = (np.stack((df_doc['total_emb'].values[i])))
        document_sentence_embedding.append(embedding_stack[:sentence_max_length])

doc_embedding_zero_padding = {}
doc_embedding_zero_padding['doc_embedding'] = document_sentence_embedding

with open("500labeled_dataset_total_embedding_zero_padding(7len).pickle","wb") as fw:
    pickle.dump(doc_embedding_zero_padding, fw)

with open("500labeled_dataset_total_embedding_zero_padding(7len).pickle","rb") as fr: # 
    document_sentence_embedding = pickle.load(fr)
df_doc['total_emb'] = document_sentence_embedding['doc_embedding']

train_d = df_doc[['review','sentiment','total_emb']]

train_df = train_d
test_df = train_d[:]

class dataset(Dataset):
    def __init__(self, data):
        self.data = data
        #self.mode = mode
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        input = self.data.iloc[idx,2]
        input = torch.Tensor(input)
        label = self.data.iloc[idx,1]
        label = torch.tensor(label)

        return input, label

def seed_everything(seed):
    torch.manual_seed(seed) #torch를 거치는 모든 난수들의 생성순서를 고정한다
    torch.cuda.manual_seed(seed) #cuda를 사용하는 메소드들의 난수시드는 따로 고정해줘야한다 
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True #딥러닝에 특화된 CuDNN의 난수시드도 고정 
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed) #numpy를 사용할 경우 고정
    random.seed(seed) #파이썬 자체 모듈 random 모듈의 시드 고정
seed_everything(42)

class LSTM(nn.Module):
    
    def __init__(self, embedding_dim, hidden_size, num_layers):
        super(LSTM, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size 
        self.input_size = embedding_dim
        self.lstm_layer = nn.LSTM(input_size = self.input_size,
                                hidden_size = self.hidden_size,
                                num_layers = self.num_layers,
                                bidirectional=True,
                                dropout=0.3,
                                batch_first = True)
        
        self.fc_layer = nn.Sequential(nn.Linear(hidden_size*2, 2048),
                                    nn.Dropout(0.2),
                                    nn.Tanh(),
                                    nn.Linear(2048, 1024),
                                    nn.Dropout(0.1),
                                    nn.Tanh(),
                                    nn.Linear(1024, 512),
                                    nn.Dropout(0.2),
                                    nn.Tanh(),
                                    nn.Linear(512,2),
                                    nn.Softmax())

    def forward(self,x):
        h0 = torch.zeros(2 * self.num_layers, x.size(0), self.hidden_size).cuda() # (BATCH SIZE, SEQ_LENGTH, HIDDEN_SIZE)
        c0 = torch.zeros(2 * self.num_layers, x.size(0), self.hidden_size).cuda()
        output, (h_0,c_0) = self.lstm_layer(x, (h0,c0))
        last_output = output[:,-1,:]
        last_output = self.fc_layer(last_output)
        return last_output
    
lstm_classifier = LSTM(
    embedding_dim=769,
    hidden_size = 1024,
    num_layers=4,
)
lstm_classifier = lstm_classifier.cuda()
criterion = torch.nn.CrossEntropyLoss().cuda()

train_dataset = dataset(train_df)
test_dataset = dataset(test_df)
train_loader =  DataLoader(train_dataset, batch_size = 2, shuffle = True, num_workers = 2)
test_loader =  DataLoader(test_dataset, batch_size = 64, shuffle = False, num_workers = 2)

#label smooting

model = lstm_classifier

optimizer = torch.optim.Adam(model.parameters(), lr=0.000015)
criterion = torch.nn.CrossEntropyLoss().cuda()

def train(num_epochs):
    running_avg_loss = 0.0
    total_correct = 0.0
    total_len = 0.0
    
    best_loss = 1e5
    batch_nb = len(train_loader)
    step = 1
    model.zero_grad()
    for epoch in range(1, num_epochs+1):
        start = time.time()
        model.train()
        for batch_idx, batch in enumerate(train_loader, start=1):
            
            input_embedding = batch[0]
            input_embedding = input_embedding.cuda()
            label = batch[1]
            label = label.cuda()
            output = lstm_classifier(input_embedding)
            pred = torch.argmax(output, dim=-1)
            
            loss = criterion(output, label)
            
    
            correct = pred.eq(label)
            total_correct += correct.sum().item()
            total_len += len(label)
            
                
            loss.backward()
            optimizer.step()
            # self.lr_scheduler.step()
            model.zero_grad()
            
            running_avg_loss = cal_running_avg_loss(loss.item(), running_avg_loss)
            
            
            train_acc = total_correct / total_len
            
            msg = "{}/{} {} - ETA : {} - Loss: {:.4f}, Acc: {:.4f}".format(
                batch_idx, batch_nb,
                progress_bar(batch_idx, batch_nb),
                eta(start, batch_idx, batch_nb),
                running_avg_loss, train_acc)
            print(msg, end="\r")
            
            
            step += 1
            
        
        val_loss,val_acc = evaluate(msg)

        save_model(model, val_loss, epoch, model_dir='./save_model/lstm/') #IDEA save model
        print("Epoch {} took {} - Train Loss: {:.4f} - Val Loss: "
                "{:.4f} - Train Acc: {:.4f} - Val Acc: {:.4f}".format(epoch,
                                                                    user_friendly_time(time_since(start)),
                                                                    running_avg_loss,
                                                                    val_loss,
                                                                    train_acc,
                                                                    val_acc
                                                                    ))
        
from sklearn.metrics import *

def evaluate(msg):
    val_batch_nb = len(test_loader)
    val_losses = []
    total_correct = 0.0
    total_len = 0.0
    len_label = 0.0
    label_ls = []
    model.eval()
    for i, batch in enumerate(test_loader, start=1):
        
        input_embedding = batch[0]
        input_embedding = input_embedding.cuda()
        label = batch[1]
        label = label.cuda()

        with torch.no_grad():

            output = lstm_classifier(input_embedding)
            pred = torch.argmax(output, dim=-1)
            loss = criterion(output, label)
            #loss = mse(output, label)
            correct = pred.eq(label)
                    
        msg2 = "{} =>   Evaluating : {}/{}".format(msg, i, val_batch_nb)
        print(msg2, end="\r")
        val_losses.append(loss.item())
        total_correct += correct.sum().item()
        total_len += len(label)

    val_acc = total_correct / total_len
    val_loss = np.mean(val_losses)

    return val_loss, val_acc

def save_model(model, loss, epoch, model_dir):
    model_to_save = model.module if hasattr(model, "module") else model
    ckpt = {"state_dict":model_to_save.state_dict()}
    model_save_path = os.path.join(model_dir, "{}_{:.4f}.pt".format(epoch, loss))
    torch.save(ckpt, model_save_path)
    
train(10)

ckpt = torch.load('./train_model/save/lstm/lstm.pt', map_location=torch.device('cuda'))
model = lstm_classifier
model.load_state_dict(ckpt['state_dict'])
