def count_flops_dense(in_features, out_features,
        bias=True, activation=True):
    flops = (2*in_features-1)*out_features
    if bias:
        flops += out_features
    if activation:
        flops += out_features
    return flops


def count_flops_conv(height, width, in_channels, out_channels, kernel_size,
        stride=1, padding=0, bias=True, activation=True):
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size]*2
    n = kernel_size[0] * kernel_size[1] * in_channels
    flops_per_instance = 2*n - 1 # per one element in activation out put num of mult and add
    out_height = (height - kernel_size[0] + 2*padding) / stride + 1
    out_width = (width - kernel_size[1] + 2*padding) / stride + 1
    num_instances_per_channel = out_height * out_width
    flops_per_channel = num_instances_per_channel * flops_per_instance
    total_flops = out_channels * flops_per_channel
    if bias:
        total_flops += out_channels * num_instances_per_channel
    if activation:
        total_flops += out_channels * num_instances_per_channel
    return total_flops

def count_flops_dense_dbb(num_gates):
    # mask construction + multiplication
    total_flops = 5 * num_gates
    return total_flops

def count_flops_conv_dbb(height, width, num_gates):
    # global avg pool
    total_flops = num_gates * height * width
    # mask construction
    total_flops += 4 * num_gates
    # mask multiplication
    total_flops += num_gates * height * width
    return total_flops

def count_flops_max_pool(height, width, channels, kernel_size,
        stride=None, padding=0):
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size]*2
    stride = kernel_size if stride is None else stride
    if isinstance(stride, int):
        stride = [stride]*2
    flops_per_instance = kernel_size[0] * kernel_size[1]
    out_height = (height - kernel_size[0] + 2*padding) / stride[0] + 1
    out_width = (width - kernel_size[1] + 2*padding) / stride[1] + 1
    num_instances_per_channel = out_height * out_width
    flops_per_channel = num_instances_per_channel * flops_per_instance
    total_flops = channels * flops_per_channel
    return total_flops

def count_flops_global_avg_pool(height, width, channels):
    return channels * height * width
def count_memory_dense(in_features, out_features, bias=True, batch_norm=False):
    mem = in_features * out_features
    if bias:
        mem += out_features
    if batch_norm:
        mem += 2 * in_features
    return mem

def count_memory_conv(height, width, in_channels, out_channels, kernel_size,
        stride=1, padding=0, bias=True, batch_norm=False):
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size]*2
    n = kernel_size[0] * kernel_size[1] * in_channels
    out_height = (height - kernel_size[0] + 2*padding) / stride + 1
    out_width = (width - kernel_size[1] + 2*padding) / stride + 1
    mem_fmap = n * out_height * out_width
    mem_kernel = n * out_channels
    mem = mem_fmap + mem_kernel
    if batch_norm:
        mem += 2 * out_channels
    return mem

def count_memory_dbb(num_gates):
    return 4 * num_gates

def func(num_units):
    l=[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'A']
    idx=0
    pidx=0
    for i in l:
        if i != 'M' and i != 'A':
            l[idx]=num_units[pidx]
            idx=idx+1
            pidx=pidx+1
        else:
            idx=idx+1
    return l



def count_flops( num_units):
        flops=0
        idx=0
        height=32
        num_units=[3]+num_units
        cfg=func(num_units)
        for v in cfg:
            if v == 'M':
                flops += count_flops_max_pool(height, height, num_units[idx], kernel_size=2, stride=2)
                height=height/2
            elif v == 'A':
                flops += count_flops_max_pool(height, height, num_units[idx], kernel_size=2, stride=2)
                height=height/2
            else:
                flops+= count_flops_conv(height, height,num_units[idx] ,num_units[idx+1], 3, padding=1)
                idx=idx+1
        flops+=count_flops_dense(num_units[-1],10)


        return flops
def count_memory( num_units):
        mem=0
        idx=0
        height=32
        num_units=[3]+num_units
        cfg=func(num_units)
        for v in cfg:
            if v == 'M':
                height=height/2
            elif v == 'A':
                height=height/2
            else:
                mem+= count_memory_conv(height, height,num_units[idx] ,num_units[idx+1], 3, padding=1)
                idx=idx+1
        mem+=count_memory_dense(num_units[-1],10)

        return mem
