
import torch
import torchvision

from ADMM.admm import *
from ADMM.nn_models_lib import *
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import time
from tqdm import tqdm
import gc
torch.set_grad_enabled(False)

if __name__ == '__main__':

    # =============================================================================
    #     plot the lower and upper bounds obtained from ADMM and LiRPA
    # =============================================================================
    resnet18 = models.resnet18()
    resnet18.maxpool = nn.Conv2d(64, 64, kernel_size=1, stride=2, padding=0, bias=False)

    resnet18.load_state_dict(torch.load('resnet18.pth'))

    resnet18.cuda()
    resnet18.eval()

    old_fc = resnet18.fc
    weight, bias = old_fc.weight, old_fc.bias
    new_fc = nn.Linear(weight.size(1), 10)
    new_fc.weight.data = weight[:10, :]
    new_fc.bias.data = bias[:10]
    resnet18.fc = new_fc

    bds_data_list = []
    for example_num in range(15):
        batch_size = 1
        test_loader, cifar10_std = cifar10_loaders(batch_size)

        data_iter = iter(test_loader)

        for i in range(example_num):
            _, _ = data_iter.next()
        data = data_iter.next()

        test_images, labels = data[0].cuda(), data[1].cuda()
        num_batches = test_images.size(0)
        num_classes = 10

        # create directory to save the data
        id_str = 'id_' + str(example_num)
        cur_dir = os.path.dirname(__file__)
        dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str)

        result_dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str,
                                  'output_bds_comparison_result_id_' + str(example_num) + '_eps_1.pt')
        bds_data = torch.load( result_dir)
        bds_data_list.append(bds_data)


    # concatenate the lbs
    admm_lb = torch.zeros(1, 10)
    lirpa_lb = torch.zeros(1, 10)
    for data in bds_data_list:
        admm_lb = torch.cat((admm_lb, data['ADMM']['lb'].unsqueeze(0).to(torch.device('cpu'))))
        lirpa_lb = torch.cat((lirpa_lb, data['LiRPA']['lb'].to(torch.device('cpu'))))

    admm_lb = admm_lb[1:,:]
    lirpa_lb = lirpa_lb[1:,:]

    lb = torch.cat((admm_lb.view(1, -1), lirpa_lb.view(1,-1)))

    # sort the output lower bounds
    lb_sort, ind = torch.sort(lb)
    ind = ind[0]
    lb_sort = torch.cat((lb[0][ind].unsqueeze(0), lb[1][ind].unsqueeze(0)))

    # concatenate the ubs
    admm_ub = torch.zeros(1, 10)
    lirpa_ub = torch.zeros(1, 10)
    for data in bds_data_list:
        admm_ub = torch.cat((admm_ub, data['ADMM']['ub'].unsqueeze(0).to(torch.device('cpu'))))
        lirpa_ub = torch.cat((lirpa_ub, data['LiRPA']['ub'].to(torch.device('cpu'))))

    admm_ub = admm_ub[1:,:]
    lirpa_ub = lirpa_ub[1:,:]

    ub = torch.cat((admm_ub.view(1, -1), lirpa_ub.view(1,-1)))
    # sort the output upper bounds
    ub_sort, ind = torch.sort(ub)
    ind = ind[0]
    ub_sort = torch.cat((ub[0][ind].unsqueeze(0), ub[1][ind].unsqueeze(0)))

    # plot the first 50 smallest lbs for comparison (figure 2 in the paper)
    fig, ax = plt.subplots(figsize=(12, 6))
    # We change the fontsize of minor ticks label
    ax.tick_params(axis='x', which='major', labelsize=18)
    ax.tick_params(axis='y', which='major', labelsize=18)
    ax.plot(lb_sort[0][:50], 'o', ms = 10,  label = 'ADMM lbs')
    ax.plot(lb_sort[1][:50], '^', ms = 10, label = 'LiRPA lbs')
    ax.legend(fontsize = 24)
    plt.xlabel(r'bound index', fontsize = 24)
    plt.ylabel(r'value',  fontsize = 24)
    plt.grid()

    # plot the first 50 largest ubs for comparison (figure 2 in the paper)
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.tick_params(axis='x', which='major', labelsize=18)
    ax.tick_params(axis='y', which='major', labelsize=18)
    ax.plot(ub_sort[0][-50:], 'o', ms = 10,  label = 'ADMM ubs')
    ax.plot(ub_sort[1][-50:], '^', ms = 10, label = 'LiRPA ubs')
    ax.legend(fontsize = 24)
    plt.xlabel(r'bound index', fontsize = 24)
    plt.ylabel(r'value',  fontsize = 24)
    plt.grid()

    # plot the all the lower bounds
    plt.figure(figsize=(12, 6))
    plt.plot(lb_sort[0], 'o', ms = 10,  label = 'ADMM lbs')
    plt.plot(lb_sort[1], '^', ms = 10, label = 'LiRPA lbs')

    plt.legend(fontsize = 18)
    plt.xlabel(r'bound index', fontsize = 24)
    plt.ylabel(r'value',  fontsize = 24)
    plt.grid()

    # plot all the upper bounds
    plt.figure(figsize= (12,6))
    plt.plot(ub_sort[0], 'o', ms = 6, label = 'ADMM ubs')
    plt.plot(ub_sort[1], '^', ms = 6, label = 'LiRPA ubs')

    plt.legend(fontsize = 14)
    plt.xlabel(r'bound index', fontsize = 16)
    plt.ylabel(r'value', fontsize = 16)
    plt.grid()

    plt.show()

