import sys
import argparse
import os
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR
import time
import datetime

from .model_utils import make_and_restore_model, make_and_restore_model_2, make_and_restore_model_3
from .datasets import ImageNet
from .datasetsFromOtherRepo import get_datasetSpecial, DATASETS
from .train_utilsFromOtherRepo import AverageMeter, accuracy, init_logfile, log

#%%
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('--batch', default=256, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
args = parser.parse_args()

#%%

def _imagenet(_dir) -> Dataset:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(_dir, transform)


def _imagenet_cp(_dir) -> Dataset:
    transform = transforms.Compose([
#        transforms.Resize(256),
#        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(_dir, transform)

#%%
    
def test(loader, model, checkpoint, bn, momentum, nb_iter):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    end = time.time()
    
    criterion = CrossEntropyLoss().cuda()
   
    model.eval()
#    print(model)

    for i, (inputs, targets) in enumerate(loader):
        data_time.update(time.time() - end)
        inputs, targets = inputs.cuda(), targets.cuda()

        if bn:
#            model.load_state_dict(checkpoint['state_dict'])
            model.train()
#            for _ in range(nb_iter):
#                _ = model(inputs)                           #-----------------> re-estimate parameters 
#            model.eval()

        outputs,_ = model(inputs)
        #print(outputs.shape)
        loss = criterion(outputs, targets)
        
        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 75 == 0:
            print('Iter: [{0}/{1}]   '
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})   '
#                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})   '
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})   '
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))
            sys.stdout.flush() 
            
    print('Final Loss  ({loss.avg:.4f})   '
          'Final Acc@1 ({top1.avg:.3f})   '
          'Final Acc@5 ({top5.avg:.3f})'.format(
                  loss=losses, top1=top1, top5=top5))
    return (losses.avg, top1.avg)

#%%

img_dir = '/imagenet/dataset/val/'

cp_dir = '/imagenet-c/'
cp_data = ['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 
           'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 
           'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 
           'snow', 'spatter', 'speckle_noise', 'zoom_blur']

#%%


def eval_cp(model_path, bn, save_path, momentum, nb_iter):
    #   print model information
    print('Model: ' + model_path + '   BN: '+ str(bn) + '   Momentum: ' + str(momentum))

    # load-- model (default model loading) ----------------------------------------------------
    model = make_and_restore_model_3(arch='resnet50', dataset= ds, momentum= momentum)
    print('Found model --------------------')
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])    
    model.eval()
    print('Found model weights --------------------')
    
    acc_list = []
    # clean data
    test_dataset =_imagenet(img_dir)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size= b_size, num_workers= workers, pin_memory=True)
    _, top1_acc = test(test_loader, model, checkpoint, bn, momentum, nb_iter)
    acc_list.append(top1_acc)

    # cp data
    for i in range(len(cp_data)):
        top1_acc = 0.
        print()
        for sev in range(1,6):
            print(cp_data[i] + '  sevearity:' + str(sev) + "   = = = = = = = = = = = = = = = = = = = = = = = = = = ")
            temp_dir = cp_dir + cp_data[i] + '/' + str(sev) + '/'
            test_dataset = _imagenet_cp(temp_dir)
            test_loader = DataLoader(test_dataset, shuffle=True, batch_size= b_size, num_workers= workers, pin_memory=True)
            _, cp1_acc = test(test_loader, model, checkpoint, bn, momentum, nb_iter)
            top1_acc += cp1_acc
        acc_list.append(top1_acc/5.)
        print(top1_acc/5.)
    acc = np.array(acc_list)
    np.savez(save_path, acc = acc)
    print('Saving:....')
    print(acc)
    print(np.mean(acc))
    sys.stdout.flush() 
    
#%%    

os.environ['IMAGENET_LOC_ENV'] = '/imagenet/'

#adv_path = './model/imagenet_linf_4_parallel.pt'
adv_path = './model/imagenet_linf_eps4_parallel.pt'

ds = ImageNet(adv_path)
b_size = 500
workers = 15


save_path = 'CP_linf_overfit_False.npz'
eval_cp(adv_path, False, save_path, momentum= 0.0, nb_iter= 0)

for i in range(3):
    save_path = 'CP_linf_overfit_True_mom_1_iter_1_rand_iter_'+ str(i) +'_.npz'
    eval_cp(adv_path, True, save_path, momentum= 1, nb_iter= 1, )


#%%
#adv_path = './model/imagenet_l2_3_0_parallel.pt'
adv_path = './model/imagenet_l2_eps765_parallel.pt'
ds = ImageNet(adv_path)

save_path = 'CP_l2_overfit_False.npz'
eval_cp(adv_path, False, save_path, momentum= 0.0, nb_iter= 0)


for i in range(3):
    save_path = 'CP_l2_overfit_True_mom_1_iter_1_rand_iter_'+ str(i) +'_.npz'
    eval_cp(adv_path, True, save_path, momentum= 1, nb_iter= 1)

print('============= Exit ===================')


