import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

from utils import load_stl_data, load_svhn_data, load_pets_data, load_flowers_data
from attack_lib import LinfPGDAttack, L2PGDAttack
from models import ResNet18, ResNet50, ResNet152, CIFAR_CNN
import matplotlib.pyplot as plt

def main():
    #DATASET = 'stl'
    DATASET = 'svhn'
    #DATASET = 'pets'
    #DATASET = 'flowers'

    #MODEL_NAME = 'naive'
    #MODEL_NAME = 'advT-l2-1.0000'
    #MODEL_NAME = 'advT-linf-0.0314'
    #MODEL_NAME = 'advTaug-linf-0.0314'
    #MODEL_NAME = 'advT-l2-0.1000'
    #MODEL_NAME = 'alpha-10.0000'
    #MODEL_NAME = 'alpha-100.0000'
    #MODEL_NAME = 'alpha-10000.0000'
    #MODEL_NAME = 'alpha-out-100.0000'
    #MODEL_NAME = 'alpha-out-1000.0000'
    #MODEL_NAME = 'alpha-approx-3.0000'
    #MODEL_NAME = 'alpha-approx-6.0000'
    #MODEL_NAME = 'alpha-approx-10.0000'
    #MODEL_NAME = 'alpha-approx-30.0000'
    #MODEL_NAME = 'alpha-approx-60.0000'
    #MODEL_NAME = 'alpha-approx-100.0000'
    #MODEL_NAME = 'alpha-approx-10.0000-adv-l2-1.0000'
    #MODEL_NAME = 'alpha-approx-100.0000-adv-l2-1.0000'
    #MODEL_NAME = 'alpha-approxrepr-3.0000'
    #MODEL_NAME = 'alpha-approxrepr-6.0000'
    #MODEL_NAME = 'alpha-approxrepr-10.0000'
    #MODEL_NAME = 'alpha-approxrepr-30.0000'
    #MODEL_NAME = 'alpha-approxrepr-60.0000'
    #MODEL_NAME = 'alpha-approxrepr-100.0000'
    #MODEL_NAME = 'alpha-approxrepr-3.0000-wd1e-06'
    MODEL_NAME = 'alpha-approxrepr-10.0000-wd1e-06'
    #MODEL_NAME = 'alpha-last-1.0000'
    #MODEL_NAME = 'alpha-last-0.3000'
    #MODEL_NAME = 'alpha-last-0.1000'
    #MODEL_NAME = 'alpha-last-0.0300'
    #MODEL_NAME = 'alpha-last--0.1000'
    #MODEL_NAME = 'alpha-last--0.3000'
    #MODEL_NAME = 'gauss-0.25'
    #MODEL_NAME = 'advTrepr-l2-1.0000'
    #MODEL_NAME = 'advTreprSSL-l2-1.0000'

    #MODEL_NAME = 'gauss-1.0'
    #MODEL_NAME = 'gauss-0.5'
    #MODEL_NAME = 'gauss-0.25'
    #MODEL_NAME = 'gauss-0.125'
    #MODEL_NAME = 'gauss-0.05'
    #MODEL_NAME = 'gauss-0.025'

    #MODEL_NAME = 'simclr'
    #MODEL_NAME = 'simclradv'
    #MODEL_NAME = 'simclradv-l2-0.1000'
    #MODEL_NAME = 'simclradv-l2-1.0000'
    #MODEL_NAME = 'simclradv-l2-0.1000-lamda0.2500'
    #MODEL_NAME = 'simclradv-l2-0.1000-lamda0.7500'
    #MODEL_NAME = 'reglast_0.01'
    #MODEL_NAME = 'reglast_0.03'
    #MODEL_NAME = 'reglast_0.1'
    #MODEL_NAME = 'reglast_0.3'
    #MODEL_NAME = 'reglast_1.0'
    #MODEL_NAME = 'reglast_3.0'
    #MODEL_NAME = 'reglast_10.0'
    #MODEL_NAME = 'reglast_30.0'
    #MODEL_NAME = 'reglast_100.0'
    #MODEL_NAME = 'reglast_0.1-adv-l2-1.0000'
    #MODEL_NAME = 'reglast_10.0-adv-l2-1.0000'
    #MODEL_NAME = 'rocl-linf-0.0157'
    #MODEL_NAME = 'rocl-linf-0.0314'
    #MODEL_NAME = 'rocl-linf-0.0627'
    #MODEL_NAME = 'rocl-linf-0.0314-lamda0.0000'
    #MODEL_NAME = 'rocl-linf-0.0314-lamda0.0001'
    #MODEL_NAME = 'rocl-linf-0.0314-lamda0.1250'
    #MODEL_NAME = 'rocl-linf-0.0314-lamda0.2500'
    #MODEL_NAME = 'rocl-linf-0.0314-lamda0.7500'
    #MODEL_NAME = 'rocl-linf-0.0314-lamda0.9999'
    #MODEL_NAME = 'roclweak-linf-0.0314'

    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.0000-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.0000-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.0000-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.5000-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.5000-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.5000-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.7500-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.7500-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.7500-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.8750-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.8750-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.8750-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda1.0000-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda1.0000-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda1.0000-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.0000-simplesimple'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.5000-simplesimple'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.7500-simplesimple'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda1.0000-simplesimple'

    #MODEL_NAME = 'mixing-None-0.0000-lamda0.0000-complexcomplex'
    #MODEL_NAME = 'mixing-None-0.0000-lamda0.5000-complexcomplex'
    #MODEL_NAME = 'mixing-None-0.0000-lamda1.0000-complexcomplex'
    #MODEL_NAME = 'mixing-None-0.0000-lamda0.0000-simplesimple'
    #MODEL_NAME = 'mixing-None-0.0000-lamda0.5000-simplesimple'
    #MODEL_NAME = 'mixing-None-0.0000-lamda1.0000-simplesimple'
    
    #MODEL_NAME = 'debug1'
    #MODEL_NAME = 'debug2'
    #MODEL_NAME = 'debug3'
    #MODEL_NAME = 'debug4'
    #MODEL_NAME = 'debug5'
    #MODEL_NAME = 'debug6'
    #MODEL_NAME = 'debug7'
    #MODEL_NAME = 'debug8'
    #MODEL_NAME = 'debug9'
    #MODEL_NAME = 'debug10'
    #MODEL_NAME = 'debug11'
    #MODEL_NAME = 'debug12'
    #MODEL_NAME = 'debug13'

    #MODEL_NAME = 'weightdecay-0.0005'
    #MODEL_NAME = 'weightdecay-0.0001'
    #MODEL_NAME = 'weightdecay-1e-05'
    #MODEL_NAME = 'weightdecay-1e-06'
    #MODEL_NAME = 'weightdecay-0.0005-lastdiff'
    #MODEL_NAME = 'weightdecay-0.0001-lastdiff'
    #MODEL_NAME = 'weightdecay-1e-05-lastdiff'
    #MODEL_NAME = 'weightdecay-1e-06-lastdiff'
    #MODEL_NAME = 'weightdecay-0.0005-dropout0.2'
    #MODEL_NAME = 'weightdecay-0.0005-dropout0.5'
    #MODEL_NAME = 'weightdecay-1e-06-dropout0.2'
    #MODEL_NAME = 'weightdecay-1e-06-dropout0.5'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg0.001'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg0.0001'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg1e-05'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg1e-06'
    #MODEL_NAME = 'weightdecay-0.0005-resnet50'
    #MODEL_NAME = 'weightdecay-1e-06-resnet50'
    #MODEL_NAME = 'weightdecay-0.0005-resnet152'
    #MODEL_NAME = 'weightdecay-1e-06-resnet152'
    #MODEL_NAME = 'weightdecay-0.0005-cifarcnn'
    #MODEL_NAME = 'weightdecay-1e-06-cifarcnn'

    #MODEL_NAME = 'aug_standard'
    #MODEL_NAME = 'aug_rotate5'
    #MODEL_NAME = 'aug_rotate10'
    #MODEL_NAME = 'aug_rotate15'
    #MODEL_NAME = 'aug_rotate30'
    #MODEL_NAME = 'aug_rotate45'
    #MODEL_NAME = 'aug_erase0.2'
    #MODEL_NAME = 'aug_erase0.33'
    #MODEL_NAME = 'aug_perspective0.25'
    #MODEL_NAME = 'aug_perspective0.5'
    #MODEL_NAME = 'aug_affine15'
    #MODEL_NAME = 'aug_affine30'
    #MODEL_NAME = 'aug_posterize2'
    #MODEL_NAME = 'aug_posterize3'
    #MODEL_NAME = 'aug_posterize4'

    if DATASET == 'stl':
        trainset, testset, trainloader, testloader, normalizer = load_stl_data()
        out_dim = 10
    elif DATASET == 'svhn':
        trainset, testset, trainloader, testloader, normalizer = load_svhn_data()
        out_dim = 10
    elif DATASET == 'pets':
        trainset, testset, trainloader, testloader, normalizer = load_pets_data()
        out_dim = 37
    elif DATASET == 'flowers':
        trainset, testset, trainloader, testloader, normalizer = load_flowers_data()
        out_dim = 102
    print (MODEL_NAME, len(trainset), len(testset))
    #assert 0

    if 'dropout0.2' in MODEL_NAME:
        dropout = 0.2
    elif 'dropout0.5' in MODEL_NAME:
        dropout = 0.5
    else:
        assert 'dropout' not in MODEL_NAME
        dropout = 0.0
    if 'resnet50' in MODEL_NAME:
        model = ResNet50(normalizer, dropout)
    elif 'resnet152' in MODEL_NAME:
        model = ResNet152(normalizer, dropout)
    elif 'cifarcnn' in MODEL_NAME:
        model = CIFAR_CNN(normalizer, dropout)
    else:
        model = ResNet18(normalizer, dropout)
    model = model.to('cuda')
    model.load_state_dict(torch.load('./saved_model/%s.pth'%MODEL_NAME))
    print (model.linear)
    model.linear = nn.Linear(model.linear.in_features, out_dim).to('cuda')
    model.eval()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.linear.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    #optimizer = torch.optim.SGD(model.linear.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,30], gamma=0.1)

    best_acc = 0.0
    for epoch in range(40):
        # Train
        print ("Epoch %d"%epoch)
        model.eval()
        #model.train()
        train_loss = 0
        correct = 0
        total = 0
        with tqdm(trainloader) as pbar:
            for batch_idx, (x, y) in enumerate(pbar):
                x, y = x.to('cuda'), y.to('cuda')
                optimizer.zero_grad()
                pred = model(x)
                loss = criterion(pred, y)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, pred_c = pred.max(1)
                total += y.size(0)
                correct += pred_c.eq(y).sum().item()
                pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), 100.*correct/total))
        scheduler.step()

        # Test
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad(), tqdm(testloader) as pbar:
            for batch_idx, (x, y) in enumerate(pbar):
                x, y = x.to('cuda'), y.to('cuda')
                pred = model(x)
                loss = criterion(pred, y)

                test_loss += loss.item()
                _, pred_c = pred.max(1)
                total += y.size(0)
                correct += pred_c.eq(y).sum().item()
                pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))
        cur_acc = 100.*correct/total

        # Save
        if cur_acc > best_acc:
            best_acc = cur_acc
            torch.save(model.state_dict(), './saved_model/%s-transfer-%s.pth'%(MODEL_NAME, DATASET))


if __name__ == '__main__':
    main()
