import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

from utils import load_data
from attack_lib import LinfPGDAttack, L2PGDAttack
from models import ResNet18
import matplotlib.pyplot as plt

def main():
    #MODEL_NAME = 'naive'
    #MODEL_NAME = 'advT-l2-1.0000'
    #MODEL_NAME = 'advT-linf-0.0314'
    #MODEL_NAME = 'advT-l2-0.1000'
    #MODEL_NAME = 'alpha-10.0000'
    #MODEL_NAME = 'alpha-100.0000'
    #MODEL_NAME = 'alpha-10000.0000'
    #MODEL_NAME = 'alpha-out-100.0000'
    #MODEL_NAME = 'alpha-out-1000.0000'
    MODEL_NAME = 'alpha-approx-10.0000'
    #MODEL_NAME = 'alpha-approx-100.0000'
    #MODEL_NAME = 'alpha-approxrepr-10.0000'
    #MODEL_NAME = 'alpha-approxrepr-100.0000'

    test_batch_size=100
    trainset, testset, trainloader, testloader, normalizer = load_data(test_batch_size=test_batch_size)
    print (len(trainset), len(testset))

    model = ResNet18(normalizer)
    model = model.to('cuda')
    model.load_state_dict(torch.load('./saved_model/%s.pth'%MODEL_NAME))
    model.eval()

    print (np.linalg.norm(model.linear.weight.detach().cpu().numpy(),2))
    assert 0
    with tqdm(testloader) as pbar:
        tot_jac = 0.0
        tot_jac_repr = 0.0
        tot_num = 0
        for x, y in pbar:
            x, y = x.to('cuda'), y.to('cuda')
            x.requires_grad_()
            features = model.calc_representation(x)
            pred = model.linear(features)

            tot_num += len(x)

            jac = []
            for i in range(pred.shape[1]):
                x.grad = None
                pred[:,i].sum().backward(retain_graph=True)
                jac.append(x.grad)
            jac = torch.stack(jac, 1)
            jac_frob = jac.view(jac.shape[0],-1).norm(2,dim=1).mean()
            tot_jac += jac_frob.item() * len(x)

            jac = []
            for i in range(features.shape[1]):
                x.grad = None
                features[:,i].sum().backward(retain_graph=(i!=features.shape[1]-1))
                jac.append(x.grad)
            jac = torch.stack(jac, 1)
            jac_frob = jac.view(jac.shape[0],-1).norm(2,dim=1).mean()
            tot_jac_repr += jac_frob.item() * len(x)

            del x.grad
            pbar.set_description('Avg jac frob norm: %.4f; Avg repr jac frob norm: %.4f'%(tot_jac/tot_num, tot_jac_repr/tot_num))
            break


if __name__ == '__main__':
    main()
