import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import grad
from tensorboardX import SummaryWriter

from collections import OrderedDict

from svhn import SVHN
from vgg import vgg11
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--total_reps', type=int, default=int(5))
parser.add_argument('--maxiter', type=int, default=int(3e4+1))
#parser.add_argument('--phase_end_iter', type=int, default=int(2.5e4))
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--batch-size', type=int, default=100)
parser.add_argument('--big_batch_size', type=int, default=1000)
parser.add_argument('--datadir', type=str, default='datasets/SVHN/train25000_test70000')
parser.add_argument('--logdir', type=str, default='logs/REPS_SB')
parser.add_argument('--model_weight_dir', type=str, default='logs/VGG11_initialization')
parser.add_argument('--seed', type=int, default=int(1))
#parser.add_argument('--heavy_tail_noise_alpha', type=float, default=1.4)
#parser.add_argument('--heavy_tail_noise_magnitude', type=float, default=0.5)
parser.add_argument('--gradient_clip', type=float, default=20)
parser.add_argument('--eps', type=float, default=1e-4)
parser.add_argument('--iter_save_model', type=int, default=int(2000))
parser.add_argument('--iter_eval', type=int, default=int(100))

args = parser.parse_args()
logger = LogSaver(args.logdir)
logger.save(str(args), 'args')

device = torch.device("cuda:0")
torch.cuda.set_device(device)
#device_for_loading_model_weights = torch.device("cuda:0")

#heavy_tail_noise_alpha = args.heavy_tail_noise_alpha
#heavy_tail_noise_magnitude = args.heavy_tail_noise_magnitude
gradient_clip = args.gradient_clip
#dist_heavy_tail_noise = torch.distributions.pareto.Pareto(1, heavy_tail_noise_alpha )

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)


# data
dataset = SVHN(args.datadir)
logger.save(str(dataset), 'dataset')
#train_list = dataset.getTrainList(1000, device)
test_list = dataset.getTestList(1000, device)

# writer
writer = SummaryWriter(args.logdir)

for idx_rep in range(args.total_reps):
    logger.save('       ')
    logger.save('              REP '+str(idx_rep))
    logger.save('       ')
    # model
    start_iter = 0
    model = vgg11()#.cuda()
    # load model weight
    checkpoint = torch.load(args.model_weight_dir+'/initialized_weight_'+str(idx_rep)+'.pth.tar')
    model.load_state_dict(checkpoint)
    del checkpoint
    # move to gpu
    model.to(device)
    logger.save(str(model), 'classifier')

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    logger.save(str(optimizer), 'optimizer')

    # optimization
    torch.backends.cudnn.benchmark = True
    train_history = []
    test_history = []
    for i in range(start_iter, args.maxiter):

        model.train()
        optimizer.zero_grad()

        x, y = dataset.getTrainBatch(args.batch_size, device)
        out = model(x)
        loss = criterion(out, y)
        loss.backward()

        optimizer.step()

        # training performance
        acc_train = accuracy(out,y)
        loss_train = loss.detach().item()
        train_history.append([i, acc_train, loss_train])

        # evaluate
        if i % args.iter_eval == 0: #or i <= args.iter_eval:
            model.eval()

            acc, loss = 0.0, 0.0
            for x,y in test_list:
                out = model(x)
                acc += accuracy(out, y).item()
                loss += criterion(out, y).item()
            acc /= len(test_list)
            loss /= len(test_list)
            test_history.append([i, acc, loss])

            np.save(open(args.logdir+'/REP_'+str(idx_rep)+'train_history.npy','wb'),np.array(train_history))
            np.save(open(args.logdir+'/REP_'+str(idx_rep)+'test_history.npy','wb'),np.array(test_history))


            logger.save('Iter:%d, Test [acc: %.2f, loss: %.4f], Train [acc: %.2f, loss: %.4f]' \
                % (i, acc, loss, acc_train, loss_train))

        if i % args.iter_save_model == 0:
            state = {'iter':i, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}
            torch.save(state, args.logdir+'/REP_'+str(idx_rep)+'iter-'+str(i)+'.pth.tar')

writer.close()
