
import torch
import torchvision

from ADMM.admm import *
from ADMM.nn_models_lib import *
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import time
import gc
from tqdm import tqdm

torch.set_grad_enabled(False)
#
import sys

'''change this to the LiRPA installation directory'''
sys.path.append(r'xxxxx')

from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm


class MyModel(nn.Module):
    def __init__(self, resnet, output_layer_num = 17):
        super().__init__()
        self.resnet = resnet
        self.output_layer_num = output_layer_num

    def forward(self, x):
        output_layer_num = self.output_layer_num
        resnet = self.resnet
        basic_block_list = [resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4]

        if output_layer_num == 0:
            x = resnet.conv1(x)
            x = resnet.bn1(x)
            return x.view(x.size(0), -1)
        elif output_layer_num == 17:
            x = resnet.conv1(x)
            x = resnet.bn1(x)
            x = resnet.relu(x)
            x = resnet.maxpool(x)

            x = resnet.layer1(x)
            x = resnet.layer2(x)
            x = resnet.layer3(x)
            x = resnet.layer4(x)

            x = x.view(x.size(0), -1)
            x = resnet.fc(x)
            return x.view(x.size(0), -1)
        else:
            x = resnet.conv1(x)
            x = resnet.bn1(x)
            x = resnet.relu(x)
            x = resnet.maxpool(x)

            layer_num = (output_layer_num -1)// 4
            basic_block_num = ((output_layer_num-1) % 4) //2
            block_output_layer_num = ((output_layer_num -1)% 4) % 2

            for i in range(layer_num):
                basic_block = basic_block_list[i]
                x = basic_block(x)

            layer = basic_block_list[layer_num]

            for j in range(basic_block_num):
                x = layer[j](x)

            basic_block = layer[basic_block_num]

            if block_output_layer_num == 0:
                x = basic_block.conv1(x)
                x = basic_block.bn1(x)
                return x.view(x.size(0), -1)
            else:
                identity = x

                out = basic_block.conv1(x)
                out = basic_block.bn1(out)
                out = basic_block.relu(out)

                out = basic_block.conv2(out)
                out = basic_block.bn2(out)

                if basic_block.downsample is not None:
                    identity = basic_block.downsample(x)

                out += identity
                return out.view(out.size(0), -1)

if __name__ == '__main__':

    # =============================================================================
    #     test and plot on a small batch
    # =============================================================================

    # load the resnet
    resnet18 = models.resnet18()
    resnet18.maxpool = nn.Conv2d(64, 64, kernel_size=1, stride=2, padding=0, bias=False)
    resnet18.load_state_dict(torch.load('resnet18.pth'))

    old_fc = resnet18.fc
    weight, bias = old_fc.weight, old_fc.bias
    new_fc = nn.Linear(weight.size(1), 10)
    new_fc.weight.data = weight[:10, :]
    new_fc.bias.data = bias[:10]
    resnet18.fc = new_fc
    resnet18.eval()

    # =============================================================================
    #     find LiRPA bounds
    # =============================================================================
    accuracy_list = []
    for example_num in tqdm(range(15), desc = 'outer_loop'):
        print(accuracy_list)

        # load image
        batch_size = 1
        test_loader, cifar10_std = cifar10_loaders(batch_size)

        data_iter = iter(test_loader)

        for i in range(example_num):
            _, _ = data_iter.next()
        data = data_iter.next()

        # create directory to save the data
        id_str = 'id_' + str(example_num)
        cur_dir = os.path.dirname(__file__)
        dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str)
        os.makedirs(dir, exist_ok=True)

        # test_images, labels = data[0].cuda(), data[1].cuda()
        test_images, labels = data[0], data[1]

        # Define perturbation
        eps = 1/255
        eps_normalized = eps / torch.tensor(cifar10_std)
        temp = eps_normalized.reshape(3,1,1)
        eps_input = temp.repeat(batch_size, 1,1,1)

        size_list = [[batch_size, 64, 16, 16]] + [[batch_size, 64, 8, 8]]*4 + [[batch_size, 128, 4, 4]]*4 + [[batch_size, 256, 2, 2]]*4 + [[batch_size, 512, 1,1]]*4 + [[batch_size, 10]]
        # extract lbs and ubs
        bds_dict = {'lb': [], 'ub': []}
        output_bds = {'lb': 0, 'ub': 0}


        # for i in tqdm(range(18), desc = 'bounds'):
        for i in range(18):

            print(str(i) + ' iteration \n')

            output_layer_num = i
            nn_model = MyModel(resnet18, output_layer_num)
            nn_model.eval()

            temp = nn_model(test_images)

            # Wrap the model with auto_LiRPA
            model = BoundedModule(nn_model, test_images)

            ptb = PerturbationLpNorm(norm=np.inf, eps= eps_input)

            # Make the input a BoundedTensor with perturbation
            my_input = BoundedTensor(test_images, ptb)

            # Regular forward propagation using BoundedTensor works as usual.
            prediction = model(my_input)
            # Compute LiRPA bounds
            lb, ub = model.compute_bounds(x=(my_input,), method='backward')

            assert torch.all(prediction > lb.view(lb.size(0), -1)) & torch.all(prediction < ub.view(ub.size(0), -1))
            assert torch.all(temp > lb.view(lb.size(0), -1)) & torch.all(temp < ub.view(ub.size(0), -1))

            if i < 17:
                bds_dict['lb'].append(lb.reshape(size_list[i]))
                bds_dict['ub'].append(ub.reshape(size_list[i]))
            else:
                output_bds['lb'] = lb.reshape(size_list[i])
                output_bds['ub'] = ub.reshape(size_list[i])

        bds_dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str, 'LiRPA_intermediate_bds_id_' + str(example_num) + '_eps_1.pt')
        output_bds_dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str, 'LiRPA_output_bds_id_'  + str(example_num) + '_eps_1.pt')
        torch.save(bds_dict, bds_dir)
        torch.save(output_bds, output_bds_dir)


        # =============================================================================
        #     verify robustness through LiRPA
        # =============================================================================
        # find objective
        resnet18.eval()
        score = resnet18(test_images)
        _, class_num = torch.max(score, dim=1)

        class_num = class_num.item()
        output_bds_LiRPA = torch.load(output_bds_dir)
        LiRPA_lb = output_bds_LiRPA['lb'][0].to(torch.device('cpu'))
        LiRPA_ub = output_bds_LiRPA['ub'][0].to(torch.device('cpu'))
        diff_bds = [LiRPA_lb[class_num] - item for item in LiRPA_ub]
        diff_bds = [item.item() for item in diff_bds]
        # exclude the prediction class
        diff_bds[class_num] = 1.0
        if min(diff_bds) >= 1e-6:
            LiRPA_robust = 1
        else:
            LiRPA_robust = 0

        accuracy_list.append(LiRPA_robust)


        # =============================================================================
        #     verify robustness through ADMM
        # =============================================================================
        # x = test_images.cuda()
        # resnet18.eval()
        # resnet18.cuda()
        #
        # eps_normalized = eps/torch.tensor(cifar10_std)
        # eps_normalized = eps_normalized.reshape(3, 1, 1)
        # eps_normalized = eps_normalized.to(x.device)
        #
        # x0_lb = x - eps_normalized
        # x0_ub = x + eps_normalized
        #
        # obj_options = {'rho': 1.0, 'c': 0}
        # resnet18_nn_sections = resnet18_act_decomposition(resnet18)
        # layer_sections = generate_layer_sections(resnet18_nn_sections, x, x0_lb, x0_ub, obj_options)
        # output_dim = layer_sections[-1].output.size()
        #
        # num_classes = 10
        # verify_dim = [num_classes] + list(output_dim)[1:]
        # input_x = x.repeat(num_classes, 1, 1, 1)
        # lb = x0_lb.repeat(num_classes, 1, 1, 1)
        # ub = x0_ub.repeat(num_classes, 1, 1, 1)
        #
        # num_samples = input_x.size(0)
        #
        # # resnet18.eval()
        # # score = resnet18(test_images.cuda())
        # # _, class_num = torch.max(score, dim=1)
        # #
        # # class_num = class_num.item()
        #
        # c = -torch.eye(num_samples)
        # c[:, class_num] += 1.0
        #
        # c = c.to(next(resnet18.parameters()).device)
        #
        # alg_options = {'rho': 0.1, 'eps_abs': 1e-5, 'eps_rel': 1e-4, 'residual_balancing': True, 'max_iter': 2000,
        #                'record': True, 'verbose': True}
        #
        # pre_act_bds_dir = bds_dir
        # pre_act_bds = torch.load(pre_act_bds_dir)
        # pre_act_bds_list = [{'lb':[pre_act_bds['lb'][i]], 'ub':[pre_act_bds['ub'][i]]} for i in range(len(pre_act_bds['lb']))]
        #
        # sol, ADMM_sess = ADMM_custom_objective(resnet18, input_x, lb, ub, c, alg_options, pre_act_bds=pre_act_bds_list)
        # gc.collect()
        #
        # ADMM_sol_dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str, 'ADMM_robustness_sol_id_' + str(example_num) + '_eps_0dot8.pt')
        # torch.save(sol, ADMM_sol_dir)
        #
        # obj = sol['obj']
        # obj[class_num] = 1.0
        # if obj.min() >= 1e-6:
        #     ADMM_robust = 1
        # else:
        #     ADMM_robust = 0
        #
        # comparison_result = {'LiRPA': LiRPA_robust, 'ADMM': ADMM_robust}
        # result_dir = os.path.join(cur_dir, 'ADMM_LiRPA_comparison', id_str, 'comparison_result_id_' + str(example_num) + '_eps_0dot8.pt')
        # torch.save(comparison_result, result_dir)
