import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVe
from tqdm import tqdm
import numpy as np
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
import re, os, shutil
import datetime
from torchtext.datasets import AG_NEWS

train_iter = AG_NEWS(split='train')
test_iter = AG_NEWS(split='test')

from model import RNN, GRU, LSTM


GLOVE_DIM = 100
GLOVE = GloVe(name='6B', dim=GLOVE_DIM, cache='./.vector_cache')


class AGNewsDataset(Dataset):
    def __init__(self, is_train=True):
        super().__init__()
        self.tokenizer = get_tokenizer('basic_english')
        self.x, self.y = self.load_to_local(is_train)
        

    def load_to_local(self, is_train=True):
        if is_train:
            train_data = []
            train_label = []
            for iter in train_iter:
                train_data.append(iter[1])
                train_label.append(iter[0]-1)
            return train_data, train_label
        else:
            test_data = []
            test_label = []
            for iter in test_iter:
                test_data.append(iter[1])
                test_label.append(iter[0]-1)
        return test_data, test_label
    
    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        sentence = self.tokenizer(self.x[index])
        x = GLOVE.get_vecs_by_tokens(sentence)
        label = self.y[index]
        return x, label


def get_dataloader(batch_size):
    def collate_fn(batch):
        x, y = zip(*batch)
        x_pad = pad_sequence(x, batch_first=True)
        y = torch.Tensor(y)
        return x_pad, y
    train_dataloader = DataLoader(AGNewsDataset(True), batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4)
    test_dataloader = DataLoader(AGNewsDataset(False), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    return train_dataloader, test_dataloader


def save_model(state_dict, is_best, log_dir):
    torch.save(state_dict, log_dir+'/latest.pth')
    if is_best:
        torch.save(state_dict, log_dir+'/best.pth')


def save_pys(log_dir):
    for filename in os.listdir('./'):
        if filename.endswith(".py"):
            src_path = os.path.join('./', filename)
            dst_path = os.path.join(log_dir, filename)
            shutil.copy(src_path, dst_path)


def main():
    with torch.no_grad():
        device = torch.device('cuda:0')
        current_time = re.sub(r'\D', '', str(datetime.datetime.now())[4:-7])
        log_dir = './lr_logs/' + current_time + '_lr'
        writer = SummaryWriter(log_dir=log_dir)

        save_pys(log_dir)

        repeat_n = 200
        epochs = 100
        train_dataloader, test_dataloader = get_dataloader(batch_size=10)
        model = LSTM(glove_dim=GLOVE_DIM, hidden_units=64, num_classes=4).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(5e3), gamma=0.9)

        citerion = torch.nn.CrossEntropyLoss(reduction="none")
        best_accuracy = -1.
        for epoch in range(epochs):
            train_loss, train_accuracy, train_count = 0., 0., 0

            with tqdm(train_dataloader) as tqdm_range:
                for x, y in tqdm_range:
                    x = x.to(device).repeat(repeat_n, 1, 1)
                    y = y.long().to(device).repeat(repeat_n)
                    hat_y = model(x,True)
                    hat_y = hat_y.squeeze(-1)
                    loss = citerion(hat_y, y)
                
                    predictions = torch.argmax(hat_y, dim=1)
                    optimizer.zero_grad()
                    model.backward(loss)
                    
                    optimizer.step()

                    batch_accuracy = torch.sum(torch.where(predictions == y, 1, 0)).cpu().detach().numpy()
                    train_accuracy += batch_accuracy
                    train_loss += torch.sum(loss).cpu().detach().numpy()
                    train_count += y.shape[0]
                    tqdm_range.set_description(f"acc {100*train_accuracy/train_count:.2f}%")

            train_loss /= train_count
            train_accuracy /= train_count

            valid_loss, valid_accuracy, valid_count = 0., 0., 0
            model.eval()
            for x, y in tqdm(test_dataloader):
                x = x.to(device)
                y = y.long().to(device)
                with torch.no_grad():
                    hat_y = model(x)
                hat_y.squeeze_(1)

                loss = citerion(hat_y, y)
                predictions = torch.argmax(hat_y, dim=1)

                valid_accuracy += torch.sum(torch.where(predictions == y, 1, 0)).cpu().detach().numpy()
                valid_loss += torch.sum(loss).cpu().detach().numpy()
                valid_count += y.shape[0]

            valid_loss /= valid_count
            valid_accuracy /= valid_count
            model.train()

            print(f'Train Epoch:{epoch:3d} || train loss:{train_loss:.2e} train accuracy:{train_accuracy * 100:.2f}% ' +
                  f'valid loss:{valid_loss:.4e} valid accuracy:{valid_accuracy * 100:.2f}% lr:{scheduler.get_lr()[0]:.2e}')
            save_model(model.state_dict(), valid_accuracy >= best_accuracy, log_dir)
            torch.save(optimizer.state_dict(), log_dir + '/optimizer.pth')
            torch.save(scheduler.state_dict(), log_dir + '/scheduler.pth')
            best_accuracy = deepcopy(valid_accuracy) if valid_accuracy >= best_accuracy else best_accuracy
            writer.add_scalar('loss/train_loss', train_loss, epoch)
            writer.add_scalar('loss/valid_loss', valid_loss, epoch)
            writer.add_scalar('accuracy/train_accuracy', train_accuracy, epoch)
            writer.add_scalar('accuracy/valid_accuracy', valid_accuracy, epoch)


if __name__ == '__main__':
    torch.manual_seed(0)
    np.random.seed(0)
    main()