import json
import torch
import torch.nn as nn
import torchvision
import numpy as np
from tqdm import tqdm

from utils import load_data, load_svhn_data
from attack_lib import LinfPGDAttack, L2PGDAttack
from models import ResNet18
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import datetime


def main():
    #MODEL_NAME = 'alpha-approx-tune-3.0'
    MODEL_NAME = 'alpha-approxrepr-tune-10.0'

    test_batch_size=100
    trainset, testset, trainloader, testloader, normalizer = load_data(test_batch_size=test_batch_size)
    _,_,_,transfer_loader,_ = load_svhn_data()

    model = ResNet18(normalizer)
    model = model.to('cuda')
    model.eval()
    model.transfer_linear = nn.Linear(model.linear.in_features,10).to('cuda')

    results = []
    #for i in range(1,401):
    #for i in range(1,301,5):
    for i in range(1,1600,5):
        model.load_state_dict(torch.load('./saved_model/%s/%s_%d.pth'%(MODEL_NAME,MODEL_NAME,i)))
        model.eval()
        # Source acc
        correct = 0
        total = 0
        with torch.no_grad(), tqdm(testloader) as pbar:
            for batch_idx, (x, y) in enumerate(pbar):
                x, y = x.to('cuda'), y.to('cuda')
                pred = model(x)

                _, pred_c = pred.max(1)
                total += y.size(0)
                correct += pred_c.eq(y).sum().item()
                pbar.set_description('Source acc:%.3f%%'%(100.*correct/total))
        src_acc = 100.*correct/total

        # Jac
        last_layer_norm = np.linalg.norm(model.linear.weight.detach().cpu().numpy(),2).item()
        #with tqdm(val_loader) as pbar:
        if True:
            pbar = testloader
            tot_jac_repr = 0.0
            tot_jac_all = 0.0
            tot_num = 0
            for x, y in pbar:
                x, y = x.to('cuda'), y.to('cuda')
                tot_num += len(x)
                x.requires_grad_()

                pred = model.calc_representation(x)
                rv = torch.FloatTensor(*pred.shape).cuda().normal_()
                rv = rv / rv.norm(dim=1,keepdim=True)
                l = (rv*pred).sum()
                l.backward()
                tot_jac_repr += (pred.shape[1]*x.grad).view(-1).norm()**2
                del x.grad

                pred = model(x)
                rv = torch.FloatTensor(*pred.shape).cuda().normal_()
                rv = rv / rv.norm(dim=1,keepdim=True)
                l = (rv*pred).sum()
                l.backward()
                tot_jac_all += (pred.shape[1]*x.grad).view(-1).norm()**2
                del x.grad

                torch.cuda.empty_cache()
                #pbar.set_description('Avg repr jac frob norm: %.4f, Avg all jac frob norm: %.4f'%(torch.sqrt(tot_jac_repr/tot_num),torch.sqrt(tot_jac_all/tot_num)))
                break
            tot_jac_repr = torch.sqrt(tot_jac_repr / tot_num).item()
            tot_jac_all = torch.sqrt(tot_jac_all / tot_num).item()

        # robust acc
        from attack_lib import L2PGDAttack, LinfPGDAttack
        adversary = L2PGDAttack(model, epsilon=0.5)
        total = 0
        adv_acc = 0
        #with torch.no_grad(), tqdm(val_loader) as pbar:
        #with val_loader as pbar:
        if True:
            pbar = testloader
            for x, y in pbar:
                x, y = x.to('cuda'), y.to('cuda')
                adv_x = adversary.perturb(x, y).detach()
                #adv_x = x

                adv_pred = model(adv_x)
                _, adv_pred_c = adv_pred.max(1)
                total += y.size(0)
                adv_acc += adv_pred_c.eq(y).sum()
                if total >= 500:
                    break
        adv_acc = (adv_acc/total).item()

        # Test
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad(), tqdm(transfer_loader) as pbar:
            for batch_idx, (x, y) in enumerate(pbar):
                x, y = x.to('cuda'), y.to('cuda')
                pred = model.transfer_linear(model.calc_representation(x))

                _, pred_c = pred.max(1)
                total += y.size(0)
                correct += pred_c.eq(y).sum().item()
                pbar.set_description('Acc:%.3f%%'%(100.*correct/total))

        cur_acc = 100.*correct/total
        print ("%d, source acc:%.4f, Jac:(%.4f,%.4f,%.4f), rob acc:%.4f, transfer acc %.4f @ %s"%(i, src_acc, last_layer_norm, tot_jac_repr, tot_jac_all, adv_acc, cur_acc, datetime.datetime.now()))
        results.append( (i,src_acc,last_layer_norm, tot_jac_repr, tot_jac_all, adv_acc, cur_acc) )
        #assert 0
    with open('figures/tune_%s.json'%MODEL_NAME,'w') as outf:
        json.dump(results,outf)



if __name__ == '__main__':
    main()
