import sys
import argparse
import os
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR
import time
import datetime

from .model_utils import make_and_restore_model_3
from .datasets import ImageNet
from .datasetsFromOtherRepo import get_datasetSpecial, DATASETS
from .train_utilsFromOtherRepo import AverageMeter, accuracy, init_logfile, log

#%%
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('--batch', default=64, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
args = parser.parse_args()

#%%

def _imagenet(_dir) -> Dataset:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(_dir, transform)


#%%
upper_limit, lower_limit = 1,0
def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

def init_delta(X, norm, epsilon):
    delta = torch.zeros_like(X).cuda()
    if norm == "l_inf":
        delta.uniform_(-epsilon, epsilon)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0),-1)
        n = d_flat.norm(p=2,dim=1).view(delta.size(0),1,1,1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r/n*epsilon
    else:
        raise ValueError
        
    delta = clamp(delta, lower_limit-X, upper_limit-X)
    return delta

#%%
def attack_pgd(model, X, y, epsilon, alpha, attack_iters, norm):

    delta = init_delta(X, norm, epsilon)        
    delta.requires_grad = True
    for _ in range(attack_iters):
        output, _ = model(X + delta)
        index = slice(None,None,None)
        if not isinstance(index, slice) and len(index) == 0:
            break
        
#        print(output.shape)
#        print(y.shape)
#        print(output.dtype)
#        print(y.dtype)
        loss = F.cross_entropy(output, y)
        loss.backward()
        grad = delta.grad.detach()
        d = delta[index, :, :, :]
        g = grad[index, :, :, :]
        x = X[index, :, :, :]
        if norm == "l_inf":
            d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0],-1),dim=1).view(-1,1,1,1)
            scaled_g = g/(g_norm + 1e-10)
            d = (d + scaled_g*alpha).view(d.size(0),-1).renorm(p=2,dim=0,maxnorm=epsilon).view_as(d)
        d = clamp(d, lower_limit - x, upper_limit - x)
        delta.data[index, :, :, :] = d
        delta.grad.zero_()
        
    return delta
#%%
def test(loader, model, criterion, noise_sd):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to eval mode
    model.eval()


    for i, (inputs, targets) in enumerate(loader):
        
#        print('Iter'+ str(i) + '    Size : ' + str(inputs.shape))
        # measure data loading time
        data_time.update(time.time() - end)

        X = inputs.cuda()
        y = targets.cuda()

        # augment inputs with noise
        delta = attack_pgd(model, X, y, epsilon, pgd_alpha, attack_iters, norm)
        delta = delta.detach()
        
        # updating the BN parameters
        X_adv = torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)
        robust_output, _ = model(X_adv)

        robust_loss = criterion(robust_output, y)

#        # compute output
#        outputs,_ = model(inputs)
#        #print(outputs.shape)
#        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(robust_output, y, topk=(1, 5))
        losses.update(robust_loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        print('Test: [{0}/{1}]\t'
              'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
              'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
            i, len(loader), batch_time=batch_time,
            data_time=data_time, loss=losses, top1=top1, top5=top5))
        sys.stdout.flush()

    print('Final \t'
          'Acc@1 {top1.avg:.3f}\t'
          'Acc@5 {top5.avg:.3f}'.format(
        i, len(loader), batch_time=batch_time,
        data_time=data_time, loss=losses, top1=top1, top5=top5))
    sys.stdout.flush()
    return (losses.avg, top1.avg)

#%%   ##Getting the model

def evalMe(adv_path):
    print(adv_path)
    ds = ImageNet(adv_path)
    model = make_and_restore_model_3(arch='resnet50', dataset= ds, momentum= 0.)
    checkpoint = torch.load(adv_path)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    
    criterion = CrossEntropyLoss().cuda()
    noise_sd=0
    test_loss, test_acc = test(test_loader, model, criterion, noise_sd)

#%%
os.environ['IMAGENET_LOC_ENV'] = '/imagenet/'
img_dir = '/imagenet/dataset/val_attack/'
test_dataset =_imagenet(img_dir)

print('Adv-Linf ----------------------------------------\n')
adv_linf_overfit = './model/imagenet_linf_eps4_parallel.pt'   # overfit

#print('ADV-L2: -----------------------------------------\n')
adv_l2_overfit = './model/imagenet_l2_eps765_parallel.pt'  # overfit


#%%
test_loader = DataLoader(test_dataset, shuffle=True, batch_size= 500,
                         num_workers= 10, pin_memory= True)

#%%
print('ADV-Linf: -----------------------------------------\n')
epsilon, attack_iters, norm = 4./255., 100, 'l_inf'
pgd_alpha = epsilon/4.
evalMe(adv_linf_overfit)

print('ADV-L2: -----------------------------------------\n')
epsilon, attack_iters, norm = 765./255., 100, 'l_2'
pgd_alpha = epsilon/8.5
evalMe(adv_l2_overfit)
print('============= Exit ===================')
