import numpy as np
import argparse
import os
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR
import time
import datetime
import sys

from .model_utils import make_and_restore_model_3
from .datasets import ImageNet
from .datasetsFromOtherRepo import get_datasetSpecial, DATASETS
from .train_utilsFromOtherRepo import AverageMeter, accuracy, init_logfile, log
from .img_utils import save_images, save_images_1d

#%%
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('--batch', default=64, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
args = parser.parse_args()

#%%

def _imagenet(_dir) -> Dataset:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(_dir, transform)


#%%
upper_limit, lower_limit = 1,0
def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

#%%
def attack_pgd(model, X, y, attack_iters=1):

    delta = torch.zeros_like(X).cuda()   
    delta.requires_grad = True
    for _ in range(attack_iters):
        output, _ = model(X + delta)
        index = slice(None,None,None)
        if not isinstance(index, slice) and len(index) == 0:
            break
        
        loss = F.cross_entropy(output, y)
        loss.backward()
        grad = delta.grad.detach()
        g = grad[index, :, :, :]

        delta.data[index, :, :, :] = g
        delta.grad.zero_()
        
    return delta
#%%
    
def updateModel(model, inputs, nb_iter):
    model.train()
    for _ in range(nb_iter):
        # augment inputs with noise
        _ = model(inputs)
    model.eval()
    return model

#%%
def process_delta(arr, div=3):
    diff_mean = np.mean(arr, axis=(1,2,3))
    diff_std = np.std(arr, axis=(1,2,3))
    
    diff = np.transpose(arr, axes=(1,2,3,0))
    diff = (diff-diff_mean)
#    diff = np.clip(diff, 0.0, 1)
    diff = diff/(div*diff_std) + 0.5
    diff = np.transpose(diff, axes=(3,0,1,2))    
    diff = np.clip(diff, 0.0, 1)
    return diff
        
#%%

def get_batch(test_dataset):
    # select input batch
    loader = DataLoader(test_dataset, shuffle=False, batch_size= 500, 
                         num_workers= 10, pin_memory=True)
    
    
    for i, (inputs, targets) in enumerate(loader):
        break

    # select input batch
    loader = DataLoader(test_dataset, shuffle=False, batch_size= 1000, 
                         num_workers= 10, pin_memory=True)
    
    cnt = 0    
    for i, (data, label) in enumerate(loader):
        for j in range(0,data.shape[0],50):
            inputs[cnt], targets[cnt] = data[j], label[j]
            cnt = cnt+1
            if cnt>=500:
                break
        if cnt>=500:
            break
            
    inputs, targets = inputs.cuda(), targets.cuda()
    return inputs, targets
 
#%%
def test(test_dataset, model_path, bn, momentum, nb_iter, img_name):

    # switch to eval mode
    model, checkpoint = getModel(model_path, bn, momentum, ImageNet(model_path))
    
    inputs, targets = get_batch(test_dataset)
    
        
    if bn: ## Or simply use model.train() instead of updating the model with momentum = 1
        model.load_state_dict(checkpoint['state_dict']) # reload the model
        model = updateModel(model, inputs, nb_iter) # update it for specific noisy_input

    delta = attack_pgd(model, inputs, targets)
    #---------------------------------------------------------------------- CPU Operations
    
    delta_cpu = delta.detach().cpu().numpy()
    inputs = inputs.detach().cpu().numpy()
    
    inputs = np.transpose(inputs, axes=[0,2,3,1])
    delta_cpu = np.transpose(delta_cpu, axes=[0,2,3,1])        
    select_imgs(model_path, bn, inputs, delta_cpu, img_name)
        
#%%
def select_imgs(model_path, bn, inputs, delta_cpu, img_name):
    noisy_name = img_name +'.png'
    
    if 'linf' in model_path:
        grad_name = img_name+ '_linf_'+str(bn) +'.png'
        div = 1.5
    else:
        grad_name = img_name+ '_l2_'+str(bn) +'.png'
        div = 3.5        

    r, c = 3, 6
    idx = [94, 260, 307]
    _input = inputs[:r*c]
    _delta = delta_cpu[:r*c]
    
    for i in range(len(idx)):
        
        _input[i] = inputs[idx[i]]
        _delta[i] = delta_cpu[idx[i]]

    save_images_1d(noisy_name, _input, r)
    save_images_1d(grad_name, process_delta(_delta, div), r)

#%%   ##Getting the model
os.environ['IMAGENET_LOC_ENV'] = '/imagenet/'

cp_dir = '/imagenet-c/'
cp_data = ['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 
           'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 
           'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 
           'snow', 'spatter', 'speckle_noise', 'zoom_blur']

#%%
def getModel(model_path, bn, momentum, ds):
    # load-- model (default model loading) ----------------------------------------------------
    model = make_and_restore_model_3(arch='resnet50', dataset= ds, momentum= momentum)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])    
    model.eval()

    return model, checkpoint

#%%
def get_lossGrads(model_path, norm):
    
    print('Model: ' + model_path )    
    # cp data
    sev = [1,3,5]
    for i in range(len(cp_data)): # visualize for sev 1,3,5
        for j in range(len(sev)):
            print(cp_data[i] + '  sevearity:' + str(sev[j]) + " = = = = = = = = = = = = = ")
            sys.stdout.flush() 
            test_dir = cp_dir + cp_data[i] + '/' + str(sev[j]) + '/'
            test_dataset = _imagenet(test_dir)
            
            img_name = cp_data[i]+'_'+str(sev[j])
            test(test_dataset, model_path, False, 0.0, 0, img_name)

            img_name = cp_data[i]+'_'+str(sev[j])            
            test(test_dataset, model_path, True, 1, 1, img_name)
#            break
#        break
            
#%%-------------------------------------------

print('Adv-Linf ----------------------------------------\n')
adv_linfpath = './model/imagenet_linf_eps4_parallel.pt'
get_lossGrads(adv_linfpath, 'linf')

print('Adv-L2 ----------------------------------------\n')
adv_l2path = './model/imagenet_l2_eps765_parallel.pt'
get_lossGrads(adv_l2path, 'l2')

print('============= Exit ===================')


