import torch
import re, os
import datetime
import numpy as np
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from model import ResNet5, VGG8
from dataloader import get_train_valid_loader
import shutil
import random


def save_model(state_dict, is_best, log_dir):
    torch.save(state_dict, log_dir+'/latest.pth')
    if is_best:
        torch.save(state_dict, log_dir+'/best.pth')


def save_py(log_dir):
    for filename in os.listdir('./'):
        if filename.endswith(".py"):
            src_path = os.path.join('./', filename)
            dst_path = os.path.join(log_dir, filename)
            shutil.copy(src_path, dst_path)


if __name__ == '__main__':
    is_lr = True    # use LR or BP

    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    device = torch.device('cuda:2')

    net = VGG8().to(device)
    current_time = re.sub(r'\D', '', str(datetime.datetime.now())[4:-7])
    log_dir = './logs/'+ type(net).__name__ + '/LR_' + current_time if is_lr else \
        './logs/' + type(net).__name__ + '/BP_' + current_time
    writer = SummaryWriter(log_dir=log_dir)
    save_py(log_dir)

    train_loader, valid_loader = get_train_valid_loader(data_dir='./data/cifar10', batch_size=100)

    repeat_n = [100, 200, 400, 800, 400, 200, 100, 50]
    epochs = 100

    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
    criterion = nn.CrossEntropyLoss(reduction='none')

    best_accuracy = -1.
    layer_n = len(net.module_w_para)
    for epoch in range(epochs):
        train_loss, train_accuracy, train_count = 0., 0., 0
        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            if is_lr:
                with torch.no_grad():
                    # this part can be done on multi-GPUs parallely
                    for i in range(layer_n):
                        inputs_ = inputs.repeat(repeat_n[i], 1, 1, 1)
                        labels_ = labels.repeat(repeat_n[i])
                        add_noise = [True if j == i else False for j in range(layer_n)]
                        outputs = net(inputs_, add_noise)
                        loss = criterion(outputs, labels_)
                        net.backward(loss)
                        train_accuracy += torch.sum(torch.where(labels_ == torch.argmax(outputs, 1), 1, 0)).cpu().detach().numpy()
                        train_loss += torch.sum(loss).cpu().detach().numpy()
                        train_count += len(labels_)
            else:
                outputs = net(inputs, add_noise)
                loss = criterion(outputs, labels)
                loss_ = torch.mean(loss)
                loss_.backward()
                train_accuracy += torch.sum(torch.where(labels == torch.argmax(outputs, 1), 1, 0)).cpu().detach().numpy()
                train_loss += torch.sum(loss).cpu().detach().numpy()
                train_count += len(labels)
            optimizer.step()
        scheduler.step()
        train_loss /= train_count
        train_accuracy /= train_count

        valid_loss, valid_accuracy, valid_count = 0., 0., 0
        net.eval()
        for inputs, labels in tqdm(valid_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            valid_accuracy += torch.sum(torch.where(labels == torch.argmax(outputs, dim=1), 1, 0)).cpu().detach().numpy()
            valid_loss += torch.sum(loss).cpu().detach().numpy()
            valid_count += inputs.shape[0]
        net.train()
        valid_loss /= valid_count
        valid_accuracy /= valid_count

        print(f'Train Epoch:{epoch:3d} || train loss:{train_loss:.2e} train accuracy:{train_accuracy*100:.2f}% ' +
              f'valid loss:{valid_loss:.4e} valid accuracy:{valid_accuracy*100:.2f}% lr:{scheduler.get_last_lr()[0]:.2e}')
        save_model(net.state_dict(), valid_accuracy >= best_accuracy, log_dir)
        torch.save(optimizer.state_dict(), log_dir + '/optimizer.pth')
        torch.save(scheduler.state_dict(), log_dir + '/scheduler.pth')
        best_accuracy = deepcopy(valid_accuracy) if valid_accuracy >= best_accuracy else best_accuracy
        writer.add_scalar('loss/train_loss', train_loss, epoch)
        writer.add_scalar('loss/valid_loss', valid_loss, epoch)
        writer.add_scalar('accuracy/train_accuracy', train_accuracy, epoch)
        writer.add_scalar('accuracy/valid_accuracy', valid_accuracy, epoch)

    print(f'Finished Training') 