import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter

from collections import OrderedDict
from svhn import SVHN
from vgg import vgg11
from utils import *

import pickle

parser = argparse.ArgumentParser()
parser.add_argument('--total_reps', type=int, default=int(5))
parser.add_argument('--repeat', type=int, default=100)
parser.add_argument('--noise-std', type=float, default=1e-2)
parser.add_argument('--resume', type=str, default='logs/REPS_SB/')
parser.add_argument('--datadir', type=str, default='datasets/SVHN/train25000_test70000')
parser.add_argument('--source_iter', type=int, default = int(30000))

args = parser.parse_args()
logger = LogSaver(args.resume)
logger.save(str(args), 'args')

device = torch.device("cuda:0")
torch.cuda.set_device(device)
#device_for_loading_model_weights = torch.device("cuda:0")

# data
dataset = SVHN(args.datadir)
logger.save(str(dataset), 'dataset')
train_list = dataset.getTrainList(5000, device)
#test_list = dataset.getTestList(1000, device)

# writer
writer = SummaryWriter(args.resume)

# model
#model = vgg11()#.cuda()
#model.to(device)
#logger.save(str(model), 'classifier')
#criterion = nn.CrossEntropyLoss().to(device)#.cuda()

list_sharpness = []
for idx_rep in range(args.total_reps):
    logger.save('       ')
    logger.save('              REP '+str(idx_rep))
    logger.save('       ')
    # model
    model = vgg11()#.cuda()
    model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    # load checkpoint
    checkpoint = torch.load(args.resume + 'REP_'+str(idx_rep)+'iter-30000.pth.tar')['model']
    state_dict = OrderedDict()
    for key in checkpoint.keys():
        state_dict[key] = checkpoint[key].cpu().to(device)
    del checkpoint
    # refresh model weights
    model.load_state_dict(state_dict)

    # evaluate flatness
    dlossT, daccT = deltaLossAcc(train_list, None, model, criterion, state_dict, args.noise_std, args.repeat,device)
    writer.add_scalar('delatAcc/train', daccT, args.source_iter)
    writer.add_scalar('deltaLoss/train', dlossT, args.source_iter)
    logger.save('Model:%d, Train [dacc: %.2f, dloss: %.6f]' \
                % (args.source_iter, daccT, dlossT))

    # store the results in numpy array
    list_sharpness.append([dlossT,daccT])
    np.save(open(args.resume+'sharpness.npy','wb'),np.array(list_sharpness))

writer.close()
