import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AutoTokenizer
from nltk.tokenize import sent_tokenize
from torch.utils.data import Dataset, DataLoader
import numpy as np
from nltk.tokenize import sent_tokenize
import random
import torch.nn as nn
import torch.nn.functional as F
from nltk import sent_tokenize
import pickle
import torch.optim as optim
import time
from train_utils import (cal_running_avg_loss, eta, progress_bar,
                        time_since, user_friendly_time)
import os as os

train_df = pd.read_csv("./dataset/semeval_train_500.csv")
# train_df2 = pd.read_csv("./dataset/trust_unlabeled_df.csv")
# train_df = pd.concat([train_df1, train_df2],axis = 0)
test_df = pd.read_csv("./dataset/semeval17_test.csv")
print(train_df)

# train_df = pd.read_csv("dataset_500.csv")
train_df = train_df[['text','label']]
bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased').cuda()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_base = nn.DataParallel(bert_base)

class dataset(Dataset):
    def __init__(self, data, mode):
        self.data = data
        self.mode = mode
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        text = self.data.iloc[idx, 0]
        text = text.lower()
        label = self.data.iloc[idx,1]
        tokenized = tokenizer(text, padding="max_length", truncation=True, max_length=512)
        input_ids = torch.tensor(tokenized['input_ids'])
        token_type_ids = torch.tensor(tokenized['token_type_ids'])
        attention_mask = torch.tensor(tokenized['attention_mask'])
        label = torch.tensor(label)
        if self.mode == 'train':
            label = label.unsqueeze(-1)
            return input_ids, token_type_ids, attention_mask, label
        else:
            return input_ids, token_type_ids, attention_mask, label
train_dataset = dataset(train_df,'train')
train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True, num_workers = 2)
test_dataset = dataset(test_df,'train')
test_loader = DataLoader(test_dataset, batch_size = 512, shuffle = False, num_workers = 2)
#label smooting

model = bert_base

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
# criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.2)
criterion = torch.nn.CrossEntropyLoss()

# def weighted_binary_cross_entropy(output, target, weights=None):
        
#     if weights is not None:
#         assert len(weights) == 2
        
#         loss = weights[0] * (target * torch.log(output)) + \
#                weights[1] * ((1 - target) * torch.log(1 - output))
#     else:
#         loss = target * torch.log(output) + (1 - target) * torch.log(1 - output)

#     return torch.neg(torch.mean(loss))

def train(num_epochs):
    running_avg_loss = 0.0
    total_correct = 0.0
    total_len = 0.0
    
    best_loss = 2e-5
    batch_nb = len(train_loader)
    step = 1
    model.zero_grad()
    for epoch in range(1, num_epochs+1):
        start = time.time()
        model.train()
        for batch_idx, batch in enumerate(train_loader, start=1):
            batch = tuple(v.cuda() for v in batch) # batch : {positive_masks:tensor, negati~~~} => (tensor, tensor, ~~~)
            
            input_ids, token_type_ids, attention_mask, targets = batch
            input_ids = input_ids.cuda()
            token_type_ids = token_type_ids.cuda()
            attention_mask = attention_mask.cuda()
            labels = torch.tensor(targets).cuda()
            
            outputs = model(input_ids = input_ids,
                    attention_mask = attention_mask,
                        token_type_ids = token_type_ids,
                        labels = labels
                        )
            

            logits = outputs.logits
            # zero_tensor = torch.zeros(32,2)
            # for num,i in enumerate(labels,0):
            #     if i == 0:
            #         zero_tensor[num] = torch.FloatTensor([1,0])
            #     else:
            #         zero_tensor[num] = torch.FloatTensor([0,1])
            # zero_tensor = zero_tensor.cuda()
            loss = criterion(logits, labels.view(-1))
            # loss = outputs[0].mean() #hard labeling
            
            pred = torch.argmax(F.softmax(logits), dim=1).unsqueeze(-1)
            correct = pred.eq(labels)
            total_correct += correct.sum().item()
            total_len += len(labels)
            
                
            loss.backward()
            optimizer.step()
            # self.lr_scheduler.step()
            model.zero_grad()
            
            running_avg_loss = cal_running_avg_loss(loss.item(), running_avg_loss)
            
            
            train_acc = total_correct / total_len
            
            msg = "{}/{} {} - ETA : {} - Loss: {:.4f}, Acc: {:.4f}".format(
                batch_idx, batch_nb,
                progress_bar(batch_idx, batch_nb),
                eta(start, batch_idx, batch_nb),
                running_avg_loss, train_acc)
            print(msg, end="\r")
            step += 1
            
        # evaluate model on validation set
        
        val_loss, val_acc = evaluate(msg)
            # if val_nll < best_loss:
            #     best_loss = val_nll
        
        save_model(model, val_loss, epoch, model_dir='./save_model/bert_base/')
    
        
        f1_scores = f1_score
        print("Epoch {} took {} - Train Loss: {:.4f} - Val Loss: "
                "{:.4f} - Train Acc: {:.4f} - Val Acc: {:.4f}".format(epoch,
                                                                    user_friendly_time(time_since(start)),
                                                                    running_avg_loss,
                                                                    val_loss,
                                                                    train_acc,
                                                                    val_acc
                                                                    # f1_macro,
                                                                    # f1_micro
                                                                    ))
from sklearn.metrics import *

def evaluate(msg):
    val_batch_nb = len(test_loader)
    val_losses = []
    total_correct = 0.0
    total_len = 0.0
    label_ls = []
    model.eval()
    for i, batch in enumerate(test_loader, start=1):
        batch = tuple(v.cuda() for v in batch)
        
        input_ids, token_type_ids, attention_mask, targets = batch
        input_ids = input_ids.cuda()
        token_type_ids = token_type_ids.cuda()
        attention_mask = attention_mask.cuda()
        labels = torch.tensor(targets).cuda()
        
        with torch.no_grad():
            outputs = model(input_ids = input_ids,
                    attention_mask = attention_mask,
                        token_type_ids = token_type_ids,
                        labels = targets)
            logits = outputs[1]
            pred = torch.argmax(F.softmax(logits), dim=1).unsqueeze(-1)
            correct = pred.eq(labels)
            # loss = criterion(outputs.view(-1), targets)
            #loss = criterion(logits, labels.view(-1))

            loss = outputs[0].mean()
            # loss = criterion(logits, labels.view(-1))
                    
        msg2 = "{} =>   Evaluating : {}/{}".format(msg, i, val_batch_nb)
        print(msg2, end="\r")
        val_losses.append(loss.item())
        total_correct += correct.sum().item()
        total_len += len(targets)
    

        # label_ls += target.view(-1).tolist()
        # pred_ls += predictions.view(-1).tolist()

    val_acc = total_correct / total_len
    val_loss = np.mean(val_losses)
    #cl_report = classification_report(true_ls, pred_ls)
    # f1_macro = f1_score(label_ls, pred_ls, average='macro', zero_division=1)
    # precision_macro = precision_score(label_ls, pred_ls, average='macro', zero_division=1)
    # recall_macro = recall_score(label_ls, pred_ls, average='macro', zero_division=1)
    # f1_micro = f1_score(label_ls, pred_ls, average='micro', zero_division=1)
    # precision_micro = precision_score(label_ls, pred_ls, average='micro', zero_division=1)
    # recall_micro = recall_score(label_ls, pred_ls, average='micro', zero_division=1)

    return val_loss, val_acc

def save_model(model, loss, epoch, model_dir):
    model_to_save = model.module if hasattr(model, "module") else model
    ckpt = {"state_dict":model_to_save.state_dict()}
    model_save_path = os.path.join(model_dir, "{}_{:.4f}.pt".format(epoch, loss))
    torch.save(ckpt, model_save_path)
    
    
    
train(5)