import torch
import torch.nn as nn
import torchvision
import numpy as np
from tqdm import tqdm

from utils import load_data, calc_repr
from attack_lib import LinfPGDAttack, L2PGDAttack
from train_imagenet import accuracy
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.datasets as datasets


def main():
    ARCH = 'resnet18'
    #ARCH = 'wide_resnet50_2'

    TRAIN_MODE = True   # madry's setting
    #TRAIN_MODE = False   # our setting

    DATASET = 'cifar10'
    TRANS_ONLY = False
    #TRANS_ONLY = True
    #DATASET = 'flowers'
    #TRANS_ONLY = True
    #DATASET = 'dtd'
    #TRANS_ONLY = True
    #DATASET = 'tiny'
    #TRANS_ONLY = True
    #DATASET = 'cal101'
    #TRANS_ONLY = True

    #MODEL_NAME = 'naive'
    #MODEL_NAME = 'labelmap'
    #MODEL_NAME = 'wd0.001'
    #MODEL_NAME = 'wd0.000100_caug'
    #MODEL_NAME = 'perspective0.25'
    #MODEL_NAME = 'perspective'
    #MODEL_NAME = 'perspective0.75'
    #MODEL_NAME = 'mixup0.1'
    #MODEL_NAME = 'mixup0.2'
    #MODEL_NAME = 'mixup0.2_ratio0.5'
    #MODEL_NAME = 'cifar10_mixup0.1'
    #MODEL_NAME = 'cifar10_mixup0.1_nolab'
    #MODEL_NAME = 'cifar10_mixup0.2'
    #MODEL_NAME = 'cifar10_mixup0.2_nolab'
    #MODEL_NAME = 'affine15'
    #MODEL_NAME = 'wd0.0001_affine'
    #MODEL_NAME = 'affine45'
    #MODEL_NAME = 'wd0.0001_caug'
    #MODEL_NAME = 'rotate5'
    #MODEL_NAME = 'rotate10'
    #MODEL_NAME = 'wd0.0001_rotate'
    #MODEL_NAME = 'rotate30'
    #MODEL_NAME = 'rotate45'
    #MODEL_NAME = 'wd0.0001_erase'
    #MODEL_NAME = 'erase0.33'
    #MODEL_NAME = 'erase0.5'
    #MODEL_NAME = 'wd0.0001_equalize'
    #MODEL_NAME = 'posterize1'
    #MODEL_NAME = 'wd0.0001_posterize'
    #MODEL_NAME = 'posterize3'
    #MODEL_NAME = 'posterize4'
    #MODEL_NAME = 'blur3'
    #MODEL_NAME = 'wd0.0001_blur'
    #MODEL_NAME = 'blur7'
    #MODEL_NAME = 'blur9'
    #MODEL_NAME = 'wd0.000010'
    #MODEL_NAME = 'wd0.000001'
    #MODEL_NAME = 'wd1e-06'
    #MODEL_NAME = 'wd0.000001_caug'
    #MODEL_NAME = 'gauss0.0125'
    #MODEL_NAME = 'gauss0.025'
    #MODEL_NAME = 'gauss0.05'
    #MODEL_NAME = 'gauss0.125'
    #MODEL_NAME = 'gauss0.25'
    

    #MODEL_NAME = 'rescale112'
    #MODEL_NAME = 'rescale56'
    #MODEL_NAME = 'wd0.0001_caug_32'
    #MODEL_NAME = 'rescale28'
    #MODEL_NAME = 'rescale14'

    #MODEL_NAME = 'partial0.05'
    #MODEL_NAME = 'partial0.05_caug'

    #MODEL_NAME = 'advt'

    #MODEL_NAME = 'reglast_100.0'
    #MODEL_NAME = 'reglast_30.0'
    #MODEL_NAME = 'reglast_20.0'
    #MODEL_NAME = 'reglast_10.0'
    #MODEL_NAME = 'reglast_10.0_wd1e-06'
    #MODEL_NAME = 'reglast_6.0'
    #MODEL_NAME = 'reglast_5.0'
    #MODEL_NAME = 'reglast_4.0'
    MODEL_NAME = 'reglast_3.0'
    #MODEL_NAME = 'reglast_2.0'
    #MODEL_NAME = 'reglast_1.0'
    #MODEL_NAME = 'reglast_0.5'
    #MODEL_NAME = 'reglast_0.1'
    #MODEL_NAME = 'reglast_0.01'

    #MODEL_NAME = 'jacALL_0.1'
    #MODEL_NAME = 'jacALL_0.01'
    #MODEL_NAME = 'jacALL_0.003'
    #MODEL_NAME = 'jacALL_0.001'
    #MODEL_NAME = 'jacALL_0.0001'
    #MODEL_NAME = 'jacREPR_0.001'
    #MODEL_NAME = 'jacREPR_0.01'
    #MODEL_NAME = 'jacREPR_0.1'
    #MODEL_NAME = 'jacREPR_1.0'
    #MODEL_NAME = 'jacLAST_0.003'
    #MODEL_NAME = 'jacLAST_0.01'
    #MODEL_NAME = 'jacLAST_0.03'
    #MODEL_NAME = 'jacLAST_0.1'
    #MODEL_NAME = 'jacLAST_1.0'
    #MODEL_NAME = 'jacLAST_-0.0002'
    #MODEL_NAME = 'jacLAST_-0.01'
    #MODEL_NAME = 'jacLAST_-0.1'
    #MODEL_NAME = 'jacLAST_-0.3'

    #MODEL_NAME = 'reglast20.0_jac0.001'
    #MODEL_NAME = 'reglast30.0_jac0.001'
    #MODEL_NAME = 'reglast30.0_jac0.01'

    #MODEL_NAME = 'madry_naive'
    #MODEL_NAME = 'madry_advt'

    ###
    #model = torchvision.models.resnet18()
    #from models.ONI_lib import ONI_Linear
    #model.fc = ONI_Linear(model.fc.in_features, model.fc.out_features, scale=30).to('cuda')

    #model.load_state_dict(torch.load('saved_model/imagenet_%s_tmp.pth.tar'%MODEL_NAME, map_location='cuda:0')['state_dict'])
    #w = model.fc.normed_weight()
    #model.fc = nn.Linear(model.fc.in_features, model.fc.out_features).to('cuda')
    #model.fc.weight.data = w
    #torch.save(model.state_dict(), './saved_model/imagenet_%s.pth'%MODEL_NAME)
    #assert 0
    ###


    if ARCH == 'resnet18':
        model = torchvision.models.resnet18()
    elif ARCH == 'wide_resnet50_2':
        model = torchvision.models.wide_resnet50_2()
        MODEL_NAME = MODEL_NAME + '_' + ARCH
    print (MODEL_NAME, DATASET, TRANS_ONLY, TRAIN_MODE)
    model = model.to('cuda')
    if MODEL_NAME.startswith('madry'):
        ckpt = torch.load('./saved_model/%s.ckpt'%MODEL_NAME)
        #print (ckpt['model']['module.normalizer.new_mean'])
        #print (ckpt['model']['module.normalizer.new_std'])
        #assert 0
        state_dict = {}
        for name, p in ckpt['model'].items():
            if name.startswith('module.model.'):
                state_dict[name[13:]] = p
        model.load_state_dict(state_dict)
    else:
        model.load_state_dict(torch.load('./saved_model/imagenet_%s.pth'%MODEL_NAME, map_location='cuda:0'))
    print (model.fc)
    model.eval()

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder('/home/xiaojun/imagenet/val/', transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        ])),
        batch_size=64, shuffle=False,
        num_workers=16, pin_memory=True)

    if not TRANS_ONLY:
        ### Eval jac
        last_layer_norm = np.linalg.norm(model.fc.weight.detach().cpu().numpy(),2)
        print (last_layer_norm)

        ####
        #a = torch.FloatTensor(5,3,4).normal_()
        #print (torch.linalg.norm(a,ord=2,dim=(1,2)))
        #print (torch.linalg.norm(a,ord='fro',dim=(1,2)))
        #print (torch.linalg.norm(a,ord=2,dim=2))
        #print (torch.linalg.norm(a,ord='fro',dim=2))
        #assert 0
        ####
        #tot_jac = 0.0
        with tqdm(val_loader) as pbar:
            tot_jac = 0.0
            tot_feat_norm = 0.0
            tot_num = 0
            for x, y in pbar:
                x, y = x.to('cuda'), y.to('cuda')
                x.requires_grad_()
                #pred = model(x)
                pred = calc_repr(model, x)
                tot_feat_norm += (pred.view(pred.shape[0],-1).norm(2,dim=1).mean())
                tot_num += len(x)

                rv = torch.FloatTensor(*pred.shape).cuda().normal_()
                rv = rv / rv.norm(dim=1,keepdim=True)
                l = (rv*pred).sum()
                l.backward()
                tot_jac += (pred.shape[1]*x.grad).view(-1).norm()**2
                #jac = []
                #for i in range(pred.shape[1]):
                #    x.grad = None
                #    pred[:,i].sum().backward()
                #    jac.append(x.grad)
                #jac = torch.stack(jac, 1)
                #jac_frob = jac.view(jac.shape[0],-1).norm(2,dim=1).mean()
                #tot_jac += jac_frob.item() * len(x)

                del x.grad
                torch.cuda.empty_cache()
                pbar.set_description('Avg feat norm: %.4f; Avg jac frob norm: %.4f'%(tot_feat_norm/tot_num,torch.sqrt(tot_jac/tot_num)))
                break
            tot_feat_norm = (tot_feat_norm/tot_num).item()
            tot_jac = torch.sqrt(tot_jac / tot_num).item()
        print (tot_feat_norm,tot_jac)

        ### Eval attack
        from attack_lib import L2PGDAttack, LinfPGDAttack
        #adversary = L2PGDAttack(model, epsilon=3.0)
        adversary = L2PGDAttack(model, epsilon=0.25)
        #adversary = L2PGDAttack(model, epsilon=1e-8, num_steps=0)
        #adversary = L2PGDAttack(model, num_steps=3, epsilon=1.0, alpha=2.0/3.0)
        #adversary = LinfPGDAttack(model, epsilon=4./255.)
        total = 0
        adv_acc1 = 0
        adv_acc5 = 0
        with torch.no_grad(), tqdm(val_loader) as pbar:
            for x, y in pbar:
                x, y = x.to('cuda'), y.to('cuda')
                adv_x = adversary.perturb(x, y, normalize=((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))).detach()
                #adv_x = x

                adv_pred = model(adv_x)
                acc1, acc5 = accuracy(adv_pred, y, topk=(1, 5))
                total += y.size(0)
                adv_acc1 += acc1*y.size(0)
                adv_acc5 += acc5*y.size(0)
                pbar.set_description('Adv acc1: %.3f; adv acc5:%.3f'%(adv_acc1/total, adv_acc5/total))
                #break
        adv_acc1 = (adv_acc1/total).item()
        adv_acc5 = (adv_acc5/total).item()
        #assert 0

    ### Eval transfer
    if DATASET == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])

        trainset = torchvision.datasets.CIFAR10(
            root='./raw_data', train=True, download=True, transform=transform_train)
        #trainloader = torch.utils.data.DataLoader(
        #    trainset, batch_size=128, shuffle=True, num_workers=2)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=64, shuffle=True, num_workers=2)
        testset = torchvision.datasets.CIFAR10(
            root='./raw_data', train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=100, shuffle=False, num_workers=2)
        out_dim = 10
    elif DATASET == 'dtd':
        from dtd_dataset import DTD
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        #trainset = torchvision.datasets.ImageFolder('./raw_data/flowers_new/train', transform=transform_train)
        trainset = DTD(train=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=16)
        #testset = torchvision.datasets.ImageFolder('./raw_data/flowers_new/test', transform=transform_test)
        testset = DTD(train=False, transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=16)
        out_dim = 47
    elif DATASET == 'flowers':
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        trainset = torchvision.datasets.ImageFolder('./raw_data/flowers_new/train', transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=16)
        testset = torchvision.datasets.ImageFolder('./raw_data/flowers_new/test', transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=16)
        out_dim = 102
    elif DATASET == 'tiny':
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        trainset = torchvision.datasets.ImageFolder('./raw_data/tiny-imagenet-200/train', transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=16)
        testset = torchvision.datasets.ImageFolder('./raw_data/tiny-imagenet-200/val/organized', transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=16)
        out_dim = 200
    elif DATASET == 'cal101':
        class TransformedDataset(torch.utils.data.Dataset):
            def __init__(self, ds, transform=None):
                self.transform = transform
                self.ds = ds

            def __len__(self):
                return len(self.ds)

            def __getitem__(self, idx):
                sample, label = self.ds[idx]
                if self.transform:
                    sample = self.transform(sample)
                    if sample.shape[0] == 1:
                        sample = sample.repeat(3,1,1)
                return sample, label

        from caltech import Caltech101
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        dataset = Caltech101('./raw_data/caltech101')
        NUM_TRAINING_SAMPLES_PER_CLASS = 30
        class_start_idx = [0]+ [i for i in np.arange(1, len(dataset)) if dataset.y[i]==dataset.y[i-1]+1]
        train_indices = sum([np.arange(start_idx,start_idx + NUM_TRAINING_SAMPLES_PER_CLASS).tolist() for start_idx in class_start_idx],[])
        test_indices = list((set(np.arange(1, len(dataset))) - set(train_indices) ))
        train_set = torch.utils.data.Subset(dataset, train_indices)
        test_set = torch.utils.data.Subset(dataset, test_indices)
        train_set = TransformedDataset(train_set, transform=transform_train) 
        test_set = TransformedDataset(test_set, transform=transform_test)

        trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=16)
        testloader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=16)
        out_dim = 101

    model.fc = nn.Linear(model.fc.in_features, out_dim).to('cuda')
    # Include normalize in the model for the convenience of adv attack
    from advertorch.utils import NormalizeByChannelMeanStd
    #normalizer = NormalizeByChannelMeanStd(mean=torch.tensor([0.4914, 0.4822, 0.4465], dtype=torch.float32).cuda(), std=torch.tensor([0.2023, 0.1994, 0.2010], dtype=torch.float32).cuda())
    normalizer = NormalizeByChannelMeanStd(mean=torch.tensor([0.485,0.456,0.406], dtype=torch.float32).cuda(), std=torch.tensor([0.229,0.224,0.225], dtype=torch.float32).cuda())
    model = nn.Sequential(normalizer, model)

    criterion = nn.CrossEntropyLoss()
    #optimizer = torch.optim.SGD(model[1].fc.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    #optimizer = torch.optim.SGD(model[1].fc.parameters(), lr=0.1, momentum=0.9, weight_decay=0)
    optimizer = torch.optim.SGD(model[1].fc.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,30], gamma=0.1)

    best_acc = 0.0
    for epoch in range(40):
        # Train
        print ("Epoch %d"%epoch)
        if TRAIN_MODE:
            model.train()
        else:
            model.eval()
        train_loss = 0
        correct = 0
        total = 0
        with tqdm(trainloader) as pbar:
            for batch_idx, (x, y) in enumerate(pbar):
                x, y = x.to('cuda'), y.to('cuda')
                optimizer.zero_grad()
                pred = model(x)
                loss = criterion(pred, y)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, pred_c = pred.max(1)
                total += y.size(0)
                correct += pred_c.eq(y).sum().item()
                pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), 100.*correct/total))
        scheduler.step()

        # Test
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad(), tqdm(testloader) as pbar:
            for batch_idx, (x, y) in enumerate(pbar):
                x, y = x.to('cuda'), y.to('cuda')
                pred = model(x)
                loss = criterion(pred, y)

                test_loss += loss.item()
                _, pred_c = pred.max(1)
                total += y.size(0)
                correct += pred_c.eq(y).sum().item()
                pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))
        cur_acc = 100.*correct/total

        # Save
        if cur_acc > best_acc:
            best_acc = cur_acc
            torch.save(model.state_dict(), './saved_model/%s-transfer-%s.pth'%(MODEL_NAME, DATASET))

    if not TRANS_ONLY:
        print ('%.4f | %.2f | %.2f | %.2f | %.2f | %.2f'%(tot_feat_norm, last_layer_norm, tot_jac, adv_acc1, adv_acc5, best_acc))
    else:
        print (best_acc)


if __name__ == '__main__':
    main()
