import numpy as np
import argparse
import os
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR
import time
import datetime

from .model_utils import make_and_restore_model_3
from .datasets import ImageNet
from .datasetsFromOtherRepo import get_datasetSpecial, DATASETS
from .train_utilsFromOtherRepo import AverageMeter, accuracy, init_logfile, log
from .img_utils import save_images, save_images_1d

#%%
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('--batch', default=64, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
args = parser.parse_args()

#%%

def _imagenet(_dir) -> Dataset:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(_dir, transform)


#%%
upper_limit, lower_limit = 1,0
def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

#%%
def attack_pgd(model, X, y, attack_iters=1):

    delta = torch.zeros_like(X).cuda()   
    delta.requires_grad = True
    for _ in range(attack_iters):
        output, _ = model(X + delta)
        index = slice(None,None,None)
        if not isinstance(index, slice) and len(index) == 0:
            break
        
        loss = F.cross_entropy(output, y)
        loss.backward()
        grad = delta.grad.detach()
        g = grad[index, :, :, :]

        delta.data[index, :, :, :] = g
        delta.grad.zero_()
        
    return delta
#%%
    
def updateModel(model, inputs, nb_iter):
    model.train()
    for _ in range(nb_iter):
        # augment inputs with noise
        _ = model(inputs)
    model.eval()
    return model

#%%
def process_delta(arr, div=3):
    diff_mean = np.mean(arr, axis=(1,2,3))
    diff_std = np.std(arr, axis=(1,2,3))
    
    diff = np.transpose(arr, axes=(1,2,3,0))
    diff = (diff-diff_mean)
#    diff = np.clip(diff, 0.0, 1)
    diff = diff/(div*diff_std) + 0.5
    diff = np.transpose(diff, axes=(3,0,1,2))    
    diff = np.clip(diff, 0.0, 1)
    return diff
        
#%%
def test(loader, model_path, bn, noise_sd, momentum, nb_iter):

    # switch to eval mode
    ds = ImageNet(model_path)
    model = make_and_restore_model_3(arch='resnet50', dataset= ds, momentum= momentum)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    
    for i, (inputs, targets) in enumerate(loader):
        
        print('Iter'+ str(i) + '    Size : ' + str(inputs.shape))
        # measure data loading time

        X = inputs.cuda()
        y = targets.cuda()
        
        noisy_input = X + torch.randn_like(X, device='cuda') * noise_sd
        
        if bn: ## Or simply use model.train() instead of updating the model with momentum = 1
            model.load_state_dict(checkpoint['state_dict']) # reload the model
            model = updateModel(model, noisy_input, nb_iter) # update it for specific noisy_input


        # augment inputs with noise
        delta = attack_pgd(model, noisy_input, y)
        #---------------------------------------------------------------------- CPU Operations
        
        delta_cpu = delta.detach().cpu().numpy()
        inputs = inputs.detach().cpu().numpy()
        noisy_input = noisy_input.detach().cpu().numpy()
        
        inputs = np.transpose(inputs, axes=[0,2,3,1])
        noisy_input = np.transpose(noisy_input, axes=[0,2,3,1])
        delta_cpu = np.transpose(delta_cpu, axes=[0,2,3,1])        
        select_imgs(model_path, bn, noise_sd, noisy_input, inputs, delta_cpu)
        
        
def select_imgs(model_path, bn, noise_sd, noisy_input, inputs, delta_cpu):
    noisy_path = 'noisy_'+ str(int(noise_sd*100)) +'.png'
    if 'linf' in model_path:
        grad_path = 'linf_'+str(bn)+ '_'+ str(int(noise_sd*100)) +'.png'
        div = 1.5
    else:
        grad_path = 'l2_'+str(bn)+ '_'+ str(int(noise_sd*100)) +'.png'
        div = 3.5        
    r, c = 2, 1
    s = 100
    _input = noisy_input[:r*c]
    _delta = delta_cpu[:r*c]
    
    for i in range(r*c):
        _input[i] = noisy_input[i*s]
        _delta[i] = delta_cpu[i*s]

    save_images_1d(noisy_path, _input, r)
    save_images_1d(grad_path, process_delta(_delta, div), r)

#%%   ##Getting the model
os.environ['IMAGENET_LOC_ENV'] = '/imagenet/'

img_dir = '/imagenet/dataset/val_certify_2/'
test_dataset =_imagenet(img_dir)

def get_lossGrads(adv_path):    
    test_loader = DataLoader(test_dataset, shuffle= True, batch_size= 500, 
                         num_workers= 2, pin_memory= True)

    noisy = [0., 0.25, 0.5, 0.75]
    for i in range(len(noisy)):
        test(test_loader, adv_path, bn= False, noise_sd= noisy[i], momentum= 1, nb_iter= 1)
        test(test_loader, adv_path, bn= True, noise_sd= noisy[i], momentum= 1, nb_iter= 1)


#%%-------------------------------------------

print('Adv-Linf ----------------------------------------\n')
adv_linfpath = './model/imagenet_linf_eps4_parallel.pt'
get_lossGrads(adv_linfpath)

print('Adv-L2 ----------------------------------------\n')
adv_l2path = './model/imagenet_l2_eps765_parallel.pt'
get_lossGrads(adv_l2path)

print('============= Exit ===================')
