import sys
import argparse
import os
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR
import time
import datetime

from .model_utils import make_and_restore_model_3
from .datasets import ImageNet
from .datasetsFromOtherRepo import get_datasetSpecial, DATASETS
from .train_utilsFromOtherRepo import AverageMeter, accuracy, init_logfile, log

#%%
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('--batch', default=64, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
args = parser.parse_args()

#%%

def _imagenet(_dir) -> Dataset:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(_dir, transform)


#%%
upper_limit, lower_limit = 1,0
def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

#%%

def init_delta(X, norm):
    delta = torch.zeros_like(X).cuda()

    if norm == "l_inf":
        delta.uniform_(-epsilon, epsilon)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0),-1)
        n = d_flat.norm(p=2,dim=1).view(delta.size(0),1,1,1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r/n*epsilon
    delta = clamp(delta, lower_limit-X, upper_limit-X)
    return delta

def attack_pgd_eot(model_list, X, y, epsilon, alpha, attack_iters, norm):

    delta = init_delta(X,norm)
    delta.requires_grad = True
    
    # attach loop -------------------------------------------------------------
    for _iter in range(attack_iters):
#        print('Iter: ' + str(_iter))
        X_adv = X + delta
        output,_ = model_list[0](X_adv)
        for j in range(1, len(model_list)):
            _out, _ = model_list[j](X_adv)
            output = output + _out

        output = output/(len(model_list))

        index = slice(None,None,None)
        if not isinstance(index, slice) and len(index) == 0:
            break
        
        loss = F.cross_entropy(output, y)
#        print('Loss: ' + str(loss.shape) + '   '+ str(loss))
        loss.backward()
        grad = delta.grad.detach()
        
        d = delta[index, :, :, :]
        g = grad[index, :, :, :]
        x = X[index, :, :, :]
        
        if norm == "l_inf":
            d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0],-1),dim=1).view(-1,1,1,1)
            scaled_g = g/(g_norm + 1e-10)
            d = (d + scaled_g*alpha).view(d.size(0),-1).renorm(p=2,dim=0,maxnorm=epsilon).view_as(d)

        d = clamp(d, lower_limit - x, upper_limit - x)
        delta.data[index, :, :, :] = d
        delta.grad.zero_()

    return delta

#%%

def updateModel(model, inputs, nb_iter, norm, noise= True):
    model.train()
    for _ in range(nb_iter):
        # augment inputs with noise
        if noise:
            _ = model(inputs + init_delta(inputs, norm))
        else:
            _ = model(inputs)
    model.eval()
    return model

def create_model_list(nb_model, checkpoint, norm, ds, momentum, X, nb_iter):
    model_list = []

    for i in range(nb_model):
        _model = make_and_restore_model_3(arch='resnet50', dataset= ds, momentum= momentum)
        _model.load_state_dict(checkpoint['state_dict'])

        print('Model: -----------------------------  '+ str(i))
        sys.stdout.flush()
        
        # update the model
        _model = updateModel(_model, X, nb_iter, norm)
        model_list.append(_model)
    
    return model_list

#%%
def get_b_attack(model_list, X, y):
    adv_X = torch.zeros_like(X).cuda()
    for j in range(nb_attack_batches):
        t_start = time.time()
        start = j*nb_attack_bSize 
        end = min( (j+1)*nb_attack_bSize, total_imgs)
        
        X_batch, y_batch = X[start:end], y[start:end]        
        delta = attack_pgd_eot(model_list, X_batch, y_batch, epsilon, pgd_alpha, attack_iters, norm)

        delta = delta.detach()
        adv_X[start:end]= torch.clamp(X_batch + delta, min=lower_limit, max=upper_limit)
        t_end = time.time()
        print("Iter (Adv Gen): " + str(j) + "   Time:" + str(t_end-t_start))
        sys.stdout.flush()    
    return adv_X
#%%
def test(loader, criterion, checkpoint, ds):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
    
    # create model list
    for i, (inputs, targets) in enumerate(loader):
        X = inputs.cuda()
        y = targets.cuda()

    model_list = create_model_list(nb_model, checkpoint, norm, ds, 
                                   momentum, X, nb_iter)
    
    ## check--- models are producing different results.
#    for i in range(len(model_list)):
#        out = model_list[i](X)    
#        print( torch.sum(out[0]*out[0]) )
#        sys.stdout.flush()

    test_model = make_and_restore_model_3(arch='resnet50', dataset= ds, momentum= momentum)
    test_model.load_state_dict(checkpoint['state_dict'])

    
    # generate attack images
    # create model list
    for i, (inputs, targets) in enumerate(loader):
        X = inputs.cuda()
        y = targets.cuda()
        

        adv_X = get_b_attack(model_list, X, y)

        # test the adversarial batch now!
        test_model.load_state_dict(checkpoint['state_dict'])        
        updateModel(test_model, adv_X, nb_iter, norm, False)

        robust_output, _ = test_model(adv_X)
        robust_loss = criterion(robust_output, y)
        acc1, acc5 = accuracy(robust_output, y, topk=(1, 5))
        losses.update(robust_loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        print('Test: [{0}/{1}]\t'
              'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
              'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
            i, len(loader), batch_time=batch_time,
            data_time=data_time, loss=losses, top1=top1, top5=top5))
        sys.stdout.flush() 
        
    print('Final \t'
          'Acc@1 {top1.avg:.3f}\t'
          'Acc@5 {top5.avg:.3f}'.format(
        i, len(loader), batch_time=batch_time,
        data_time=data_time, loss=losses, top1=top1, top5=top5))

    return (losses.avg, top1.avg)

#%%   ##Getting the model

def evalMe(adv_path):
    print(adv_path)
    ds = ImageNet(adv_path)
    checkpoint = torch.load(adv_path)
    
    criterion = CrossEntropyLoss().cuda()
    test_loss, test_acc = test(test_loader, criterion, checkpoint, ds)

#%%
os.environ['IMAGENET_LOC_ENV'] = '/imagenet/'
img_dir = '/imagenet/dataset/val_attack/'
test_dataset =_imagenet(img_dir)

adv_linf_overfit = './model/imagenet_linf_eps4_parallel.pt'   # overfit: Adv_Linf
adv_l2_overfit = './model/imagenet_l2_eps765_parallel.pt'  # overfit: ADV_L2

#%%

total_imgs = 2000
bSize = 400
nb_attack_bSize = 80
nb_attack_batches = int(bSize/nb_attack_bSize)

momentum = 1.0 
nb_model = 10
nb_iter = 1

#%%
test_loader = DataLoader(test_dataset, shuffle=True, batch_size= bSize,
                         num_workers= 15, pin_memory= True)

#%%
print('ADV-Linf: -----------------------------------------\n')
epsilon, attack_iters, norm = 4./255., 100, 'l_inf'
pgd_alpha = epsilon/4.
evalMe(adv_linf_overfit)

print('ADV-L2: -----------------------------------------\n')
epsilon, attack_iters, norm = 765./255., 100, 'l_2'
pgd_alpha = epsilon/8.5
evalMe(adv_l2_overfit)

print('============= Exit ===================')
