import json
import torch
import torch.nn as nn
import torchvision
import numpy as np
from tqdm import tqdm

from utils import load_data, calc_repr
from attack_lib import LinfPGDAttack, L2PGDAttack
from train_imagenet import accuracy
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import datetime


def main():
    ARCH = 'resnet18'
    #ARCH = 'wide_resnet50_2'

    TRAIN_MODE = True   # madry's setting
    #TRAIN_MODE = False   # our setting

    DATASET = 'cifar10'
    TRANS_ONLY = False

    #MODEL_NAME = 'jacREPR_0.001'
    MODEL_NAME = 'jacALL_0.0001'
    #MODEL_NAME = 'jacLAST_0.003'

    if ARCH == 'resnet18':
        model = torchvision.models.resnet18()
    elif ARCH == 'wide_resnet50_2':
        model = torchvision.models.wide_resnet50_2()
        MODEL_NAME = MODEL_NAME + '_' + ARCH
    print (MODEL_NAME, DATASET, TRANS_ONLY, TRAIN_MODE)
    model = model.to('cuda')
    model.eval()

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder('/home/xiaojun/imagenet/val/', transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        ])),
        batch_size=64, shuffle=False,
        num_workers=16, pin_memory=True)
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ])
    testset = torchvision.datasets.CIFAR10(
        root='./raw_data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)
    model.transfer_fc = nn.Linear(model.fc.in_features,10).to('cuda')


    criterion = nn.CrossEntropyLoss()

    def transfer_forward(model, x):
        if ARCH == 'resnet18':
            x = model.conv1(x)
            x = model.bn1(x)
            x = model.relu(x)
            x = model.maxpool(x)

            x = model.layer1(x)
            x = model.layer2(x)
            x = model.layer3(x)
            x = model.layer4(x)

            x = model.avgpool(x)
            x = torch.flatten(x, 1)
            x = model.transfer_fc(x)
            return x

    results = []
    #for i in range(1,401):
    #for i in range(1,301,5):
    for i in range(50,301):
        model.load_state_dict(torch.load('./saved_model/imagenet_%s_tuning_ckpt/%d.pth.tar'%(MODEL_NAME,i), map_location='cuda:0')['state_dict'])
        model.eval()
        # Source acc
        correct = 0
        total = 0
        with torch.no_grad(), tqdm(val_loader) as pbar:
            for batch_idx, (x, y) in enumerate(pbar):
                x, y = x.to('cuda'), y.to('cuda')
                pred = model(x)

                _, pred_c = pred.max(1)
                total += y.size(0)
                correct += pred_c.eq(y).sum().item()
                pbar.set_description('Source acc:%.3f%%'%(100.*correct/total))
        src_acc = 100.*correct/total

        # Jac
        last_layer_norm = np.linalg.norm(model.fc.weight.detach().cpu().numpy(),2).item()
        #with tqdm(val_loader) as pbar:
        if True:
            pbar = val_loader
            tot_jac_repr = 0.0
            tot_jac_all = 0.0
            tot_num = 0
            for x, y in pbar:
                x, y = x.to('cuda'), y.to('cuda')
                tot_num += len(x)
                x.requires_grad_()

                pred = calc_repr(model, x)
                rv = torch.FloatTensor(*pred.shape).cuda().normal_()
                rv = rv / rv.norm(dim=1,keepdim=True)
                l = (rv*pred).sum()
                l.backward()
                tot_jac_repr += (pred.shape[1]*x.grad).view(-1).norm()**2
                del x.grad

                pred = model(x)
                rv = torch.FloatTensor(*pred.shape).cuda().normal_()
                rv = rv / rv.norm(dim=1,keepdim=True)
                l = (rv*pred).sum()
                l.backward()
                tot_jac_all += (pred.shape[1]*x.grad).view(-1).norm()**2
                del x.grad

                torch.cuda.empty_cache()
                #pbar.set_description('Avg repr jac frob norm: %.4f, Avg all jac frob norm: %.4f'%(torch.sqrt(tot_jac_repr/tot_num),torch.sqrt(tot_jac_all/tot_num)))
                break
            tot_jac_repr = torch.sqrt(tot_jac_repr / tot_num).item()
            tot_jac_all = torch.sqrt(tot_jac_all / tot_num).item()

        # robust acc
        from attack_lib import L2PGDAttack, LinfPGDAttack
        adversary = L2PGDAttack(model, epsilon=0.25)
        total = 0
        adv_acc1 = 0
        adv_acc5 = 0
        #with torch.no_grad(), tqdm(val_loader) as pbar:
        #with val_loader as pbar:
        if True:
            pbar = val_loader
            for x, y in pbar:
                x, y = x.to('cuda'), y.to('cuda')
                adv_x = adversary.perturb(x, y, normalize=((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))).detach()
                #adv_x = x

                adv_pred = model(adv_x)
                acc1, acc5 = accuracy(adv_pred, y, topk=(1, 5))
                total += y.size(0)
                adv_acc1 += acc1*y.size(0)
                adv_acc5 += acc5*y.size(0)
                #pbar.set_description('Adv acc1: %.3f; adv acc5:%.3f'%(adv_acc1/total, adv_acc5/total))
                if total >= 500:
                    break
        adv_acc1 = (adv_acc1/total).item()
        adv_acc5 = (adv_acc5/total).item()

        # Test
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad(), tqdm(testloader) as pbar:
            for batch_idx, (x, y) in enumerate(pbar):
                x, y = x.to('cuda'), y.to('cuda')
                pred = transfer_forward(model, x)
                loss = criterion(pred, y)

                test_loss += loss.item()
                _, pred_c = pred.max(1)
                total += y.size(0)
                correct += pred_c.eq(y).sum().item()
                pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

        cur_acc = 100.*correct/total
        print ("%d, source acc:%.4f, Jac:(%.4f,%.4f,%.4f), rob acc:%.4f, transfer acc %.4f @ %s"%(i, src_acc, last_layer_norm, tot_jac_repr, tot_jac_all, adv_acc1, cur_acc, datetime.datetime.now()))
        results.append( (i,src_acc,last_layer_norm, tot_jac_repr, tot_jac_all, adv_acc1, cur_acc) )
        assert 0
    with open('figures/tune_%s.json'%MODEL_NAME,'w') as outf:
        json.dump(results,outf)



if __name__ == '__main__':
    main()
