import sys, os
sys.path.append('./..')
from model import *
from dataloader import *
import argparse
import numpy as np
import torch.optim as optim
from time import time
parser = argparse.ArgumentParser()

parser.add_argument('--datapath', type=str, default='./data', help='data path')
parser.add_argument('--dataset', type=str, default='svhn')
parser.add_argument('--batch-size', type=int, default=100)
parser.add_argument('--epoch', type=int, default=100)

parser.add_argument('--method', type=str, default='ours')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--lr', type=float, default=1e-3, metavar='N', help='learning rate')
parser.add_argument('--MC-train', type=int, default=1)
parser.add_argument('--MC-test', type=int, default=50)
parser.add_argument('--droprate', type=float, default=0.7)
parser.add_argument('--klw', type=float, default=0.2, help='KL annealing')
parser.add_argument('--num-HH', type=int, default=2)

parser.add_argument('--lowrank', default=False, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('--v0-option', type=int, default=0) 

parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')

args = parser.parse_args()

args.cuda = not args.no_cuda and torch.cuda.is_available()

os.environ['PYTHONHASHSEED']=str(args.seed)
import random
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic=True

device = torch.device("cuda:0" if args.cuda else "cpu")

model = LeNetVSD(args).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 30, 40, 50, 60, 70, 80], gamma=0.3)
train_loader, test_loader = get_svhn(batch_size=args.batch_size, datapath=args.datapath)
learner = Learner(model, len(train_loader), len(train_loader.dataset))

for epoch in range(1, args.epoch + 1):
    train_loss, train_elbo, train_acc = 0, 0, 0
    t1 = time()
    for batch_idx, (data, target) in enumerate(train_loader):
        model.train()
        torch.cuda.empty_cache()
        data, target = data.to(device), target.to(device)
        if args.dataset == 'mnist':
            data = data.view(-1, 28*28)

        optimizer.zero_grad()
        output = model(data)
        pred = output.data.max(1)[1]
        nll_train, loss, elbo, alpha_ = learner(output, target, args.klw) # KL-weight here == SON ==
        loss.backward()
        optimizer.step()

        # train_loss += loss.item()
        # train_elbo += elbo.item()
        train_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())

        del loss, elbo
    print('Time/epoch: ', time() - t1)
    scheduler.step()

    print_train = 'epoch ' + str(epoch) + ' train_acc ' + str(train_acc / len(train_loader.dataset) * 100) + '\n'
    print(print_train)

    model.eval()
    test_loss, test_elbo, test_acc = 0, 0, 0
    nll_list = []

    for batch_idx, (data, target) in enumerate(test_loader):
        outputs = torch.zeros(args.MC_test, len(data), 10).to(device)
        torch.cuda.empty_cache()
        data, target = data.to(device), target.to(device)
        if args.dataset == 'mnist':
            data = data.view(-1, 28*28)
        with torch.no_grad():
            for i in range(args.MC_test):
                outputs[i] = model(data)
            output = outputs.mean(0)
            preds = outputs.max(2, keepdim=True)[1]
            pred = output.max(1, keepdim=True)[1].squeeze()
            nll_test, loss, elbo, alpha_ = learner(output, target, args.klw)

        test_loss += loss.item()
        test_elbo += elbo.item()
        nll_list.append(nll_test.item())
        # pred = output.data.max(1)[1]
        test_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())
        del loss, elbo

    test_accuracy = test_acc / len(test_loader.dataset) * 100
    nll_final = np.mean(nll_list)
 
    print_test = 'Epoch ' + str(epoch) + ' test_acc ' + str(test_accuracy) + ' test_nll ' + str(nll_final) +  '\n'
    print(print_test)


