import torch
from torch.utils.data import Dataset
from torchvision import datasets,transforms
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
import numpy as np
import pickle
import torch.backends.cudnn as cudnn

import os

import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.utils.data import DataLoader
from utils import *

from model import LeNet

#==========
# arguments
#==========

total_reps = 5
device = 'cuda:0'#'cpu'
batch_size_train = 1000#20

noise_std = 0.01
repeat = 1000

load_path_dataset = 'data/FashionMNIST_corrupted_dataset_'
load_path_model_weights = 'checkpoints/REPS_LB/'


#================
# dataset
#================

# add transformer for FashionMNIST dataset
stats = {'mean': [0.5],'std': [0.5]}
trans = [
        transforms.ToTensor(),
        lambda t: t.type(torch.get_default_dtype()),
        transforms.Normalize(**stats)
        ]

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.Compose(trans)
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.Compose(trans)
)

#===========================================================
# modify training data: load the ones generated by GD script
#===========================================================
np_data = np.load(open(load_path_dataset+'data.npy','rb'))
np_label = np.load(open(load_path_dataset+'label.npy','rb'))

training_data.data = torch.from_numpy( np_data ) #training_data.data[random_idx_list_for_training_data]
training_data.targets = torch.from_numpy( np_label ) #training_data.targets[random_idx_list_for_training_data]

#============
# data loader
#============
train_dataloader = DataLoader(training_data, batch_size= batch_size_train, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=100, shuffle=False)

#===================
# evaluate sharpness
#===================
list_sharpness = []
cudnn.benchmark = True
for idx_rep in range(total_reps):
    net = LeNet()
    net = net.to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)
    # load checkpoint
    checkpoint = torch.load(load_path_model_weights + 'REP_'+str(idx_rep)+'model_iter10000.pth')
    state_dict = OrderedDict()
    for key in checkpoint.keys():
        state_dict[key] = checkpoint[key].cpu().to(device)
    del checkpoint
    # refresh model weights
    net.load_state_dict(state_dict)

    # check performance on test set
    check_loaded_weight_loss,check_loaded_weight_acc = eval_loss_and_acc_on_valid_set(net,test_dataloader,loss_fn,device = device)
    print('PERFORMANCE OF PRETRAINED MODEL: Loss '+str(check_loaded_weight_loss)+', accuracy '+str(check_loaded_weight_acc))

    # evaluate sharpness
    dLoss, dAcc = loop_evaluate_expected_sharpness(net, state_dict, noise_std, repeat, train_dataloader, loss_fn, device)

    # store results in numpy array
    list_sharpness.append([dLoss,dAcc])
    np.save(open(load_path_model_weights+'sharpness.npy','wb'),np.array(list_sharpness))
