import argparse
import os
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR
import time
import datetime
import sys

from .model_utils import make_and_restore_model_3
from .datasets import ImageNet
from .datasetsFromOtherRepo import get_datasetSpecial, DATASETS
from .train_utilsFromOtherRepo import AverageMeter, accuracy, init_logfile, log

#%%
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('--batch', default=256, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
args = parser.parse_args()

#%%

def _imagenet(_dir) -> Dataset:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(_dir, transform)

def updateModel(model, inputs, nb_iter):
    model.train()
    for _ in range(nb_iter):
        # augment inputs with noise
        _ = model(inputs)
    model.eval()
    return model

#%%
def one_hot(a, num_class= 10):
    b = np.zeros((a.size, num_class))
    b[np.arange(a.size), a] = 1
    return b

#%%
    
def test(loader, model_path, bn, noise_sd, momentum, nb_iter):
    
    top1 = []
    for i in range(r_iter):
        top1.append(AverageMeter())
    
    # load-- model (default model loading)
    model = make_and_restore_model_3(arch='resnet50', dataset= ds, momentum= momentum)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    ## set larger batch-size
#    loader = DataLoader(test_dataset, shuffle=True, batch_size= 512, num_workers=args.workers, pin_memory=True)

    for i, (X, y) in enumerate(loader):
        inputs, targets = X.cuda(), y.cuda()
        
        for j in range(r_iter):
            noisy_input = inputs + torch.randn_like(inputs, device='cuda') * noise_sd
            if bn and i==0 and j==0:
                model.load_state_dict(checkpoint['state_dict']) # reload the model
                if momentum < 1.:
                    model = updateModel(model, noisy_input, nb_iter) # update it for specific noisy_input
                else:
                    model.train()

            outputs, _ = model(noisy_input)
        
            # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            top1[j].update(acc1.item(), inputs.size(0))
        
        if i % 100 ==0:            
            print('Test: [{0}/{1}]\t'
                  'Acc@1 {top.val:.3f} ({top.avg:.3f})\t'.format(
                i, len(loader), top=top1[0]))
            sys.stdout.flush() 
            
    lst = []
    for i in range(r_iter):
        lst.append(top1[i].avg)
    
    return lst


#%%   ##Getting the model
os.environ['IMAGENET_LOC_ENV'] = '/imagenet/'
r_iter= 3

_dir = '/imagenet/dataset/val/'

test_dataset =_imagenet(_dir)
r_iter = 5
workers =10
b_size = 500


def eval_me(model_path, bn, momentum, nb_iter= 5):
    noise_sd_lst = [0.25, 0.5, 0.75]
    print(str(model_path) + ' Momentum: ' +str(momentum) + ' BN:' + str(bn))
    

    bSize_list = [512]
    
    for b in range(len(bSize_list)):
        b_size = bSize_list[b]
        print('Batch Size: '+ str(b_size))
        test_loader = DataLoader(test_dataset, shuffle=True, batch_size= b_size, 
                                 num_workers=workers, pin_memory=True)

        for i in range(len(noise_sd_lst)):
            lst = test(test_loader, model_path, bn, noise_sd_lst[i], momentum= momentum, nb_iter=nb_iter)
            print(str( np.around(np.mean(lst), decimals= 1)) 
                  + "{\\tiny $\\pm" + str( np.around(np.std(lst)*2, decimals= 2)) + "$}&"  
                  , end = ' ')
        print()


def eval_model(model_path):
    eval_me(model_path, False, 0.0, 0)
    
#    eval_me(model_path, True, 0.1, 5)
#    eval_me(model_path, True, 0.3, 5)
#    eval_me(model_path, True, 0.5, 5)
#    eval_me(model_path, True, 0.7, 5)
#    eval_me(model_path, True, 0.9, 5)

    eval_me(model_path, True, 1.01, 5)  #-------- full adaptation



print('Adv-Linf ----------------------------------------\n')
adv_path = './model/imagenet_linf_eps4_parallel.pt'
ds = ImageNet(adv_path)
eval_model( model_path = adv_path)

print('ADV-L2: -----------------------------------------\n')
adv_path = './model/imagenet_l2_eps765_parallel.pt'
ds = ImageNet(adv_path)
eval_model( model_path = adv_path)
print('============= Exit ===================')


