import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

from utils import load_data, load_svhn_data
from attack_lib import LinfPGDAttack, L2PGDAttack
from models import ResNet18, ResNet50, ResNet152
import matplotlib.pyplot as plt

def main():
    #MODEL_NAME = 'naive'

    #MODEL_NAME = 'weightdecay-0.0005'
    #MODEL_NAME = 'weightdecay-0.0001'
    MODEL_NAME = 'weightdecay-1e-05'
    #MODEL_NAME = 'weightdecay-1e-06'
    #MODEL_NAME = 'weightdecay-0.0005-lastdiff'
    #MODEL_NAME = 'weightdecay-0.0001-lastdiff'
    #MODEL_NAME = 'weightdecay-1e-05-lastdiff'
    #MODEL_NAME = 'weightdecay-1e-06-lastdiff'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg0.001'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg0.0001'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg1e-05'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg1e-06'
    #MODEL_NAME = 'weightdecay-0.0005-dropout0.2'
    #MODEL_NAME = 'weightdecay-0.0005-dropout0.5'
    #MODEL_NAME = 'weightdecay-1e-06-dropout0.2'
    #MODEL_NAME = 'weightdecay-1e-06-dropout0.5'
    #MODEL_NAME = 'weightdecay-0.0005-resnet50'
    #MODEL_NAME = 'weightdecay-1e-06-resnet50'
    #MODEL_NAME = 'weightdecay-0.0005-resnet152'
    #MODEL_NAME = 'weightdecay-1e-06-resnet152'

    test_batch_size=100
    criterion = nn.CrossEntropyLoss()

    trainset, testset, trainloader, testloader, normalizer = load_data(test_batch_size=test_batch_size)
    print (MODEL_NAME, len(trainset), len(testset))

    if 'dropout0.2' in MODEL_NAME:
        dropout = 0.2
    elif 'dropout0.5' in MODEL_NAME:
        dropout = 0.5
    else:
        assert 'dropout' not in MODEL_NAME
        dropout = 0.0
    if 'resnet50' in MODEL_NAME:
        model = ResNet50(normalizer, dropout)
    elif 'resnet152' in MODEL_NAME:
        model = ResNet152(normalizer, dropout)
    else:
        model = ResNet18(normalizer, dropout)
    model = model.to('cuda')
    model.load_state_dict(torch.load('./saved_model/%s.pth'%MODEL_NAME))
    model.eval()
    correct = 0
    tot_loss = 0
    total = 0
    with torch.no_grad(), tqdm(trainloader) as pbar:
        for x, y in pbar:
            x, y = x.to('cuda'), y.to('cuda')

            pred = model(x)
            _, pred_c = pred.max(1)
            total += y.size(0)
            tot_loss += criterion(pred,y).item() * y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.4f; Acc: %.3f'%(tot_loss/total, 100.*correct/total))

    _, _, transfer_trainloader, _, _ = load_svhn_data()
    model.load_state_dict(torch.load('./saved_model/%s-transfer-svhn.pth'%MODEL_NAME))
    model.eval()
    correct_transfer = 0
    tot_loss_transfer = 0
    total_transfer = 0
    with torch.no_grad(), tqdm(transfer_trainloader) as pbar:
        for x, y in pbar:
            x, y = x.to('cuda'), y.to('cuda')

            pred = model(x)
            _, pred_c = pred.max(1)
            total_transfer += y.size(0)
            tot_loss_transfer += criterion(pred,y).item() * y.size(0)
            correct_transfer += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.4f; Acc: %.3f'%(tot_loss_transfer/total_transfer, 100.*correct_transfer/total_transfer))

    print (tot_loss/total, tot_loss_transfer/total_transfer, correct/total, correct_transfer/total_transfer)


if __name__ == '__main__':
    main()
