import sys
import argparse
import os
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR
import time
import datetime

from .model_utils import make_and_restore_model,model_dataset_from_store
from .datasets import ImageNet
from .datasetsFromOtherRepo import get_datasetSpecial, DATASETS
from .train_utilsFromOtherRepo import AverageMeter, accuracy, init_logfile, log

#%%
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('--batch', default=256, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--workers', default= 10, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
args = parser.parse_args()

#%%

def _imagenet(_dir) -> Dataset:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(_dir, transform)

def updateModel(model, loader, noise_sd, momentum):
#    with torch.no_grad():
    for i, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.cuda(), targets.cuda()
        
        # augment inputs with noise
        inputs = inputs + torch.randn_like(inputs, device='cuda') * noise_sd
        model.train()
        _ = model(inputs)
        model.eval()
    return model

#%%
def one_hot(a, num_class= 10):
    b = np.zeros((a.size, num_class))
    b[np.arange(a.size), a] = 1
    return b

#%%
    
def test(loader, adv_path, bn, noise_sd, momentum, nb_iter, save_path):
    
    # load-- model (default model loading)
    model, _ = make_and_restore_model(arch='resnet50', dataset= ds, resume_path= adv_path, 
                                  parallel=True, pytorch_pretrained=False, 
                                  add_custom_forward=False, momentum= momentum)
   
    model.eval()
#    print(model)
    
#   print model information
    print(img_dir)
    print('Model: ' + adv_path + '   BN: '+ str(bn) + 
          '   Noise: ' + str(noise_sd) + '   Momentum: ' + str(momentum)
          )
    print(save_path)
#   fix a model and then eval
    if bn:
        for i in range(nb_iter):
            model = updateModel(model, loader, noise_sd, momentum)
            
    ## set larger batch-size
    loader = DataLoader(test_dataset, shuffle=False, batch_size= 500, 
                        num_workers=args.workers, pin_memory=True)

#    with torch.no_grad():
    y_pred = np.zeros(shape=(nb_data, 1000))
    
    for i, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.cuda(), targets.cuda()
        y_true= targets.cpu().numpy()

        for n in range(N):
            if n % print_freq ==0:
                print('Iter: '+ str(n))        
        
            X = inputs + torch.randn_like(inputs, device='cuda') * noise_sd        
            outputs,_ = model(X)
            outputs = outputs.argmax(1)
            outputs = outputs.cpu().numpy()
            y_pred += one_hot(outputs, 1000)
            
            
            if n % print_freq ==0:
                np.savez(save_path, y_true= y_true, y_pred= y_pred)
                print('saving...  ' + str(n))
                sys.stdout.flush()
                
    np.savez(save_path, y_true= y_true, y_pred= y_pred)            
    return

#%%

N = 100000
nb_data= 500
print_freq = 5000

#%%   ##Getting the model
os.environ['IMAGENET_LOC_ENV'] = '/imagenet/'
adv_path = './model/imagenet_linf_eps4.pt'
ds = ImageNet(adv_path)
img_dir = '/imagenet/dataset/val_certify/'
test_dataset =_imagenet(img_dir)


test_loader = DataLoader(test_dataset, shuffle=False, batch_size=500, num_workers= 15, pin_memory= True)

#s_path = 'linf_overfit_False_mom_5_noise_005.npz'
#test(test_loader, adv_path, bn= False, noise_sd = 0.05, momentum= 0.0, nb_iter=0, save_path=s_path)
#
s_path = 'linf_overfit_True_mom_5_noise_075.npz'
test(test_loader, adv_path, bn= True, noise_sd = 0.75, momentum= 1.0, nb_iter=1, save_path=s_path)
#
#s_path = 'linf_overfit_True_mom_5_noise_025.npz'
#test(test_loader, adv_path, bn= True, noise_sd = 0.25, momentum= 1.0, nb_iter=1, save_path=s_path)



#%%------- L2
#adv_path = './model/imagenet_l2_3_0.pt'
#adv_path = './model/imagenet_l2_eps765.pt'
#ds = ImageNet(adv_path)
#test_loader = DataLoader(test_dataset, shuffle=False, batch_size=500, num_workers= 4, pin_memory= True)

#s_path = 'l2_overfit_False_mom_5_noise_005.npz'
#test(test_loader, adv_path, bn= False, noise_sd = 0.05, momentum= 0.0, nb_iter=0, save_path=s_path)

#s_path = 'l2_overfit_True_mom_5_noise_025.npz'
#test(test_loader, adv_path, bn= True, noise_sd = 0.25, momentum= 1.0, nb_iter=1, save_path=s_path)

#s_path = 'l2_overfit_True_mom_5_noise_075.npz'
#test(test_loader, adv_path, bn= True, noise_sd = 0.75, momentum= 1.0, nb_iter=1, save_path=s_path)


print('============= Exit ===================')


