
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import warnings
warnings.simplefilter("always")

import numpy as np
import matplotlib.pyplot as plt
import time
from .fft_convolutions import FFTConv2d

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable

from torch.utils.data import DataLoader, Subset
import pickle

from tqdm import tqdm
torch.set_grad_enabled(False)

import gc
import os.path

from .admm_layers import *

# =============================================================================
# load neural network models
# =============================================================================

def mnist_loaders(batch_size, shuffle_test=False):
    mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
    mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=shuffle_test, pin_memory=True)
    return train_loader, test_loader


def cifar10_loaders(batch_size, shuffle_test=False):
    # data set normalization parameters are hard coded
    cifar10_mean = (0.4914, 0.4822, 0.4465)
    cifar10_std = (0.2471, 0.2435, 0.2616)

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(cifar10_mean, cifar10_std)])

    # trainset = datasets.CIFAR10(root='./data', train=True,
    #                                         download=True, transform=transform)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
    #                                           shuffle=True, pin_memory=True)

    testset = datasets.CIFAR10(root='./data', train=False,
                               download=True, transform=transform)

    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=shuffle_test, pin_memory=True)

    return testloader, cifar10_std


# =============================================================================
# Define the ADMM layers
# =============================================================================

class ADMMBase():
    def update_x(self, input_z_pre, input_mu_pre):
        raise NotImplementedError

    def update_yz(self, input_x_next):
        raise NotImplementedError

    def update_dual(self, x_next):
        raise NotImplementedError


class ADMMModule(ADMMBase):
    pass

class ADMMLayer(ADMMModule):
    def __init__(self, x_init, y_init, z_init, layer, lb=None, ub=None):
        self.x = x_init
        self.y = y_init
        self.z = z_init

        self.layer = layer

        self.num_batches = self.x.size(0)

        self.lb = lb
        self.ub = ub

        self.lam = torch.zeros_like(self.x)
        self.mu = torch.zeros_like(self.z)

        # label the type of the layer
        if isinstance(layer, nn.Linear):
            self.label = 'linear'
        elif isinstance(layer, nn.ReLU):
            self.label = 'relu'
        elif isinstance(layer, nn.Conv2d):
            self.label = 'conv'
        elif isinstance(layer, Down_Sampling):
            self.label = 'downsampling'
        elif isinstance(layer, FFT_Padding):
            self.label = 'padding'
        elif isinstance(layer, FFT_Cropping):
            self.label = 'cropping'
        elif isinstance(layer, Bias):
            self.label = 'bias'
        elif isinstance(layer, Conv_Post_Processing):
            self.label = 'conv_post_processing'
        elif isinstance(layer, nn.AdaptiveAvgPool2d):
            '''only consider this output case'''
            assert layer.output_size[0] == 1
            self.label = 'avg_pooling'
        elif isinstance(layer, nn.BatchNorm2d):
            self.label = 'batch_norm'
        elif isinstance(layer, Identity):
            self.label = 'identity'
        else:
            self.label = 'undecided'

        if isinstance(layer, nn.Linear):
            self.W = layer.weight
            self.b = layer.bias
            self.inv_mat = torch.inverse(torch.eye(self.W.shape[1]).to(self.W.device) + (self.W.t()).mm(self.W))

        elif isinstance(layer, nn.Conv2d):
            self.fft_conv = FFTConv2d(layer, self.x.size())

        elif isinstance(layer, Down_Sampling):
            self.stride = layer.stride

        elif isinstance(layer, FFT_Padding):
            self.conv = layer.conv

        elif isinstance(layer, Bias):
            self.b = layer.b

        elif isinstance(layer, Conv_Post_Processing):
            self.conv = layer.conv
            self.b = layer.b
            self.stride = layer.stride
            self.x_init = layer.X
            self.output_init = layer.output_init

        elif isinstance(layer, nn.BatchNorm2d):
            self.bn_layer = layer

    def update_x(self, z_pre, mu_pre):
        x_old = self.x
        x_new = 1 / 2 * (self.y - self.lam + z_pre.reshape(self.y.size()) - mu_pre.reshape(self.y.size()))
        self.x = x_new

    def update_yz(self, x_next):
        y_old = self.y
        z_old = self.z

        y_before_proj = self.x + self.lam
        z_before_proj = x_next.reshape(self.mu.size()) + self.mu

        y_new, z_new = self.project_yz(y_before_proj, z_before_proj)

        self.y = y_new
        self.z = z_new

        return y_new - y_old, z_new - z_old

    def update_dual(self, x_next):
        lam_new = self.lam + self.x - self.y
        mu_new = self.mu + x_next.reshape(self.mu.size()) - self.z

        self.lam = lam_new
        self.mu = mu_new

    def project_yz(self, y_before_proj, z_before_proj):
        if self.label == 'linear':

            W = self.W
            b = self.b.reshape(1, -1)
            b = b.repeat(self.y.shape[0], 1)
            temp = y_before_proj + (z_before_proj - b).matmul(W)
            y_proj = temp.matmul(self.inv_mat)
            z_proj = y_proj.matmul(W.t()) + b


        elif self.label == 'relu':
            y_proj, z_proj = relu_projection(y_before_proj, z_before_proj, self.lb, self.ub)

        elif self.label == 'conv':
            y_proj, z_proj = conv_proj(y_before_proj, z_before_proj, self.fft_conv)

        elif self.label == 'downsampling':
            y_proj, z_proj = down_sampling_proj(y_before_proj, z_before_proj, self.stride)

        elif self.label == 'padding':
            y_proj, z_proj = fft_padding_proj(y_before_proj, z_before_proj, self.conv)

        elif self.label == 'cropping':
            y_proj, z_proj = fft_cropping_proj(y_before_proj, z_before_proj)

        elif self.label == 'bias':
            y_proj, z_proj = bias_proj(y_before_proj, z_before_proj, self.b)

        elif self.label == 'conv_post_processing':
            y_proj, z_proj = conv_post_processing_proj(y_before_proj, z_before_proj, self.output_init, self.stride,
                                                       self.b)
        elif self.label == 'batch_norm':
            y_proj, z_proj = batch_norm_proj(y_before_proj, z_before_proj, self.bn_layer)

        elif self.label == 'avg_pooling':
            y_proj, z_proj = avg_pooling_proj(y_before_proj, z_before_proj)

        elif self.label == 'identity':
            y_proj, z_proj = identity_proj(y_before_proj, z_before_proj)

        return y_proj, z_proj


class ADMMInput(ADMMLayer):
    def __init__(self, x_init, y_init, z_init, layer, lb, ub):
        super().__init__(x_init, y_init, z_init, layer, lb, ub)

    def update_x(self, x0_lb = None, x0_ub = None):
        # if x0_lb is None or x0_ub is None:
        #     warnings.warn('given state input bound will not be used.')
        if x0_lb is None:
            x0_lb = self.lb

        if x0_ub is None:
            x0_ub = self.ub
        #
        # x0_lb = self.lb
        # x0_ub = self.ub

        temp = self.y - self.lam
        x_new = torch.max(torch.min(temp, x0_ub), x0_lb)
        self.x = x_new


class ADMMOutput(ADMMModule):
    def __init__(self, x_init, c, rho, lb = None, ub = None):
        self.x = x_init
        self.c = c

        self.rho = rho
        self.lb = lb
        self.ub = ub
        '''attention: adaptive rho not used yet'''
        num_batches = c.size(0)
        self.rho = rho * torch.ones(c.size(0)).to(self.c.device)
        # self.rho = rho

    def update_x(self, z_pre, mu_pre):
        rho = self.rho

        dim = len(list(self.c.size()))
        assert dim > 1
        for i in range(dim -1):
            rho = rho.unsqueeze(-1)

        '''Attention: change rho to scalar for simplicity'''
        x_new = -1 / rho * self.c + z_pre - mu_pre

        self.x = x_new


# =============================================================================
#     projection functions
# =============================================================================
def relu_projection(y_0, z_0, y_min, y_max):
    sz = y_0.size()
    yp, zp = relu_proj_cvx(y_0.view(-1), z_0.view(-1), y_min.view(-1), y_max.view(-1))
    return yp.view(*sz), zp.view(*sz)


def relu_proj_cvx(y_before_proj, z_before_proj, y_lb, y_ub):
    # projected vectors
    y_proj = torch.zeros_like(y_before_proj)
    z_proj = torch.zeros_like(z_before_proj)

    z_lb = torch.max(y_lb, torch.zeros_like(y_lb))
    z_ub = torch.max(y_ub, torch.zeros_like(y_ub))

    # case 1: y_lb >= 0
    indices = (y_lb >= 0)

    mid = (y_before_proj[indices] + z_before_proj[indices])/2
    mid = torch.min(torch.max(mid, y_lb[indices]), y_ub[indices])

    y_proj[indices] = mid
    z_proj[indices] = mid

    # case 2: y_max <= 0
    indices = (y_ub <= 0)
    mid = torch.min(torch.max(y_before_proj[indices], y_lb[indices]), y_ub[indices])
    mid = torch.min(torch.max(mid, y_lb[indices]), y_ub[indices])

    y_proj[indices] = mid
    z_proj[indices] = 0

    # case 3: y_lb <= 0 <= y_ub
    indices = (y_lb < 0) * (y_ub > 0)
    if torch.any(indices):
        y0 = y_before_proj[indices]
        z0 = z_before_proj[indices]

        y_min = y_lb[indices]
        y_max = y_ub[indices]
        z_min = z_lb[indices]
        z_max = z_ub[indices]

        N = y0.size(0)

        yp = y0.new_zeros(N, 3)
        zp = y0.new_zeros(N, 3)

        dist = y0.new_zeros(N, 3)

        # project onto z = y, 0 <= y <= y_max
        yp[:, 0] = torch.min(torch.max(0.5 * (y0 + z0), torch.zeros_like(y0)), y_max)
        zp[:, 0] = yp[:, 0]

        dist[:, 0] = 0.5 * (yp[:, 0] - y0) ** 2 + 0.5 * (zp[:, 0] - z0) ** 2

        # project onto z = 0, y_min <= y <= 0
        yp[:, 1] = torch.min(torch.max(y0, y_min), torch.zeros_like(y0))
        zp[:, 1] = 0.0

        dist[:, 1] = 0.5 * (yp[:, 1] - y0) ** 2 + 0.5 * (zp[:, 1] - z0) ** 2

        # project onto z = s*y + b, y_min <= y <= y_max
        s = (z_max - z_min) / (y_max - y_min)

        yp[:, 2] = torch.min(torch.max((y0 + s ** 2 * y_min - s * (z_min - z0)) / (s ** 2 + 1), y_min), y_max)
        zp[:, 2] = z_min + s * (yp[:, 2] - y_min)

        dist[:, 2] = 0.5 * (yp[:, 2] - y0) ** 2 + 0.5 * (zp[:, 2] - z0) ** 2

        index = torch.argmin(dist, dim=1)

        yp = yp[torch.arange(0, N), index]
        zp = zp[torch.arange(0, N), index]

        index = (z0 >= y0) * (z0 >= 0) * (z0 <= z_min + s * (y0 - y_min))

        yp[index] = y0[index]
        zp[index] = z0[index]

        # plug in the values
        y_proj[indices] = yp
        z_proj[indices] = zp

    return y_proj, z_proj


def down_sampling_proj(y_before_proj, z_before_proj, stride):
    ys = y_before_proj[:, :, ::stride, ::stride]
    ys_update = (ys + z_before_proj) / 2

    y_proj = y_before_proj.clone()
    y_proj[:, :, ::stride, ::stride] = ys_update
    z_proj = ys_update
    return y_proj, z_proj


def fft_padding_proj(y_before_proj, z_before_proj, conv):
    '''double check if the entries are modified correspondingly'''
    pad_tuple =find_input_padding_tuple(conv)
    zero_tensor = torch.zeros(4)
    all_zero = torch.all(torch.tensor(pad_tuple).type(zero_tensor.dtype) == zero_tensor)

    if all_zero:
        # case when kernel = (1,1), stride = 1
        z_y = z_before_proj.clone()
    else:
        z_y = z_before_proj[:, :, pad_tuple[2]:-pad_tuple[3], pad_tuple[0]:-pad_tuple[1]]

    y_proj = (y_before_proj + z_y) / 2

    if all_zero:
        z_proj = y_proj.clone()
    else:
        z_proj = torch.zeros(z_before_proj.size()).to(z_before_proj.device)
        z_proj[:, :, pad_tuple[2]:-pad_tuple[3], pad_tuple[0]:-pad_tuple[1]] = y_proj

    return y_proj, z_proj


def fft_cropping_proj(y_before_proj, z_before_proj):
    n, k = find_output_cropping_tuple(z_before_proj, y_before_proj)
    y_z = y_before_proj[:, :, k:n - k, k:n - k]

    z_proj = (y_z + z_before_proj) / 2

    y_proj = y_before_proj.clone()
    y_proj[:, :, k:n - k, k:n - k] = z_proj
    return y_proj, z_proj


def bias_proj(y_before_proj, z_before_proj, b):
    # bias_tensor = generate_bias_tensor(b, z_before_proj)

    # b = bias_tensor
    # # b = b.repeat(self.y.shape[0], 1)
    # temp = y_before_proj + (z_before_proj - b)
    # y_proj = temp/2
    # z_proj = y_proj + b

    bias_tensor = generate_bias_tensor(b, z_before_proj)
    y_proj = (y_before_proj + z_before_proj) / 2 - bias_tensor / 2
    z_proj = (y_before_proj + z_before_proj) / 2 + bias_tensor / 2
    return y_proj, z_proj


def conv_proj(y_before_proj, z_before_proj, fft_conv):
    WTz = fft_conv.t(z_before_proj)
    sum_y_z = y_before_proj + WTz
    ridge_inv = fft_conv.ridge_inverse(sum_y_z)
    y_proj = ridge_inv

    z_proj = fft_conv(ridge_inv)

    return y_proj, z_proj


def conv_post_processing_proj(y_before_proj, z_before_proj, output_init, stride=1, b=None):
    n, k = find_output_cropping_tuple(output_init, y_before_proj)
    y_ds = y_before_proj[:, :, k:n - k, k:n - k][:, :, ::stride, ::stride]

    bias_tensor = generate_bias_tensor(b, z_before_proj)
    z_proj = (y_ds + z_before_proj + bias_tensor) / 2

    y_proj = y_before_proj.clone()
    y_proj[:, :, k:n - k, k:n - k][:, :, ::stride, ::stride] = (y_ds + z_before_proj - bias_tensor) / 2

    return y_proj, z_proj

def batch_norm_proj(y_before_proj, z_before_proj, m):
    # m is the batchnorm2d layer
    N, C, H, W = y_before_proj.size()

    alpha = 1.0/torch.sqrt(m.running_var + m.eps)*m.weight
    beta = -m.running_mean/torch.sqrt(m.running_var + m.eps)*m.weight + m.bias

    term_1 = 1.0/(1+alpha**2)
    term_1 = term_1.reshape(C, 1, 1)
    term_1 = term_1.repeat(N, 1, 1, 1)

    term_2 = alpha/(1+alpha**2)
    term_2 = term_2.reshape(C,1,1)
    term_2 = term_2.repeat(N,1,1,1)

    term_3 = -alpha*beta/(1 + alpha**2)
    term_3 = term_3.reshape(C,1,1)
    term_3 = term_3.repeat(N,1,1,1)

    y_proj = term_1*y_before_proj + term_2*z_before_proj + term_3

    term_4 = alpha.reshape(C,1,1)
    term_4 = term_4.repeat(N,1,1,1)

    term_5 = beta.reshape(C,1,1)
    term_5.repeat(N,1,1,1)

    z_proj = term_4*y_proj + term_5

    return y_proj, z_proj

def avg_pooling_proj(y_before_proj, z_before_proj):
    # make sure that (N, C, H, W) is reduced to (N, C, 1, 1)
    N, C, H, W = y_before_proj.size()

    avg_y = torch.mean(y_before_proj, (2,3))
    avg_y = avg_y.reshape(N,C,1,1)

    M = H*W

    # new formulation
    y_proj = y_before_proj - avg_y/(M+1) + z_before_proj/(M+1)

    z_proj = torch.mean(y_proj, (2, 3))
    z_proj = z_proj.reshape(N,C,1,1)

    return y_proj, z_proj

def identity_proj(y_before_proj, z_before_proj):
    y_proj = (y_before_proj + z_before_proj)/2
    z_proj = (y_before_proj + z_before_proj)/2
    return y_proj, z_proj

# =============================================================================
# Define ADMM blocks for ADMM implementation
# =============================================================================

class ADMM_forward_block:
    def __init__(self, ADMM_section, pre_block = None, post_block = None, options = None):
        '''ADMM_section is a list of ADMM_layers'''
        self.section = ADMM_section
        self.pre_block = pre_block
        self.post_block = post_block
        self.options = options # options to control the ADMM update
        self.length = len(ADMM_section)

        self.y_diff_list = None
        self.z_diff_list = None

        self.num_batches = self.x.size(0)

    '''interface variables'''
    @property
    def x(self):
        return self.section[0].x

    @property
    def y(self):
        return self.section[0].y

    @property
    def lam(self):
        return self.section[0].lam

    @property
    def z(self):
        if not isinstance(self.section[-1], ADMMOutput):
            return self.section[-1].z
        else:
            return None

    @property
    def mu(self):
        if not isinstance(self.section[-1], ADMMOutput):
            return self.section[-1].mu
        else:
            return None


    def set_pre_block(self, pre_block):
        self.pre_block = pre_block

    def set_post_block(self, post_block):
        self.post_block = post_block

    def update_x(self):
        length = self.length
        for i in range(length):

            if i == 0:
                cur_layer = self.section[i]
                if isinstance(cur_layer, ADMMInput):
                    cur_layer.update_x()
                else:
                    '''pre_layer z and mu'''
                    pre_layer = self.pre_block
                    cur_layer.update_x(pre_layer.z, pre_layer.mu)
            else:
                cur_layer = self.section[i]
                pre_layer = self.section[i - 1]
                cur_layer.update_x(pre_layer.z, pre_layer.mu)

    def update_yz(self):
        y_diff_list = []
        z_diff_list = []
        length = self.length
        for i in range(length - 1):
            cur_layer = self.section[i]
            next_layer = self.section[i + 1]
            y_diff, z_diff = cur_layer.update_yz(next_layer.x)
            y_diff_list.append(y_diff)
            z_diff_list.append(z_diff)

        cur_layer = self.section[-1]
        if not isinstance(cur_layer, ADMMOutput):
            next_layer = self.post_block
            '''post layer x'''
            y_diff, z_diff = cur_layer.update_yz(next_layer.x)
            y_diff_list.append(y_diff)
            z_diff_list.append(z_diff)

        self.y_diff_list = y_diff_list
        self.z_diff_list = z_diff_list

        return y_diff_list, z_diff_list

    def update_dual(self):
        length = self.length
        for i in range(length - 1):
            cur_layer = self.section[i]
            next_layer = self.section[i + 1]
            cur_layer.update_dual(next_layer.x)

        '''update dual variable of the last layer'''
        cur_layer = self.section[-1]
        if not isinstance(cur_layer, ADMMOutput):
            next_layer = self.post_block
            '''post layer x'''
            cur_layer.update_dual(next_layer.x)

    def primal_residual_sq(self):
        ADMM_section = self.section
        num_batches = ADMM_section[0].num_batches

        rp = 0
        length = len(ADMM_section)
        for i in range(length - 1):
            cur_layer = ADMM_section[i]
            next_layer = ADMM_section[i + 1]
            rp += torch.norm(cur_layer.y.reshape(num_batches, -1) - cur_layer.x.reshape(num_batches, -1), 2, dim=1) ** 2
            rp += torch.norm(next_layer.x.reshape(num_batches, -1) - cur_layer.z.reshape(num_batches, -1), 2, dim=1) ** 2

        if not isinstance(ADMM_section[-1], ADMMOutput):
            cur_layer = ADMM_section[-1]
            next_layer = self.post_block
            rp += torch.norm(cur_layer.y.reshape(num_batches, -1) - cur_layer.x.reshape(num_batches, -1), 2, dim=1) ** 2

            '''attention: 04/13 add next layer from post block'''
            rp += torch.norm(next_layer.x.reshape(num_batches, -1) - cur_layer.z.reshape(num_batches, -1), 2, dim=1) ** 2
        return rp

    def dual_residual_sq(self):

        rd = 0.0
        num_batches = self.x.size(0)
        length = len(self.y_diff_list)
        # dual residual from the first layer
        if isinstance(self.section[0], ADMMInput):
            rd += torch.norm(self.y_diff_list[0].reshape(num_batches, -1), 2, dim=1)**2
        else:
            pre_layer = self.pre_block
            rd += torch.norm(pre_layer.z_diff_list[-1].reshape(num_batches, -1) + self.y_diff_list[0].reshape(num_batches, -1), 2, dim=1)**2

        # dual residual from the middel layers
        for i in range(length-1):
            rd += torch.norm(self.z_diff_list[i].reshape(num_batches, -1) + self.y_diff_list[i + 1].reshape(num_batches, -1), 2, dim=1)**2

        # check the last layer
        if isinstance(self.section[-1], ADMMOutput):
            rd += torch.norm(self.z_diff_list[-1].reshape(num_batches, -1), 2, dim = 1)**2

        return rd

    def stopping_primal(self):
        norm_A_sq = 0.0
        p_count = 0
        for block in self.section:
            x = block.x
            if isinstance(block, ADMMInput):
                norm_A_sq += torch.norm(x.reshape(x.size(0), -1), 2, dim = 1)**2
                p_count += x.reshape(x.size(0), -1).size(-1)
            elif isinstance(block, ADMMOutput):
                norm_A_sq += torch.norm(x.reshape(x.size(0), -1), 2, dim=1) ** 2
                p_count += x.reshape(x.size(0), -1).size(-1)
            else:
                norm_A_sq += 2*torch.norm(x.reshape(x.size(0), -1), 2, dim=1) ** 2
                p_count += 2*x.reshape(x.size(0), -1).size(-1)

        norm_B_sq = 0.0
        for block in self.section:
            if isinstance(block, ADMMOutput):
                norm_B_sq += 0.0
            else:
                y = block.y
                z = block.z

                norm_B_sq += torch.norm(y.reshape(y.size(0), -1), 2, dim=1) ** 2 + torch.norm(z.reshape(z.size(0), -1), 2, dim=1) ** 2

        return norm_A_sq, norm_B_sq, p_count

    def stopping_dual(self):
        n_count = 0
        length = self.length

        for i in range(length):
            x = self.section[i].x
            n_count += x.reshape(x.size(0), -1).size(-1)

        norm_dual_sq = 0.0
        # compute for the first layer
        if isinstance(self.section[0], ADMMInput):
            lam = self.section[0].lam
            norm_dual_sq += torch.norm(lam.reshape(lam.size(0), -1), 2, dim = 1)**2
        else:
            lam = self.section[0].lam
            mu = self.pre_block.mu
            dual_sum = lam + mu.reshape(lam.size())
            norm_dual_sq += torch.norm(dual_sum.reshape(dual_sum.size(0), -1), 2, dim = 1)**2

        if length > 1:
            # compute for the middle layers
            if length > 2:
                for i in range(1, length-1):
                    lam = self.section[i].lam
                    mu = self.section[i-1].mu
                    dual_sum = lam + mu.reshape(lam.size())
                    norm_dual_sq += torch.norm(dual_sum.reshape(dual_sum.size(0), -1), 2, dim=1) ** 2

            # compute for the last layer
            if isinstance(self.section[length-1], ADMMOutput):
                mu = self.section[length-2].mu
                norm_dual_sq += torch.norm(lam.reshape(mu.size(0), -1), 2, dim = 1)**2
            else:
                lam = self.section[length-1].lam
                mu = self.section[length-2].mu
                dual_sum = lam + mu.reshape(lam.size())
                norm_dual_sq += torch.norm(dual_sum.reshape(dual_sum.size(0), -1), 2, dim=1) ** 2

        return norm_dual_sq, n_count

    def adjust_dual_variable(self, incr_idx, incr, decr_idx, decr):
        length = self.length
        for i in range(length - 1):
            cur_layer = self.section[i]
            lam = cur_layer.lam
            mu = cur_layer.mu

            lam[incr_idx] = lam[incr_idx] / incr
            lam[decr_idx] = lam[decr_idx] * decr

            mu[incr_idx] = mu[incr_idx] / incr
            mu[decr_idx] = mu[decr_idx] * decr

            cur_layer.lam = lam
            cur_layer.mu = mu

        cur_layer = self.section[-1]
        if not isinstance(cur_layer, ADMMOutput):
            lam = cur_layer.lam
            mu = cur_layer.mu

            lam[incr_idx] = lam[incr_idx] / incr
            lam[decr_idx] = lam[decr_idx] * decr

            mu[incr_idx] = mu[incr_idx] / incr
            mu[decr_idx] = mu[decr_idx] * decr

            cur_layer.lam = lam
            cur_layer.mu = mu

class ADMM_res_block:
    '''residual block represents the basic block in ResNet18 excluding the last relu layer (or up to the summation of the residual)'''
    def __init__(self, main_block, res_block, sum_block, pre_block = None, post_block = None, options = None):
        self.x = main_block.section[0].x

        assert torch.all(main_block.section[0].x == res_block.section[0].x)
        self.num_batches = self.x.size(0)

        # main_block and res_block are instances of ADMM_forward_block modeling the main and residual pass
        ''' Attention: I misnamed the main and res blocks. The main_block here refers to the residual connection and 
            the res_block refers to the identity/down-sampling connection. But this does not affect the running of the algorithm.
        '''
        self.main_block = main_block
        self.res_block = res_block

        self.sum_block = sum_block

        self.pre_block = pre_block
        self.post_block = post_block
        self.options = options

        self.y_diff_list = []
        self.z_diff_list = []


    @property
    def x(self):
        return self._x

    @x.setter
    def x(self, value):
        self._x = value

    @property
    def z(self):
        return self.sum_block.z

    @property
    def mu(self):
        return self.sum_block.mu


    def set_pre_block(self, pre_block):
        self.main_block.pre_block = pre_block
        self.res_block.pre_block = pre_block


    def set_post_block(self, post_block):
        self.sum_block.post_block = post_block

    def update_x(self):
        '''update the shared x_0 first'''
        z_pre = self.pre_block.z
        mu_pre = self.pre_block.mu
        y_m = self.main_block.y
        lam_m = self.main_block.lam
        y_r = self.res_block.y
        lam_r = self.res_block.lam

        x_0_new = (z_pre - mu_pre + y_m - lam_m + y_r - lam_r)/3
        self.main_block.update_x()
        self.res_block.update_x()
        self.sum_block.update_x()

        self.x = x_0_new
        self.main_block.section[0].x = x_0_new
        self.res_block.section[0].x = x_0_new

    def update_yz(self):
        y_main_diff_list, z_main_diff_list = self.main_block.update_yz()
        y_res_diff_list, z_res_diff_list = self.res_block.update_yz()
        y_sum_diff_list, z_sum_diff_list = self.sum_block.update_yz()

        y_diff_list = [y_main_diff_list, y_res_diff_list,  y_sum_diff_list]
        z_diff_list = [z_main_diff_list, z_res_diff_list, z_sum_diff_list]
        self.y_diff_list = [y_main_diff_list, y_res_diff_list, y_sum_diff_list]
        self.z_diff_list = [z_main_diff_list, z_res_diff_list, z_sum_diff_list]

        return y_diff_list, z_diff_list


    def update_dual(self):
        self.main_block.update_dual()
        self.res_block.update_dual()
        self.sum_block.update_dual()

    def primal_residual_sq(self):
        rp_sq_main = self.main_block.primal_residual_sq()
        rp_sq_res = self.res_block.primal_residual_sq()
        rp_sq_sum = self.sum_block.primal_residual_sq()

        rp_sq = rp_sq_main + rp_sq_res + rp_sq_sum

        return rp_sq

    def dual_residual_sq(self):
        rd_sq_main = self.main_block.dual_residual_sq()
        rd_sq_res = self.res_block.dual_residual_sq()
        rd_sq_sum = self.sum_block.dual_residual_sq()

        # correction term
        z_pre_diff = self.pre_block.z_diff_list[-1]
        y_m_diff = self.main_block.y_diff_list[0]
        y_r_diff = self.res_block.y_diff_list[0]

        num_batches = self.x.size(0)

        term_1 = torch.norm(y_r_diff.reshape(num_batches, -1) + z_pre_diff.reshape(num_batches, -1) + y_m_diff.reshape(num_batches, -1), 2, dim=1) ** 2
        term_2 = torch.norm(y_m_diff.reshape(num_batches, -1) + z_pre_diff.reshape(num_batches, -1), 2, dim = 1)**2
        term_3 = torch.norm(y_r_diff.reshape(num_batches, -1) + z_pre_diff.reshape(num_batches, -1), 2, dim = 1)**2
        corr = term_1 - term_2 - term_3

        rd_sq = rd_sq_main + rd_sq_res + rd_sq_sum + corr
        return rd_sq

    def stopping_primal(self):
        norm_A_sq_m, norm_B_sq_m, p_count_m = self.main_block.stopping_primal()
        norm_A_sq_r, norm_B_sq_r, p_count_r = self.res_block.stopping_primal()
        norm_A_sq_sum, norm_B_sq_sum, p_count_sum = self.sum_block.stopping_primal()

        # minus the extra term
        x = self.res_block.section[0].x
        minus_term_A = torch.norm(x.reshape(x.size(0),-1), 2, dim = 1)**2
        minus_term_p_count = x.reshape(x.size(0), -1).size(-1)

        norm_A_sq_total = norm_A_sq_m + norm_A_sq_r + norm_A_sq_sum - minus_term_A
        norm_B_sq_total = norm_B_sq_m + norm_B_sq_r + norm_B_sq_sum
        p_count = p_count_m + p_count_r + p_count_sum - minus_term_p_count
        return norm_A_sq_total, norm_B_sq_total, p_count

    def stopping_dual(self):
        dual_sq_m, n_count_m = self.main_block.stopping_dual()
        dual_sq_r, n_count_r = self.res_block.stopping_dual()
        dual_sq_sum, n_count_sum = self.sum_block.stopping_dual()

        # compute n_count
        dual_sq_total = dual_sq_m + dual_sq_r + dual_sq_sum
        common_x = self.main_block.x
        n_count_corr  = common_x.reshape(common_x.size(0), -1).size(1)
        n_count = n_count_m + n_count_r + n_count_sum - n_count_corr

        # compute the dual sum
        lam_r = self.main_block.lam
        lam_m = self.res_block.lam
        mu = self.pre_block.mu

        dual_sum_1 = lam_m + lam_r + mu
        term_1 = torch.norm(dual_sum_1.reshape(dual_sum_1.size(0), -1), 2, dim = 1)**2

        dual_sum_2 = lam_m + mu
        term_2 = torch.norm(dual_sum_2.reshape(dual_sum_2.size(0), -1), 2, dim = 1)**2

        dual_sum_3 = lam_r + mu
        term_3 = torch.norm(dual_sum_3.reshape(dual_sum_3.size(0), -1), 2, dim=1) ** 2

        dual_sq_corr = term_1 - term_2 - term_3
        dual_sq_total =  dual_sq_m + dual_sq_r + dual_sq_sum + dual_sq_corr

        return dual_sq_total, n_count

    def adjust_dual_variable(self, incr_idx, incr, decr_idx, decr):
        self.main_block.adjust_dual_variable(incr_idx, incr, decr_idx, decr)
        self.res_block.adjust_dual_variable(incr_idx, incr, decr_idx, decr)
        self.sum_block.adjust_dual_variable(incr_idx, incr, decr_idx, decr)


class mini_block:
    # used as the interface between the ADMM_sum_block and the main_/res_block in the ADMM_res_block
    def __init__(self, x):
        self.x = x

class ADMM_sum_block:
    # the sum block used in the ADMM_res_block implementing the sum operation
    def __init__(self, x_m, x_r, pre_block_main, pre_block_res, post_block = None):
        # initialize the sum layer
        self.x_m  = x_m
        self.x_r = x_r

        self.y_m = x_m
        self.y_r = x_r

        self.z = x_m + x_r
        self.lam_m = torch.zeros_like(self.x_m)
        self.lam_r = torch.zeros_like(self.x_r)
        self.mu = torch.zeros_like(self.z)

        self.pre_block_main = pre_block_main
        self.pre_block_res = pre_block_res
        self.post_block = post_block

        '''Attention: check if the changes carry through here'''
        x_m_block = mini_block(self.x_m)
        x_r_block = mini_block(self.x_r)
        self.x_m_block = x_m_block
        self.x_r_block = x_r_block

        self.y_m_diff = None
        self.y_r_diff = None
        self.z_diff_list = []

    def update_x(self):
        x_m_new = (self.pre_block_main.z - self.pre_block_main.mu + self.y_m - self.lam_m)/2
        x_r_new = (self.pre_block_res.z - self.pre_block_res.mu + self.y_r - self.lam_r)/2

        self.x_m = x_m_new
        self.x_r = x_r_new

        '''attention'''
        self.x_m_block.x = self.x_m
        self.x_r_block.x = self.x_r

    def update_yz(self):
        old_y_m = self.y_m
        old_y_r = self.y_r
        old_z = self.z

        y_m_before_proj = self.x_m + self.lam_m
        y_r_before_proj = self.x_r + self.lam_r

        z_before_proj = self.mu + self.post_block.x

        y_m_proj = y_m_before_proj*2/3 - y_r_before_proj/3 + z_before_proj/3
        y_r_proj = -y_m_before_proj/3 + y_r_before_proj*2/3 + z_before_proj/3
        z_proj = y_m_before_proj/3 + y_r_before_proj/3 + z_before_proj*2/3

        self.y_m = y_m_proj
        self.y_r = y_r_proj
        self.z = z_proj

        '''attention: double check this!'''
        y_m_diff = y_m_proj - old_y_m
        y_r_diff = y_r_proj - old_y_r
        z_diff = z_proj - old_z

        z_diff_list = [z_diff]
        self.y_m_diff = y_m_diff
        self.y_r_diff = y_r_diff
        self.z_diff_list = z_diff_list

        '''Attention: output form here'''
        return y_m_diff + y_r_diff, z_diff

    def update_dual(self):
        lam_m_new = self.lam_m + self.x_m - self.y_m
        lam_r_new = self.lam_r + self.x_r - self.y_r
        mu_new = self.mu + self.post_block.x.reshape(self.mu.size()) - self.z

        self.lam_m = lam_m_new
        self.lam_r = lam_r_new
        self.mu = mu_new

    def primal_residual_sq(self):
        num_batches = self.x_m.size(0)
        rp = 0.0
        rp += torch.norm(self.x_m.reshape(num_batches, -1)  - self.y_m.reshape(num_batches, -1), 2, dim = 1)**2
        rp += torch.norm(self.x_r.reshape(num_batches, -1) - self.y_r.reshape(num_batches, -1), 2, dim = 1 )**2

        '''attention: 04/13 add post block'''
        next_layer = self.post_block
        rp += torch.norm(self.z.reshape(num_batches, -1) - next_layer.x.reshape(num_batches, -1), 2, dim = 1 )**2
        return rp

    def dual_residual_sq(self):
        rd = 0.0
        num_batches = self.x_m.size(0)

        rd += torch.norm(self.y_m_diff.reshape(num_batches, -1) + self.pre_block_main.z_diff_list[-1].reshape(num_batches, -1), 2, dim=1) ** 2
        rd += torch.norm(self.y_r_diff.reshape(num_batches, -1) + self.pre_block_res.z_diff_list[-1].reshape(num_batches, -1), 2, dim=1) ** 2
        return rd

    def stopping_primal(self):
        norm_A_sq = 2*torch.norm(self.x_m.reshape(self.x_m.size(0), -1), 2, dim = 1)**2 + 2*torch.norm(self.x_r.reshape(self.x_r.size(0), -1), 2, dim = 1)**2
        norm_B_sq = torch.norm(self.y_m.reshape(self.y_m.size(0), -1), 2, dim = 1)**2 + torch.norm(self.y_r.reshape(self.y_r.size(0),-1), 2, dim = 1)**2 + torch.norm(self.z.reshape(self.z.size(0),-1), 2, dim = 1)**2
        p_count = 2*self.x_m.reshape(self.x_m.size(0),-1).size(-1) + 2*self.x_r.reshape(self.x_r.size(0), -1).size(-1)
        return norm_A_sq, norm_B_sq, p_count

    def stopping_dual(self):
        norm_dual_sq = 0.0
        lam_r = self.lam_r
        lam_m = self.lam_m
        mu_m = self.pre_block_main.mu
        mu_r = self.pre_block_res.mu

        dual_sum_m = lam_m + mu_m.reshape(lam_m.size())
        dual_sum_r = lam_r + mu_r.reshape(lam_r.size())

        norm_dual_sq += torch.norm(dual_sum_m.reshape(dual_sum_m.size(0), -1), 2, dim = 1)**2
        norm_dual_sq += torch.norm(dual_sum_r.reshape(dual_sum_r.size(0), -1), 2, dim = 1)**2

        x_m = self.x_m
        x_r = self.x_r
        n_count = x_m.reshape(x_m.size(0), -1).size(-1)
        n_count += x_r.reshape(x_r.size(0), -1).size(-1)
        return norm_dual_sq, n_count

    def adjust_dual_variable(self, incr_idx, incr, decr_idx, decr):
        lam_m = self.lam_m
        lam_r = self.lam_r
        mu = self.mu

        lam_m[incr_idx] = lam_m[incr_idx] / incr
        lam_m[decr_idx] = lam_m[decr_idx] * decr

        lam_r[incr_idx] = lam_r[incr_idx] / incr
        lam_r[decr_idx] = lam_r[decr_idx] * decr

        mu[incr_idx] = mu[incr_idx] / incr
        mu[decr_idx] = mu[decr_idx] * decr

        self.lam_m = lam_m
        self.lam_r = lam_r
        self.mu = mu

class ADMM_session:
    '''ADMM session models a given neural network initialized by a list of ADMM_blocks (either ADMM_forward_block or ADMM_res_block) arranged sequentially.'''
    def __init__(self, ADMM_blocks):
        # list of ADMM blocks, sequential connection of forward block and res block
        self.ADMM_session = ADMM_blocks
        self.length = len(ADMM_blocks)
        assert isinstance(self.ADMM_session[0].section[0], ADMMInput)
        assert isinstance(self.ADMM_session[-1].section[-1], ADMMOutput)

        self.connect()
        self.y_diff_list = []
        self.z_diff_list = []
        self.num_batches = ADMM_blocks[0].num_batches

        self.rp = None
        self.rd = None
        self.primal_stopping_threshold = None
        self.dual_stopping_threshold = None

    @property
    def rho(self):
        if isinstance( self.ADMM_session[-1].section[-1], ADMMOutput):
            return self.ADMM_session[-1].section[-1].rho
        else:
            return None

    @rho.setter
    def rho(self, value):
        if isinstance(self.ADMM_session[-1].section[-1], ADMMOutput):
            self.ADMM_session[-1].section[-1].rho = value

    @property
    def c(self):
        if isinstance(self.ADMM_session[-1].section[-1], ADMMOutput):
            return self.ADMM_session[-1].section[-1].c
        else:
            return None

    def connect(self):
        # set up the pre/post blocks
        for i in range(self.length - 1):
            self.ADMM_session[i+1].pre_block = self.ADMM_session[i]
            if isinstance(self.ADMM_session[i+1], ADMM_res_block):
                self.ADMM_session[i + 1].set_pre_block(self.ADMM_session[i])

        for i in range(self.length - 1):
            self.ADMM_session[i].post_block = self.ADMM_session[i+1]
            if isinstance(self.ADMM_session[i], ADMM_res_block):
                self.ADMM_session[i].set_post_block(self.ADMM_session[i+1])

    def update_x(self):
        for block in self.ADMM_session:
            block.update_x()

    def update_yz(self):
        y_diff_list = []
        z_diff_list = []
        for block in self.ADMM_session:
            y_diff, z_diff = block.update_yz()
            y_diff_list.append(y_diff)
            z_diff_list.append(z_diff)

        self.y_diff_list = y_diff_list
        self.z_diff_list = z_diff_list

    def update_dual(self):
        for block in self.ADMM_session:
            block.update_dual()

    def run(self):
        self.update_x()
        self.update_yz()
        self.update_dual()

    def primal_residual(self):
        rp_sq = 0.0
        for block in self.ADMM_session:
            rp_sq += block.primal_residual_sq()

        rp = torch.sqrt(rp_sq)
        self.rp = rp
        return rp

    def dual_residual(self):
        rd_sq = 0.0
        for block in self.ADMM_session:
            rd_sq += block.dual_residual_sq()

        rho = self.rho

        rd = rho*torch.sqrt(rd_sq)
        self.rd = rd
        return rd

    def stopping_primal(self, eps_abs = 1e-3, eps_rel = 1e-2):
        session = self.ADMM_session
        norm_A_sq_sum = 0.0
        norm_B_sq_sum = 0.0
        p_count_sum = 0
        for block in session:
            norm_A_sq, norm_B_sq, p_count = block.stopping_primal()
            norm_A_sq_sum += norm_A_sq
            norm_B_sq_sum += norm_B_sq
            p_count_sum += p_count

        max_norm_sq = torch.max(norm_A_sq_sum, norm_B_sq_sum)
        max_norm = torch.sqrt(max_norm_sq)

        primal_threshold = (torch.sqrt(torch.tensor([p_count_sum], dtype=torch.float32).to(max_norm.device)) * eps_abs) + eps_rel * max_norm

        return primal_threshold

    def stopping_dual(self, eps_abs = 1e-3, eps_rel = 1e-2):
        session = self.ADMM_session
        norm_dual_sq_sum = 0.0
        n_count_sum = 0.0

        for block in session:
            norm_dual_sq, n_count = block.stopping_dual()
            norm_dual_sq_sum += norm_dual_sq
            n_count_sum += n_count

        norm_dual = torch.sqrt(norm_dual_sq_sum)

        rho = self.rho

        dual_threshold = torch.sqrt(torch.tensor([n_count_sum], dtype = torch.float32).to(norm_dual.device))*eps_abs + eps_rel*norm_dual*rho

        return dual_threshold

    def residual_balancing(self, rp = None, rd = None, mu = 10, incr = 2, decr = 2):
        if rp is None:
            rp = self.rp

        if rd is None:
            rd = self.rd

        rho = self.rho
        incr_idx = (rp > mu * rd)
        decr_idx = (rd >= mu * rp)
        rho[incr_idx] = rho[incr_idx] * incr
        rho[decr_idx] = rho[decr_idx] / decr

        self.adjust_dual_variable(incr_idx, incr, decr_idx, decr)
        self.rho = rho

    def adjust_dual_variable(self, incr_idx, incr, decr_idx, decr):
        for block in self.ADMM_session:
            block.adjust_dual_variable(incr_idx, incr, decr_idx, decr)

##################################################################
# Define classes that help initialize the ADMM blocks
##################################################################

class Layer_section:
    '''initialize ADMM blocks from a list of neural network layers'''
    def __init__(self, nn_layer_list, x, label = 'middle', lb = None, ub = None, options = None,  pre_act_bds = None):
        '''
        :param nn_layer_list: a list of neural network layers arranged sequentially
        :param x: input tensor
        :param label:
                'middle': does not contain ADMMInput or ADMMOutput layer
                'input':  contains ADMMInput layer
                'output':  contains ADMMOutput layer
                'complete'" contains both ADMMInput and ADMMOutput layers
        :param lb: input lower bound
        :param ub: input upper bound
        :param options: {'rho': the rho parameter, 'c': the objective tensor}
        :param pre_act_bds: manually given pre-activation bounds
        '''
        # nn_layer_list is a list of nn layers with x as the input
        # four types of labels: middle, input, output, complete

        self.nn_list = nn_layer_list
        self.input = x
        self.length = len(nn_layer_list)
        self.intermediate_states = []
        self.propagate()
        # interface
        self.output = self.intermediate_states[-1]

        self.label = label
        # input lower and upper bounds
        self.lb = lb
        self.ub = ub
        self.options = options

        # given pre activation bounds
        self.pre_act_bds = pre_act_bds

        self.lbs = None
        self.ubs = None
        self.IBP()
        # interface
        self.output_lb = self.lbs[-1]
        self.output_ub = self.ubs[-1]

        self.ADMM_section = None
        self.batch_num = self.input.size(0)

    def propagate(self):
        # implement a forward pass through the nn layers
        x = self.input
        output_list = [x]
        for i in range(self.length):
            layer = self.nn_list[i]
            if isinstance(layer, nn.ReLU):
                y = F.relu(x)
            else:
                y = layer(x)

            output_list.append(y)
            x = y

        self.intermediate_states = output_list
        return output_list

    def IBP(self):
        # find the intermediate layer bounds through interval bound propagation
        lb = self.lb
        ub = self.ub

        lbs, ubs = compute_bounds_interval_arithmetic(self.nn_list, lb, ub)

        self.lbs = lbs
        self.ubs = ubs

        return lbs, ubs

    def init_ADMM_section(self):
        # initialize a ADMM section, which is a list of ADMM layers
        ADMM_section = []
        pre_act_layer_num = 0
        for i in range(self.length):
            layer = self.nn_list[i]
            x_init = self.intermediate_states[i]
            y_init = x_init
            z_init = self.intermediate_states[i+1]

            if self.pre_act_bds is None:
                lb = self.lbs[i]
                ub = self.ubs[i]
            else:
                lb = None
                ub = None

            if i == 0:
                if self.label == 'input' or self.label == 'complete':
                    # initialize the input layer
                    if isinstance(layer, nn.ReLU) and self.pre_act_bds is not None:
                        lb = self.pre_act_bds['lb'][pre_act_layer_num].to(x_init.device)
                        ub = self.pre_act_bds['ub'][pre_act_layer_num].to(x_init.device)

                        # '''attention '''
                        # assert (lb.size(0), ub.size(0)) == (1, 1)

                        lb = lb.repeat(self.batch_num,1,1,1)
                        ub = ub.repeat(self.batch_num,1,1,1)

                        ADMM_layer = add_ADMM_layers(x_init, y_init, z_init, layer, x_lb=lb, x_ub=ub, layer_num=0)
                        pre_act_layer_num += 1
                    else:
                        ADMM_layer = add_ADMM_layers(x_init, y_init, z_init, layer,  x_lb=self.lb, x_ub= self.ub, layer_num=0)
                else:
                    if isinstance(layer, nn.ReLU) and self.pre_act_bds is not None:
                        lb = self.pre_act_bds['lb'][pre_act_layer_num].to(x_init.device)
                        ub = self.pre_act_bds['ub'][pre_act_layer_num].to(x_init.device)

                        # assert (lb.size(0), ub.size(0)) == (1, 1)

                        lb = lb.repeat(self.batch_num, 1, 1, 1)
                        ub = ub.repeat(self.batch_num, 1, 1, 1)
                        ADMM_layer = add_ADMM_layers(x_init, y_init, z_init, layer, x_lb=lb, x_ub=ub, layer_num=1)
                        pre_act_layer_num += 1
                    else:
                        ADMM_layer = add_ADMM_layers(x_init, y_init, z_init, layer,  x_lb=lb, x_ub=ub, layer_num=1)
            else:
                if isinstance(layer, nn.ReLU) and self.pre_act_bds is not None:
                    lb = self.pre_act_bds['lb'][pre_act_layer_num].to(x_init.device)
                    ub = self.pre_act_bds['ub'][pre_act_layer_num].to(x_init.device)
                    assert (lb.size(0), ub.size(0)) == (1, 1)
                    lb = lb.repeat(self.batch_num,1,1,1)
                    ub = ub.repeat(self.batch_num,1,1,1)
                    ADMM_layer = add_ADMM_layers(x_init, y_init, z_init, layer, x_lb=lb, x_ub=ub, layer_num=1)
                    pre_act_layer_num += 1
                else:
                    ADMM_layer = add_ADMM_layers(x_init, y_init, z_init, layer,  x_lb=lb, x_ub=ub, layer_num=1)

            ADMM_section += ADMM_layer

        if self.label == 'output' or self.label == 'complete':
            # initialize the output layer
            c = self.options['c']
            rho = self.options['rho']
            # extract the output
            x_init = self.intermediate_states[-1]
            ADMM_layer = [ADMMOutput(x_init, c, rho)]

            ADMM_section += ADMM_layer

        self.ADMM_section = ADMM_section

        return ADMM_section

    def init_ADMM_block(self):
        # initialize an ADMM block from the corresponding list of ADMM layers (or a ADMM section)
        if self.ADMM_section is None:
            self.init_ADMM_section()

        block = ADMM_forward_block(self.ADMM_section)

        return block


def add_ADMM_layers(x_init, y_init, z_init, layer, x_lb=None, x_ub=None, layer_num=1):
    if layer_num == 0:
        layer_type = ADMMInput
    else:
        layer_type = ADMMLayer

    if isinstance(layer, nn.ReLU):
        admm_layer = layer_type(x_init, y_init, z_init, layer, x_lb, x_ub)
        admm_layers = [admm_layer]
    elif isinstance(layer, Identity):
        admm_layer = layer_type(x_init, y_init, z_init, layer, x_lb, x_ub)
        admm_layers = [admm_layer]

    elif isinstance(layer, nn.Linear):
        admm_layer = layer_type(x_init, y_init, z_init, layer, x_lb, x_ub)
        admm_layers = [admm_layer]

    elif isinstance(layer, nn.BatchNorm2d):
        admm_layer = layer_type(x_init, y_init, z_init, layer, x_lb, x_ub)
        admm_layers = [admm_layer]

    elif isinstance(layer, nn.AdaptiveAvgPool2d):
        admm_layer = layer_type(x_init, y_init, z_init, layer, x_lb, x_ub)
        admm_layers = [admm_layer]

    elif isinstance(layer, nn.Conv2d):
        admm_layers = []
        new_conv, stride, bias = conv_layer_params(layer)
        # if stride > 2:
        #     raise ValueError('Stride must equal 1 or 2.')

        # add padding layer
        X_0 = y_init
        padding_layer = FFT_Padding(new_conv)
        Y_0 = padding_layer(X_0)
        admm_layer = layer_type(X_0, X_0, Y_0, padding_layer, x_lb, x_ub)
        admm_layers.append(admm_layer)

        # the rest layers are given by ADMMLayer
        # add fft layer
        X_0 = Y_0
        fft_conv = FFTConv2d(new_conv, X_0.size())
        Y_0 = fft_conv(X_0)
        admm_layer = ADMMLayer(X_0, X_0, Y_0, new_conv)
        admm_layers.append(admm_layer)

        # add post processing layer
        X_0 = Y_0
        conv_post_processing_layer = Conv_Post_Processing(new_conv, x_init, stride, bias)
        Y_0 = conv_post_processing_layer(X_0)

        '''attention here'''
        output = layer(x_init)
        assert Y_0.size() == output.size()

        admm_layer = ADMMLayer(X_0, X_0, Y_0, conv_post_processing_layer)
        admm_layers.append(admm_layer)

    elif isinstance(layer, Flatten):
        '''Attention: double check this part'''
        admm_layers = []
        # raise ValueError('Flatten layer should be skipped.')

    else:
        raise ValueError('Current layer type unsupported.')

    return admm_layers


def decompose_conv_layer(conv, X):
    layers = []
    new_conv, stride, bias = conv_layer_params(conv)
    padding_layer = FFT_Padding(new_conv)
    layers.append(padding_layer)

    X_0 = padding_layer(X)
    fft_conv = FFTConv2d(new_conv, X_0.size())
    layers.append(fft_conv)

    cropping_layer = FFT_Cropping(new_conv, X)
    layers.append(cropping_layer)

    if stride == 2:
        ds_layer = Down_Sampling(stride)
        layers.append(ds_layer)

    if bias is not None:
        bias_layer = Bias(bias)
        layers.append(bias_layer)

    return layers


def conv_layer_params(conv):
    kernel_size = conv.kernel_size[0]
    padding = conv.padding[0]
    stride = conv.stride[0]
    weight = conv.weight
    bias = conv.bias
    in_ch = conv.in_channels
    out_ch = conv.out_channels

    new_conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=padding, bias=False)
    new_conv.weight = torch.nn.Parameter(weight)

    return new_conv, stride, bias


class Residual_layer_section:
    def __init__(self, main_nn_section, res_nn_section, x, lb = None, ub = None, pre_act_bds = None):
        # mian_nn_section: list of nn layers in the main pass
        # res_nn_section: list of nn layers in the residual pass
        # pre activation bounds, if given, are only for the ReLU layer in the main pass

        self.main_nn_section = main_nn_section
        self.res_nn_section = res_nn_section

        self.main_section = Layer_section(main_nn_section, x, label = 'middle', lb = lb, ub = ub,  pre_act_bds= pre_act_bds)
        # the residual section does not have relu layers
        self.res_section = Layer_section(res_nn_section, x, label = 'middle', lb = lb, ub = ub)

        self.input = x
        self.lb = lb
        self.ub = ub

        self.main_states_list = None
        self.res_states_list = None
        self.propagate()
        # interface
        self.output = self.main_states_list[-1] + self.res_states_list[-1]

        self.lbs = None
        self.ubs = None
        self.IBP()
        # interface
        self.output_lb = self.main_section.lbs[-1] + self.res_section.lbs[-1]
        self.output_ub = self.main_section.ubs[-1] + self.res_section.ubs[-1]

    def propagate(self):
        main_states_list = self.main_section.propagate()
        res_states_list = self.res_section.propagate()
        self.main_states_list = main_states_list
        self.res_states_list = res_states_list

    def IBP(self):
        self.main_section.IBP()
        self.res_section.IBP()

    def init_ADMM_block(self):
        main_block = self.main_section.init_ADMM_block()
        res_block = self.res_section.init_ADMM_block()
        sum_block = ADMM_sum_block(self.main_states_list[-1], self.res_states_list[-1], main_block, res_block)

        main_block.set_post_block(sum_block.x_m_block)
        res_block.set_post_block(sum_block.x_r_block)

        block = ADMM_res_block(main_block, res_block, sum_block)

        # self.ADMM_block = block
        return block

def contain_relu_layer(nn_section):
    # detect if a nn section contains ReLU layer
    if len(nn_section) == 2:
        # this means the given section is a ResNet basic block which contains relu layers
        return True

    if len(nn_section) == 1:
        for layer in nn_section[0]:
            if isinstance(layer, nn.ReLU):
                return True

    return False

def generate_layer_sections(nn_sections, x, lb, ub, obj_options, pre_act_bds_list = None):
    '''nn_section = [ section_1, section_2, ... ]
       feedforward layers: section_1 = [[conv, bn, relu, ...]], residual block layers: section_2 = [[conv1, bn1, relu, conv2, bn2], [downsample layers]], ...'''
    num_blocks = len(nn_sections)
    layer_section_list = []
    act_layer_num = 0

    if pre_act_bds_list is None:
        '''attention'''
        pre_act_bds_list = [None]*num_blocks

    if num_blocks == 1:
        nn_cur_section = nn_sections[0]
        cur_layer_section = Layer_section(nn_cur_section[0], x, label='complete', lb = lb, ub = ub, options = obj_options, pre_act_bds = pre_act_bds_list[act_layer_num])
        layer_section_list.append(cur_layer_section)
        return layer_section_list

    for i in range(num_blocks):
        nn_cur_section = nn_sections[i]
        if i == 0:
            if len(nn_cur_section) == 1:
                # the starting block cannot be a ResNet basic block
                if contain_relu_layer(nn_cur_section):
                    cur_layer_section = Layer_section(nn_cur_section[0], x, label = 'input', lb = lb, ub = ub,  pre_act_bds = pre_act_bds_list[act_layer_num])
                    act_layer_num += 1
                else:
                    cur_layer_section = Layer_section(nn_cur_section[0], x, label = 'input', lb = lb, ub = ub)

                layer_section_list.append(cur_layer_section)

                lb = cur_layer_section.output_lb
                ub = cur_layer_section.output_ub
                x = cur_layer_section.output
            else:
                raise ValueError('case not defined')

        elif i == num_blocks - 1:
            if len(nn_cur_section) == 1:
                if contain_relu_layer(nn_cur_section):
                    cur_layer_section = Layer_section(nn_cur_section[0], x, label='output', lb=lb, ub=ub, options=obj_options,  pre_act_bds = pre_act_bds_list[act_layer_num])
                    act_layer_num += 1
                else:
                    cur_layer_section = Layer_section(nn_cur_section[0], x, label='output', lb=lb, ub=ub, options=obj_options)

                layer_section_list.append(cur_layer_section)

                lb = cur_layer_section.output_lb
                ub = cur_layer_section.output_ub
                x = cur_layer_section.output
            else:
                # the last block cannot be a ResNet basic block
                raise ValueError('case not defined')

        else:
            if len(nn_cur_section) == 1:
                # feedforward block
                if contain_relu_layer(nn_cur_section):
                    cur_layer_section = Layer_section(nn_cur_section[0], x, label='middle', lb=lb, ub=ub,  pre_act_bds = pre_act_bds_list[act_layer_num])
                    act_layer_num += 1
                else:
                    cur_layer_section = Layer_section(nn_cur_section[0], x, label='middle', lb=lb, ub=ub)

                layer_section_list.append(cur_layer_section)

                lb = cur_layer_section.output_lb
                ub = cur_layer_section.output_ub
                x = cur_layer_section.output
            elif len(nn_cur_section) == 2:
                # residual block
                cur_layer_section = Residual_layer_section(nn_cur_section[0], nn_cur_section[1], x, lb, ub, pre_act_bds = pre_act_bds_list[act_layer_num])
                assert contain_relu_layer(nn_cur_section)
                act_layer_num += 1

                layer_section_list.append(cur_layer_section)

                lb = cur_layer_section.output_lb
                ub = cur_layer_section.output_ub
                x = cur_layer_section.output
            else:
                raise ValueError('case not defined')
    return layer_section_list


def init_ADMM_session(nn_sections, x, lb, ub, obj_options, pre_act_bds_list = None):
    layer_section_list = generate_layer_sections(nn_sections, x, lb, ub, obj_options, pre_act_bds_list)
    length = len(layer_section_list)
    ADMM_block_list = []
    for layer_section in layer_section_list:
        ADMM_block = layer_section.init_ADMM_block()
        ADMM_block_list.append(ADMM_block)

    output = layer_section_list[-1].output
    return ADMM_block_list, output

########################################################################################
# Run ADMM
#########################################################################################

def run_ADMM(ADMM_sess, eps_abs = 1e-5, eps_rel = 1e-4, residual_balancing = False, max_iter = 10000, record = False, verbose = False):
    start_time = time.time()
    rp_list = []
    rd_list = []
    obj_list = []
    p_tol_list = []
    d_tol_list = []
    rho_list = []
    succ_ratio_list = []

    num = ADMM_sess.num_batches
    c = ADMM_sess.c

    latest_peak = torch.zeros(num).to(ADMM_sess.ADMM_session[0].x.device)
    latest_valley = torch.zeros(num).to(ADMM_sess.ADMM_session[0].x.device)
    # initialize update indices to be false
    valley_accumulated_update_idx = (latest_valley <= -1.0)
    peak_accumulated_update_idx = (latest_peak <= -1.0)

    window_length = 3
    moving_window = []

    view_id = 0

    for i in range(max_iter):
        last_state = ADMM_sess.ADMM_session[-1].section[-1].x
        product = last_state.reshape(num, -1) * c.reshape(num, -1)
        obj = product.sum(dim=1)

        moving_window.append(obj)
        if len(moving_window) > window_length:
            moving_window = moving_window[1:]

        if len(moving_window) == window_length:
            peak_index = (moving_window[1] >= moving_window[0]) & (moving_window[1] >= moving_window[2])
            valley_index = (moving_window[1] <= moving_window[0]) & (moving_window[1] <= moving_window[2])

            valley_accumulated_update_idx = valley_accumulated_update_idx | valley_index
            peak_accumulated_update_idx = peak_accumulated_update_idx | peak_index

            latest_peak[peak_index] = moving_window[1][peak_index]
            latest_valley[valley_index] = moving_window[1][valley_index]

        ADMM_sess.update_x()
        ADMM_sess.update_yz()
        ADMM_sess.update_dual()
        rp = ADMM_sess.primal_residual()
        rd = ADMM_sess.dual_residual()

        if residual_balancing:
            '''attention: stop applying residual balancing after 3000 iterations'''
            if i < 3000:
                ADMM_sess.residual_balancing()

        p_tol = ADMM_sess.stopping_primal(eps_abs, eps_rel)
        d_tol = ADMM_sess.stopping_dual(eps_abs, eps_rel)
        rho = ADMM_sess.rho

        if verbose:
            # print('iter no. ', str(i), 'rho: ', ADMM_sess.rho[view_id] ,'primal: ', rp[view_id], p_tol[view_id], ' dual: ', rd[view_id], d_tol[view_id], ' obj: ', obj[view_id])
            print('iter no. ', str(i) ,'primal (res. and tol.): ', rp[view_id], p_tol[view_id], ' dual (res. and tol.): ', rd[view_id], d_tol[view_id], ' obj: ', obj[view_id])

        if record:
            rp_list.append(rp)
            rd_list.append(rd)
            obj_list.append(obj)
            p_tol_list.append(p_tol)
            d_tol_list.append(d_tol)
            rho_list.append(ADMM_sess.rho.clone())

        status = 'max_iter'

        if torch.all(rp <= p_tol) and torch.all(rd <= d_tol):
            status = 'meet_tolerance'
            break

        if i % 100 == 0:
            succ_ratio = ((rp <= p_tol) & (rd <= d_tol)).sum().item() / num
            succ_ratio_list.append(succ_ratio)
            print('success ratio', succ_ratio)

    if status == 'meet_tolerance':
        print('All examples meet stopping criterion.')
    else:
        print('Maximum number of iteration is reached.')

    latest_valley[~valley_accumulated_update_idx] = obj[~valley_accumulated_update_idx]
    latest_peak[~peak_accumulated_update_idx] = obj[~peak_accumulated_update_idx]

    termination_example_id = ((rp <= p_tol) & (rd <= d_tol))
    running_time = time.time() - start_time
    if record:
        result = {'obj_list': obj_list, 'rp_list': rp_list, 'rd_list': rd_list, 'p_tol_list': p_tol_list, 'd_tol_list': d_tol_list,
                  'rho_list': rho_list, 'running_time': running_time, 'latest_peak': latest_peak, 'latest_valley': latest_valley, 'status': status}
    else:
        result = {'obj_list': obj, 'rp_list': rp, 'rd_list': rd, 'p_tol_list': p_tol, 'd_tol_list': d_tol, 'rho_list': rho,
                  'running_time': running_time, 'latest_peak': latest_peak, 'latest_valley': latest_valley, 'status': status}

    return obj, running_time, result, termination_example_id, latest_peak, latest_valley


def ADMM_custom_objective(resnet18, x, x0_lb, x0_ub, c, alg_options, pre_act_bds = None):
    # run ADMM with custom objective function on resnet18
    resnet18_nn_sections = resnet18_act_decomposition(resnet18)
    rho = alg_options['rho']
    obj_opt = {'rho': rho, 'c': c}

    if pre_act_bds is None:
        ADMM_blocks, output_x = init_ADMM_session(resnet18_nn_sections, x, x0_lb, x0_ub, obj_opt)
    elif isinstance(pre_act_bds, str):
        pre_act_bds_list = torch.load(pre_act_bds)
        ADMM_blocks, output_x = init_ADMM_session(resnet18_nn_sections, x, x0_lb, x0_ub, obj_opt, pre_act_bds_list=pre_act_bds_list)
    elif isinstance(pre_act_bds, list):
        ADMM_blocks, output_x = init_ADMM_session(resnet18_nn_sections, x, x0_lb, x0_ub, obj_opt, pre_act_bds_list=pre_act_bds)
    else:
        raise ValueError('Preactivation bounds input not suported.')

    ADMM_sess = ADMM_session(ADMM_blocks)

    eps_abs = alg_options['eps_abs']
    eps_rel = alg_options['eps_rel']
    residual_balancing = alg_options['residual_balancing']
    max_iter = alg_options['max_iter']
    record = alg_options['record']
    verbose = alg_options['verbose']

    obj, running_time, result, termination_example_id, latest_peak, latest_valley = run_ADMM(ADMM_sess, eps_abs=eps_abs, eps_rel=eps_rel, residual_balancing=residual_balancing, max_iter=max_iter, record=record, verbose=verbose)
    sol = {'obj': obj, 'running_time': running_time, 'result': result, 'termination_example_id': termination_example_id, 'latest_peak': latest_peak, 'latest_valley': latest_valley}

    return sol, ADMM_sess



def ADMM_CNN_custom_objective(nn_model, x, x0_lb, x0_ub, c, alg_options, pre_act_bds = None):
    # run ADMM on custom objective function on a feedforward CNN
    nn_sections = network_decomposition(nn_model)
    rho = alg_options['rho']
    obj_opt = {'rho': rho, 'c': c}
    if pre_act_bds is None:
        ADMM_blocks, output_x = init_ADMM_session(nn_sections, x, x0_lb, x0_ub, obj_opt)
    elif isinstance(pre_act_bds, str):
        pre_act_bds_list = torch.load(pre_act_bds)
        ADMM_blocks, output_x = init_ADMM_session(nn_sections, x, x0_lb, x0_ub, obj_opt, pre_act_bds_list=pre_act_bds_list)
    elif isinstance(pre_act_bds, list):
        ADMM_blocks, output_x = init_ADMM_session(nn_sections, x, x0_lb, x0_ub, obj_opt, pre_act_bds_list=pre_act_bds)
    else:
        raise ValueError('Preactivation bounds input not suported.')

    ADMM_sess = ADMM_session(ADMM_blocks)

    eps_abs = alg_options['eps_abs']
    eps_rel = alg_options['eps_rel']
    residual_balancing = alg_options['residual_balancing']
    max_iter = alg_options['max_iter']
    record = alg_options['record']
    verbose = alg_options['verbose']

    obj, running_time, result, termination_example_id, latest_peak, latest_valley = run_ADMM(ADMM_sess, eps_abs=eps_abs,
                                                                                             eps_rel=eps_rel,
                                                                                             residual_balancing=residual_balancing,
                                                                                             max_iter=max_iter,
                                                                                             record=record,
                                                                                             verbose=verbose)

    sol = {'obj': obj, 'running_time': running_time, 'result': result, 'termination_example_id': termination_example_id,
           'latest_peak': latest_peak, 'latest_valley': latest_valley}

    return sol, ADMM_sess



def compute_bounds_interval_arithmetic(model, lb0, ub0):
    """ Calculate upper and lower bounds with IBP (loose calculation) """
    lbs, ubs = [lb0], [ub0]
    lbs_pre_act = []
    ubs_pre_act = []
    for layer in model:
        lb, ub = lbs[-1], ubs[-1]
        if isinstance(layer, nn.Linear):
            mu = (ub + lb) / 2
            r = (ub - lb) / 2

            mu = F.linear(mu, layer.weight, layer.bias)
            r = F.linear(r, layer.weight.abs())

            lbs.append(mu - r)
            ubs.append(mu + r)

        elif isinstance(layer, nn.Conv2d):
            '''we do not consider bias yet'''
            # assert layer.bias is None

            mu = (ub + lb) / 2
            r = (ub - lb) / 2

            mu = layer(mu)

            if layer.bias is None:
                original_weight = layer.weight
                layer.weight = nn.Parameter(layer.weight.abs())
                r = layer(r)
                layer.weight = original_weight
            else:
                original_weight = layer.weight
                original_bias = layer.bias

                layer.weight = nn.Parameter(layer.weight.abs())
                layer.bias = nn.Parameter(torch.zeros(layer.bias.size()).to(r.device))
                r = layer(r)
                layer.weight = original_weight
                layer.bias = original_bias

            lbs.append(mu - r)
            ubs.append(mu + r)

        elif isinstance(layer, Flatten):
            lbs.append(layer(lb))
            ubs.append(layer(ub))

        elif isinstance(layer, nn.ReLU):
            # extract the preactivation layer
            lbs_pre_act.append(lb)
            ubs_pre_act.append(ub)

            # do not use layer(lb), since the given relu layer may have inplace = True.
            lbs.append(F.relu(lb))
            ubs.append(F.relu(ub))

        elif isinstance(layer, nn.BatchNorm2d):
            mu = (ub + lb) / 2
            r = (ub - lb) / 2

            alpha = 1.0 / torch.sqrt(layer.running_var + layer.eps) * layer.weight
            beta = -layer.running_mean / torch.sqrt(layer.running_var + layer.eps) * layer.weight + layer.bias

            N, C, H, W = mu.size()

            weight = alpha.reshape(C, 1, 1)
            weight = weight.repeat(N, 1,1,1)
            bias = beta.reshape(C,1,1)
            bias = bias.repeat(N,1,1,1)

            mu = weight*mu + bias
            r = weight.abs()*r

            lbs.append(mu - r)
            ubs.append(mu + r)

        elif isinstance(layer, nn.AdaptiveAvgPool2d):
            assert layer.output_size[0] == 1
            lbs.append(layer(lb))
            ubs.append(layer(ub))

        elif isinstance(layer, Identity):
            lbs.append(layer(lb))
            ubs.append(layer(ub))

        else:
            raise ValueError("Unknown layer type for IBP")

    return lbs, ubs

# decompose resnet18 in order to construct the ADMM modules
def resnet18_act_decomposition(resnet18):
    # manual decomposition of resnet18. This part will be made automatic in the future updates.
    nn_sections = []

    nn_sections.append([[resnet18.conv1, resnet18.bn1]])
    nn_sections.append([[resnet18.relu]])
    nn_sections.append([[resnet18.maxpool]])

    nn_sections.append([[resnet18.layer1[0].conv1, resnet18.layer1[0].bn1, resnet18.layer1[0].relu,
                         resnet18.layer1[0].conv2, resnet18.layer1[0].bn2], [Identity()]])
    nn_sections.append([[resnet18.relu]])
    nn_sections.append([[resnet18.layer1[1].conv1, resnet18.layer1[1].bn1, resnet18.layer1[1].relu,
                         resnet18.layer1[1].conv2, resnet18.layer1[1].bn2], [Identity()]])
    nn_sections.append([[resnet18.relu]])

    nn_sections.append([[resnet18.layer2[0].conv1, resnet18.layer2[0].bn1, resnet18.layer2[0].relu,
                         resnet18.layer2[0].conv2, resnet18.layer2[0].bn2],
                        [resnet18.layer2[0].downsample[0], resnet18.layer2[0].downsample[1]]])
    nn_sections.append([[resnet18.relu]])
    nn_sections.append([[resnet18.layer2[1].conv1, resnet18.layer2[1].bn1, resnet18.layer2[1].relu,
                         resnet18.layer2[1].conv2, resnet18.layer2[1].bn2], [Identity()]])
    nn_sections.append([[resnet18.relu]])

    nn_sections.append([[resnet18.layer3[0].conv1, resnet18.layer3[0].bn1, resnet18.layer3[0].relu,
                         resnet18.layer3[0].conv2, resnet18.layer3[0].bn2],
                        [resnet18.layer3[0].downsample[0], resnet18.layer3[0].downsample[1]]])
    nn_sections.append([[resnet18.relu]])
    nn_sections.append([[resnet18.layer3[1].conv1, resnet18.layer3[1].bn1, resnet18.layer3[1].relu,
                         resnet18.layer3[1].conv2, resnet18.layer3[1].bn2], [Identity()]])
    nn_sections.append([[resnet18.relu]])

    nn_sections.append([[resnet18.layer4[0].conv1, resnet18.layer4[0].bn1, resnet18.layer4[0].relu,
                         resnet18.layer4[0].conv2, resnet18.layer4[0].bn2],
                        [resnet18.layer4[0].downsample[0], resnet18.layer4[0].downsample[1]]])
    nn_sections.append([[resnet18.relu]])
    nn_sections.append([[resnet18.layer4[1].conv1, resnet18.layer4[1].bn1, resnet18.layer4[1].relu,
                         resnet18.layer4[1].conv2, resnet18.layer4[1].bn2], [Identity()]])
    nn_sections.append([[resnet18.relu]])

    # the following two lines are equivalent for resnet18
    nn_sections.append([[resnet18.avgpool, Flatten(), resnet18.fc]])
    # nn_sections.append([[Flatten(), resnet18.fc]])

    return nn_sections

def resnet18_act_truncation(nn_sections, layer_num):
    # truncate the resnet18 before the relu layers
    truncated_section = []
    relu_layer_num = 0
    for section in nn_sections:
        if (len(section) == 1 and isinstance(section[0][0], nn.ReLU)) or len(section) == 2:
            if relu_layer_num >= layer_num:
                if len(section) == 2:
                    temp = [section[0][:2]]
                    truncated_section.append(temp)
                    break
                else:
                    break
            else:
                truncated_section.append(section)

            relu_layer_num += 1

        else:
            truncated_section.append(section)

    # if the last layer is a residual layer, append an identity layer
    if len(truncated_section[-1]) == 2:
        truncated_section.append([[Identity()]])

    return truncated_section

def network_decomposition(nn_model):
    # decompose a feedforward nn for constructing ADMM modules
    nn_section = []
    for layer in nn_model:
        nn_section.append(layer)

    return [[nn_section]]

##############################################################################################################
# The following codes are designed mainly to compute the all the preactivation bounds of resnet18 through ADMM
##############################################################################################################

def find_starting_layer_num(file_head):
    for i in range(1, 17):
        if not os.path.exists(file_head + '_layer_up_to_' + str(i) + '.pt'):
            break

    layer_num = i
    return layer_num

def IBP_for_first_ReLU_layer(resnet18, x, x0_lb, x0_ub, file_head):
    nn_sections = resnet18_act_decomposition(resnet18)
    truncated_section_0 = resnet18_act_truncation(nn_sections, 0)
    obj_options = {'rho': 1.0, 'c': 0}
    layer_sections_0 = generate_layer_sections(truncated_section_0, x, x0_lb, x0_ub, obj_options)
    IBP_lb_0 = layer_sections_0[-1].output_lb
    IBP_ub_0 = layer_sections_0[-1].output_ub

    pre_act_bds_0 = {'lb': [IBP_lb_0.to(torch.device('cpu'))], 'ub': [IBP_ub_0.to(torch.device('cpu'))]}
    torch.save(pre_act_bds_0, file_head + '_layer_0.pt')

    pre_act_bds_up_to_layer_1 = [pre_act_bds_0]
    torch.save(pre_act_bds_up_to_layer_1, file_head + '_layer_up_to_0.pt')

def IBP_for_first_two_ReLU_layers(resnet18, x, x0_lb, x0_ub, file_head):
    # compute and save the preactivation bounds for the first two ReLU layers

    nn_sections = resnet18_act_decomposition(resnet18)
    truncated_section_0 = resnet18_act_truncation(nn_sections, 0)
    obj_options = {'rho': 1.0, 'c': 0}
    layer_sections_0 = generate_layer_sections(truncated_section_0, x, x0_lb, x0_ub, obj_options)
    IBP_lb_0 = layer_sections_0[-1].output_lb
    IBP_ub_0 = layer_sections_0[-1].output_ub

    pre_act_bds_0 = {'lb': [IBP_lb_0.to(torch.device('cpu'))], 'ub': [IBP_ub_0.to(torch.device('cpu'))]}
    torch.save(pre_act_bds_0, file_head + '_layer_0.pt')

    # initialize the second layer
    truncated_section_1 = resnet18_act_truncation(nn_sections, 1)
    obj_options = {'rho': 1.0, 'c': 0}
    layer_sections_1 = generate_layer_sections(truncated_section_1, x, x0_lb, x0_ub, obj_options)
    IBP_lb_1 = layer_sections_1[-1].output_lb
    IBP_ub_1 = layer_sections_1[-1].output_ub

    pre_act_bds_1 = {'lb': [IBP_lb_1.to(torch.device('cpu'))], 'ub': [IBP_ub_1.to(torch.device('cpu'))]}
    torch.save(pre_act_bds_1, file_head + '_layer_1.pt')

    pre_act_bds_up_to_layer_1 = [pre_act_bds_0, pre_act_bds_1]
    torch.save(pre_act_bds_up_to_layer_1, file_head + '_layer_up_to_1.pt')

def compute_layer_wise_bounds_automatic(resnet18, x, x0_lb, x0_ub, batch_size, file_head, alg_options):
    # automatically detect which layer and which batch to start by searching over the saved bounds in file_head
    starting_layer_num = find_starting_layer_num(file_head)
    recent_pre_act_bds_up_to, running_time_list = compute_layer_wise_bounds(resnet18, x, x0_lb, x0_ub, batch_size, file_head, alg_options, start_layer = starting_layer_num)
    return recent_pre_act_bds_up_to, running_time_list

def compute_layer_wise_bounds(resnet18, x, x0_lb, x0_ub, batch_size, file_head, alg_options, start_layer):
    # initialize the first and second relu layers

    # if start_layer == 1:
    #     IBP_for_first_two_ReLU_layers(resnet18, x, x0_lb, x0_ub, file_head)
    #     start_layer = 2

    IBP_for_first_ReLU_layer(resnet18, x, x0_lb, x0_ub, file_head)
    # go through all the intermediate layers
    running_time_list = []
    recent_pre_act_bds_up_to = torch.load(file_head + '_layer_up_to_' + str(start_layer-1) + '.pt')

    for layer_num in tqdm(range(start_layer, 17), desc = 'laywer_wise_bds'):
        pre_act_bds, running_time = compute_layer_bounds(resnet18, x, x0_lb, x0_ub, layer_num, batch_size, file_head, alg_options)
        gc.collect()
        recent_pre_act_bds_up_to = recent_pre_act_bds_up_to + [pre_act_bds]
        torch.save(recent_pre_act_bds_up_to, file_head + '_layer_up_to_' + str(layer_num) + '.pt')
        running_time_list.append(running_time)
        print( f'Layer {layer_num} bounds computation finished. Running time {running_time}, accumulated running time {sum(running_time_list)}')

    return recent_pre_act_bds_up_to, running_time_list

def compute_layer_bounds(resnet18, x, x0_lb, x0_ub, layer_num, batch_size, file_head, alg_options):
    resnet18_nn_sections = resnet18_act_decomposition(resnet18)
    truncated_nn_sections = resnet18_act_truncation(resnet18_nn_sections, layer_num)

    if layer_num >= 1:
        bounds_to_load = file_head + '_layer_up_to_' + str(layer_num-1)
        pre_act_bds_list = torch.load(bounds_to_load + '.pt')

    new_file_head = file_head + '_layer_' + str(layer_num)
    pre_act_bds, running_time = compute_output_bounds(truncated_nn_sections, x, x0_lb, x0_ub, pre_act_bds_list, batch_size, new_file_head, alg_options)

    torch.save(pre_act_bds, new_file_head)
    print('total running time:' + str(running_time))
    return pre_act_bds, running_time

def compute_output_bounds_ADMM_wrap(nn_section, x, lb, ub, c, rho, alg_options, pre_act_bds_list):
    obj_options = {'rho': rho, 'c': c}
    ADMM_blocks, output_x = init_ADMM_session(nn_section, x, lb, ub, obj_options, pre_act_bds_list=pre_act_bds_list)

    # initialize the ADMM session
    ADMM_sess = ADMM_session(ADMM_blocks)
    '''attention: manual collection of garbage is necessary since we have reference cycles'''
    gc.collect()
    obj, running_time, result, termination_example_id, latest_peak, latest_valley = compute_output_bounds_ADMM(ADMM_sess, alg_options)
    return obj, running_time, result, termination_example_id, latest_peak, latest_valley

class Layerwise_bounds_ADMM(nn.Module):
    def __init__(self, nn_section, rho, alg_options, pre_act_bds_list):
        super().__init__()
        self.nn_section = nn_section
        self.rho = rho
        self.alg_options = alg_options
        self.pre_act_bds_list = pre_act_bds_list

    def forward(self, x, lb, ub, c):
        rho = self.rho
        alg_options = self.alg_options
        pre_act_bds_list = self.pre_act_bds_list
        nn_section = self.nn_section
        obj, running_time, result, termination_example_id, latest_peak, latest_valley = compute_output_bounds_ADMM_wrap(nn_section, x, lb, ub, c, rho, alg_options, pre_act_bds_list)

        if result['status'] == 'max_iter':
            warnings.warn('ADMM maximum iteration number reached.')

        # return latest_valley
        return obj

def detect_layer_bounds_starting_number(file_head_layer, total_iter, bound_type):
    for i in range(total_iter):
        if not os.path.exists(file_head_layer + '_' + bound_type +'_' + str(i) + '_of_' + str(total_iter) + '.pt'):
            break
    starting_iter = i
    return starting_iter

def compute_output_lower_bounds(nn_section, input_x, x0_lb, x0_ub, pre_act_bds_list, batch_size, file_head, alg_options):
    temp_opt = {'rho': 1.0, 'c': 0}
    layer_sections = generate_layer_sections(nn_section, input_x, x0_lb, x0_ub, temp_opt, pre_act_bds_list)

    output_dim = layer_sections[-1].output.size()

    # compute how many batches I need
    num_bds = layer_sections[-1].output.view(output_dim[0], -1).size(1)
    batch_list = [batch_size] * (num_bds // batch_size) + [num_bds % batch_size]
    outer_iter_num = len(batch_list)

    rho = alg_options['rho']
    start_iter = detect_layer_bounds_starting_number(file_head, outer_iter_num, 'lb')
    # compute the layerwise lower bounds
    if start_iter < outer_iter_num - 1:
        for k in tqdm(range(start_iter, outer_iter_num), desc='lower_bounds'):
            b_size = batch_list[k]
            accumulated_size = sum(batch_list[:k])
            verify_dim = [b_size] + list(output_dim)[1:]
            x = input_x.repeat(b_size, 1, 1, 1)
            lb = x0_lb.repeat(b_size, 1, 1, 1)
            ub = x0_ub.repeat(b_size, 1, 1, 1)

            # construct lower bounds objective functions
            c = torch.zeros(verify_dim)
            for i in range(b_size):
                vec = torch.zeros(output_dim[1:])
                vec.view(-1)[accumulated_size + i] = 1.0
                c[i, :] = vec

            c = c.to(input_x.device)

            ADMM_bounds_parallel = Layerwise_bounds_ADMM(nn_section, rho, alg_options, pre_act_bds_list)
            if torch.cuda.device_count() > 1:
                ADMM_bounds_parallel = nn.DataParallel(ADMM_bounds_parallel)
            ADMM_bounds_parallel.cuda()
            gc.collect()

            start_time = time.time()
            latest_valley = ADMM_bounds_parallel(x, lb, ub, c)
            running_time = time.time() - start_time

            print('running time:' + str(running_time))

            output_lb = latest_valley

            data_to_save = {'bounds': output_lb.to(torch.device('cpu')), 'running_time':running_time}
            torch.save(data_to_save, file_head + '_lb_' + str(k) + '_of_' + str(outer_iter_num) + '.pt')

    # summarize the results
    lb_list = []
    running_time_list = []
    for i in range(outer_iter_num):
        data = torch.load(file_head + '_lb_' + str(i) + '_of_' + str(outer_iter_num) + '.pt')
        lb_list.append(data['bounds'])
        running_time_list.append(data['running_time'])

    pre_act_lb = torch.cat(lb_list)
    pre_act_lb = pre_act_lb.reshape(output_dim)
    return pre_act_lb, running_time_list


def compute_output_upper_bounds(nn_section, input_x, x0_lb, x0_ub, pre_act_bds_list, batch_size, file_head, alg_options):
    temp_opt = {'rho': 1.0, 'c': 0}
    layer_sections = generate_layer_sections(nn_section, input_x, x0_lb, x0_ub, temp_opt, pre_act_bds_list)

    output_dim = layer_sections[-1].output.size()

    # compute how many batches I need
    num_bds = layer_sections[-1].output.view(output_dim[0], -1).size(1)
    batch_list = [batch_size] * (num_bds // batch_size) + [num_bds % batch_size]
    outer_iter_num = len(batch_list)

    rho = alg_options['rho']
    start_iter = detect_layer_bounds_starting_number(file_head, outer_iter_num, 'ub')
    # compute the layerwise lower bounds
    if start_iter < outer_iter_num - 1:
        for k in tqdm(range(start_iter, outer_iter_num), desc='upper_bounds'):

            b_size = batch_list[k]
            accumulated_size = sum(batch_list[:k])
            verify_dim = [b_size] + list(output_dim)[1:]
            x = input_x.repeat(b_size, 1, 1, 1)
            lb = x0_lb.repeat(b_size, 1, 1, 1)
            ub = x0_ub.repeat(b_size, 1, 1, 1)

            # construct lower bounds objective functions
            c = torch.zeros(verify_dim)
            for i in range(b_size):
                vec = torch.zeros(output_dim[1:])
                vec.view(-1)[accumulated_size + i] = -1.0
                c[i, :] = vec

            c = c.to(input_x.device)

            ADMM_bounds_parallel = Layerwise_bounds_ADMM(nn_section, rho, alg_options, pre_act_bds_list)
            if torch.cuda.device_count() > 1:
                ADMM_bounds_parallel = nn.DataParallel(ADMM_bounds_parallel)
            ADMM_bounds_parallel.cuda()
            gc.collect()

            start_time = time.time()
            latest_valley = ADMM_bounds_parallel(x, lb, ub, c)
            running_time = time.time() - start_time

            print('running time:' + str(running_time))

            output_ub = -latest_valley
            data_to_save = {'bounds': output_ub.to(torch.device('cpu')), 'running_time':running_time}
            torch.save(data_to_save, file_head + '_ub_' + str(k) + '_of_' + str(outer_iter_num) + '.pt')

    # summarize the results
    ub_list = []
    running_time_list = []
    for i in range(outer_iter_num):
        data = torch.load(file_head + '_ub_' + str(i) + '_of_' + str(outer_iter_num) + '.pt')
        ub_list.append(data['bounds'])
        running_time_list.append(data['running_time'])

    pre_act_ub = torch.cat(ub_list)
    pre_act_ub = pre_act_ub.reshape(output_dim)
    return pre_act_ub, running_time_list

def compute_output_bounds(nn_section, input_x, x0_lb, x0_ub, pre_act_bds_list, batch_size, file_head, alg_options):
    temp_opt = {'rho': 1.0, 'c': 0}
    layer_sections = generate_layer_sections(nn_section, input_x, x0_lb, x0_ub, temp_opt, pre_act_bds_list)

    output_dim = layer_sections[-1].output.size()
    # compute how many batches I need
    num_bds = layer_sections[-1].output.view(output_dim[0], -1).size(1)
    batch_list = [batch_size] * (num_bds // batch_size) + [num_bds % batch_size]
    outer_iter_num = len(batch_list)

    pre_act_lb, running_time_list_lb = compute_output_lower_bounds(nn_section, input_x, x0_lb, x0_ub, pre_act_bds_list, batch_size, file_head, alg_options)
    pre_act_ub, running_time_list_ub = compute_output_upper_bounds(nn_section, input_x, x0_lb, x0_ub, pre_act_bds_list, batch_size, file_head, alg_options)

    # check if lower bound is greater than the upper bound
    violation_idx = (pre_act_ub <= pre_act_lb)
    lb_ub_contradiction = 0
    if violation_idx.sum() >= 1:
        # candidate bounds obtained by IBP
        IBP_lb = layer_sections[-1].output_lb.to(pre_act_lb.device)
        IBP_ub = layer_sections[-1].output_ub.to(pre_act_ub.device)

        lb_ub_contradiction = 1
        pre_act_ub[violation_idx] = IBP_ub[violation_idx]
        pre_act_lb[violation_idx] = IBP_lb[violation_idx]

    pre_act_bds = {'lb': [pre_act_lb.to(torch.device('cpu'))], 'ub': [pre_act_ub.to(torch.device('cpu'))] }
    torch.save(pre_act_bds, file_head + '.pt')

    running_time_list = running_time_list_lb + running_time_list_ub
    diagnostics = {'running_time': running_time_list, 'pre_act_bds': pre_act_bds, 'lb_ub_contradiction': lb_ub_contradiction}
    torch.save(diagnostics, file_head + '_diagnostics.pt')

    return pre_act_bds, sum(running_time_list)


def compute_output_bounds_ADMM(ADMM_sess, alg_options):
    eps_abs = alg_options['eps_abs']
    eps_rel = alg_options['eps_rel']
    residual_balancing = alg_options['residual_balancing']
    max_iter = alg_options['max_iter']
    record = alg_options['record']
    verbose = alg_options['verbose']

    obj, running_time, result, termination_example_id, latest_peak, latest_valley = run_ADMM(ADMM_sess, eps_abs=eps_abs, eps_rel=eps_rel, residual_balancing=residual_balancing, max_iter=max_iter, record=record, verbose=verbose)
    gc.collect()

    return obj, running_time, result, termination_example_id, latest_peak, latest_valley
