from helper import *
from torch import tensor, nn, optim
from model import xresnet18, xresnet50, xresnet101, xresnet100
from modelNAS import xresnet_fast18 as DAN18, xresnet_fast34 as DAN34, xresnet_fast50 as DAN50 ,xresnet_fast100 as DAN100, xresnet_fast101 as DAN101
from modelAdjoint import xresnet_fast18,xresnet_fast50,xresnet_fast101,xresnet_fast100
from adjointNetworkNAS import AdjointLoss as AdjointLossDAN, BlockwiseAdjointLoss
from adjointNetwork import AdjointLoss, TeacherStudentLoss

import torchattacks
from torchattacks import PGD

def load_model(model, state_dict_file_path=None):
    if state_dict_file_path is not None:
        model.load_state_dict(torch.load(state_dict_file_path))
    return model

def accuracy(model, val_dl, standard=True):
    model.eval()

    correct = 0
    total = 0

    for images, labels in val_dl:

        images = images.cuda()
        outputs = model(images)

        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels.cuda()).sum()
    
    if standard:
        print('Standard accuracy: %.2f %%' % (100 * float(correct) / total))
    else:
        print('Robust accuracy: %.2f %%' % (100 * float(correct) / total))

def dataset_resize(image_size,x): return x.view(-1, 3, image_size, image_size)

if __name__ == "__main__":
   data = load_cifar_data(64,32,100)
   train_dl, val_dl = data.train_dl, data.valid_dl
   data_resize = partial(dataset_resize,32)
   #model = xresnet50(c_out=100,resize=data_resize,compression_factor=1)
   model = xresnet_fast50(c_out=100, resize=data_resize, compression_factor=4)
   model = nn.DataParallel(model)
   #model = load_model(model, state_dict_file_path="/scratch/un270/model/Adjoint-Experiments/Oct2021/Individual-R50-Cifar100/239.pt")
   #model = load_model(model, state_dict_file_path="/scratch/un270/model/Adjoint-Experiments/Oct2021/AN-50-CIFAR-2/214.pt")
   model = load_model(model, state_dict_file_path="/scratch/un270/model/Adjoint-Experiments/Oct2021/AN-50-CIFAR-4/233.pt") 
   
   model = model.eval()

   atk = PGD(model, eps=2/255, alpha=2/255, steps=7)
   atk.set_return_type('int') # Save as integer.
   atk.save(data_loader=val_dl, save_path="data/cifar10_pgd.pt", verbose=True)

   adv_images, adv_labels = torch.load("data/cifar10_pgd.pt")
   adv_data = TensorDataset(adv_images.float()/255, adv_labels)
   adv_loader = DataLoader(adv_data, batch_size=128, shuffle=False)

   accuracy(model, val_dl, True)
   accuracy(model, adv_loader, False)
