
import torch
import torchvision

from ADMM.admm import *
from ADMM.nn_models_lib import *
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import time
from tqdm import tqdm
import gc
torch.set_grad_enabled(False)

if __name__ == '__main__':

    # =============================================================================
    #     test and plot on a small batch
    # =============================================================================

    # load resnet
    resnet18 = models.resnet18()
    resnet18.maxpool = nn.Conv2d(64, 64, kernel_size=1, stride=2, padding=0, bias=False)

    resnet18.load_state_dict(torch.load('resnet18.pth'))

    # if run on CPU, comment the line below
    resnet18.cuda()

    resnet18.eval()

    old_fc = resnet18.fc
    weight, bias = old_fc.weight, old_fc.bias
    new_fc = nn.Linear(weight.size(1), 10)
    new_fc.weight.data = weight[:10, :]
    new_fc.bias.data = bias[:10]
    resnet18.fc = new_fc

    for example_num in tqdm(range(15)):
        batch_size = 1
        test_loader, cifar10_std = cifar10_loaders(batch_size)

        data_iter = iter(test_loader)

        for i in range(example_num):
            _, _ = data_iter.next()
        data = data_iter.next()

        # if run on GPU, use the following line
        test_images, labels = data[0].cuda(), data[1].cuda()

        # if run on CPU, use the following line
        # test_images, labels = data[0], data[1]

        num_batches = test_images.size(0)
        num_classes = 10

        # create directory to save the data
        id_str = 'id_' + str(example_num)
        cur_dir = os.path.dirname(__file__)
        dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str)
        os.makedirs(dir, exist_ok=True)

        bds_dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str,
                               'LiRPA_intermediate_bds_id_' + str(example_num) + '_eps_1.pt')
        output_bds_dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str,
                                      'LiRPA_output_bds_id_' + str(example_num) + '_eps_1.pt')

        x = test_images

        eps = 1/255
        eps_normalized = eps/torch.tensor(cifar10_std)
        eps_normalized = eps_normalized.reshape(3, 1, 1)
        eps_normalized = eps_normalized.to(x.device)

        x0_lb = x - eps_normalized
        x0_ub = x + eps_normalized

        obj_options = {'rho': 1.0, 'c': 0}
        resnet18_nn_sections = resnet18_act_decomposition(resnet18)
        layer_sections = generate_layer_sections(resnet18_nn_sections, x, x0_lb, x0_ub, obj_options)
        output_dim = layer_sections[-1].output.size()

        # set the ADMM objective to compute the lower and upper bounds of the resnet18 output
        c = torch.cat((torch.eye(num_classes), -torch.eye(num_classes)))
        input_x = x.repeat(2*num_classes, 1,1,1)
        lb = x0_lb.repeat(2*num_classes,1,1,1)
        ub = x0_ub.repeat(2*num_classes,1,1,1)

        c = c.to(next(resnet18.parameters()).device)

        # ADMM parameters
        alg_options = {'rho': 1.0, 'eps_abs': 1e-5, 'eps_rel': 1e-4, 'residual_balancing': False, 'max_iter': 5000,
                       'record': True, 'verbose': True}

        # plug in the intermediate preactivation bounds generated by LiRPA
        pre_act_bds = torch.load(bds_dir)
        pre_act_bds_list = [{'lb':[pre_act_bds['lb'][i]], 'ub':[pre_act_bds['ub'][i]]} for i in range(len(pre_act_bds['lb']))]

        # solve LP through ADMM
        sol, ADMM_sess = ADMM_custom_objective(resnet18, input_x, lb, ub, c, alg_options, pre_act_bds=pre_act_bds_list)
        gc.collect()

        ADMM_sol_dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str,
                                    'ADMM_output_bds_id_' + str(example_num) + '_eps_1.pt')
        torch.save(sol, ADMM_sol_dir)

        # save the results
        obj = sol['obj']
        admm_lb = obj[:10]
        admm_ub = -obj[10:]
        lirpa_bds = torch.load(output_bds_dir)
        comparison_result = {'LiRPA': lirpa_bds, 'ADMM': {'lb': admm_lb, 'ub': admm_ub}, 'ADMM_running_time': sol['running_time']}
        result_dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str,
                                  'output_bds_comparison_result_id_' + str(example_num) + '_eps_1.pt')
        torch.save(comparison_result, result_dir)

        print('ADMM lb: ', admm_lb)
        print('LiRPA lb: ', lirpa_bds['lb'] )
        print('ADMM ub: ', admm_ub)
        print('LiRPA ub: ', lirpa_bds['ub'])
