# -*- coding: utf-8 -*-
import os
import tensorflow as tf

# Module is always loaded from hadamard/python folder location
from layers import HadamardDense, HadamardConv2D, SparseConv2D, StrHadamardDense, StrHadamardDenseV2, StrConv2D
from callbacks import HadamardCallback
from initializers import TwiceTruncatedNormalInitializer, ExactNormalFactorization, equivar_initializer, equivar_initializer_conv2d
from utils import GroupLassoRegularizer

# Components for ResNet construction
def regularized_padded_conv(filters, kernel_size, strides, depth, kernel_initializer, multfac_initializer, la, reg='l2', *args, **kwargs):
    
    if depth == 1:
        if reg == 'l2':
            regularizer = tf.keras.regularizers.L2(l2=la)
        else:
            regularizer = tf.keras.regularizers.L1(l1=la)
            
        return tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same', kernel_regularizer=regularizer,
                                      kernel_initializer='he_normal', use_bias=False, strides=strides, *args, **kwargs)
    else:
        #return HadamardConv2D(filters=filters, kernel_size=(kernel_size,kernel_size), strides=strides, use_bias=False, factorize_bias=False,
        #                      depth=depth, init_type=init_type, init=init,la=la, padding='same')
        return StrConv2D(filters=filters, kernel_size=(kernel_size,kernel_size), strides=strides, use_bias=False, #use_bias
                        depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,
                        la=la, padding='same')

def bn_relu(x):
    x = tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU()(x)


def shortcut(x, filters, stride, mode, depth,  kernel_initializer, multfac_initializer, la, reg):
    if x.shape[-1] == filters:
        return x
    elif mode == 'B':
        return regularized_padded_conv(filters=filters, kernel_size=1, strides=stride, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)(x)
    elif mode == 'B_original':
        x = regularized_padded_conv(filters=filters, kernel_size=1, strides=stride, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)(x)
        return tf.keras.layers.BatchNormalization()(x)
    elif mode == 'A':
        return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride>1 else x,
                      paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])])
    else:
        raise KeyError("Parameter shortcut_type not recognized!")
    

def original_block(x, filters, depth, kernel_initializer, multfac_initializer, la, reg, stride=1, **kwargs):
    c1 = regularized_padded_conv(filters=filters, kernel_size=3, strides=stride, depth=depth,  kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)(x)
    c2 = regularized_padded_conv(filters=filters, kernel_size=3, strides=1, depth=depth,  kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)(bn_relu(c1))
    c2 = tf.keras.layers.BatchNormalization()(c2)
    
    mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type
    x = shortcut(x, filters, stride, mode=mode, depth=depth,  kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    return tf.keras.layers.ReLU()(x + c2)
    
    
def preactivation_block(x, filters, depth, kernel_initializer, multfac_initializer, la, reg, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
        
    c1 = regularized_padded_conv(filters=filters, kernel_size=3, strides=stride, depth=depth,  kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)(flow)
    if _dropout:
        c1 = tf.keras.layers.Dropout(_dropout)(c1)
        
    c2 = regularized_padded_conv(filters=filters, kernel_size=3, strides=1, depth=depth,  kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)(bn_relu(c1))
    x = shortcut(x, filters, stride, mode=_shortcut_type, depth=depth, init_type=init_type, init=init, la=la, reg=reg)
    return x + c2


def bottleneck_block(x, filters, depth, kernel_initializer, multfac_initializer, la, reg, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
         
    c1 = regularized_padded_conv(filters=filters//_bottleneck_width, kernel_size=1, strides=1,depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,la=la,reg=reg)(flow)
    c2 = regularized_padded_conv(filters=filters//_bottleneck_width, kernel_size=3, strides=stride,depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,la=la,reg=reg)(bn_relu(c1))
    c3 = regularized_padded_conv(filters=filters, kernel_size=1, strides=1,depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,la=la,reg=reg)(bn_relu(c2))
    x = shortcut(x, filters, stride, mode=_shortcut_type, depth=depth,  kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    return x + c3


def group_of_blocks(x, block_type, num_blocks, filters, stride, depth, kernel_initializer, multfac_initializer, la, reg, block_idx=0):
    global _preact_shortcuts
    preact_block = True if _preact_shortcuts or block_idx == 0 else False
    
    x = block_type(x, filters=filters, depth=depth,  kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg, stride=stride, preact_block=preact_block)
    for i in range(num_blocks-1):
        x = block_type(x, filters=filters, depth=depth,  kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg, stride=1)
    return x

# ResNet model definition
def Resnet(input_shape, n_classes, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2),
           shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
           dropout=0, cardinality=1, bottleneck_width=4, preact_shortcuts=True, name = 'ResNet',
           use_bias=True, factorize_bias=False, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0, reg='l2'):
    
    global _shortcut_type, _preact_projection, _dropout, _cardinality, _bottleneck_width, _preact_shortcuts
    _bottleneck_width = bottleneck_width # used in ResNeXts and bottleneck blocks
    #_regularizer = tf.keras.regularizers.l2(l2_reg)
    _shortcut_type = shortcut_type # used in blocks
    _cardinality = cardinality # used in ResNeXts
    _dropout = dropout # used in Wide ResNets
    _preact_shortcuts = preact_shortcuts
    
    block_types = {'preactivated': preactivation_block,
                   'bottleneck': bottleneck_block,
                   'original': original_block}
    
    selected_block = block_types[block_type]
    inputs = tf.keras.layers.Input(shape=input_shape)
    flow = regularized_padded_conv(depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg, **first_conv)(inputs)
    #flow = regularized_padded_conv(filters=16, kernel_size=3, strides=1, depth=depth, init_type=init_type, init=init, la=la)(inputs)
    
    if block_type == 'original':
        flow = bn_relu(flow)
    
    for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)):
        flow = group_of_blocks(flow,
                               block_type=selected_block,
                               num_blocks=group_size,
                               filters=feature,
                               stride=stride,
                               depth=depth,
                               kernel_initializer=kernel_initializer, 
                               multfac_initializer=multfac_initializer,
                               la=la,
                               reg=reg,
                               block_idx=block_idx)
    
    if block_type != 'original':
        flow = bn_relu(flow)
    
    flow = tf.keras.layers.GlobalAveragePooling2D()(flow)
    
    if depth == 1:

        if reg == 'l2':
            reg_dense = tf.keras.regularizers.L2(l2=la)
        else:
            reg_dense = tf.keras.regularizers.L1(l1=la)
        
        outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=reg_dense, activation='softmax',
                                       use_bias=use_bias)(flow)
    else:
        #outputs = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type=init_type, init=init, 
        #                    use_bias=use_bias, factorize_bias=factorize_bias)(flow)
        if reg == 'l2':
            reg_dense = tf.keras.regularizers.L2(l2=la)
        else:
            reg_dense = tf.keras.regularizers.L1(l1=la)
        
        outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=reg_dense, activation='softmax',
                                       use_bias=use_bias)(flow)
        
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    return model


def load_weights_func(model, model_name):
    try: model.load_weights(os.path.join('saved_models', model_name + '.tf'))
    except tf.errors.NotFoundError: print("No weights found for this model!")
    return model

# ResNet architecture wrappers
# small ResNet18 with 0.2 mio params
def small_hadamard_resnet18(block_type='original', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10, reg='l2'):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(2, 2, 2), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters":16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'small_hadamard_cifar_resnet18',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    if load_weights: model = load_weights_func(model, 'hadamard_16f_resnet18')
    return model

# standard ResNet18 with 11.174 mio params
def hadamard_resnet18(block_type='original', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10, reg='l2'):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(2, 2, 2, 2), features=(64, 128, 256, 512),
                   strides=(1, 2, 2, 2), first_conv={"filters":64, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet18',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    if load_weights: model = load_weights_func(model, 'hadamard_64f_resnet18')
    return model

# standard ResNet34 with 22.282 mio params
def hadamard_resnet34(block_type='original', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10, reg='l2'):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(3, 4, 6, 3), features=(64, 128, 256, 512),
                   strides=(1, 2, 2, 2), first_conv={"filters":64, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet34',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg='l2')
    if load_weights: model = load_weights_func(model, 'hadamard_64f_resnet34')
    return model

# standard ResNet50 with 23.521 mio params
def hadamard_resnet50(block_type='bottleneck', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10, reg='l2'):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(3, 4, 6, 3), features=(256,512,1024,2048),
                   strides=(1, 2, 2, 2), first_conv={"filters":256, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet50',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    if load_weights: model = load_weights_func(model, 'hadamard_64f_resnet50')
    return model

# standard ResNet101
def hadamard_resnet101(block_type='bottleneck', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10,reg='l2'):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(3, 4, 23, 3), features=(256,512,1024,2048),
                   strides=(1, 2, 2, 2), first_conv={"filters":256, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet101',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    if load_weights: model = load_weights_func(model, 'hadamard_64f_resnet101')
    return model

# standard ResNet20 for cifar-like images
def hadamard_resnet20(block_type='original', shortcut_type='A', load_weights=False, #l2_reg=1e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10, reg='l2'):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(3, 3, 3), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet20',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet20')
    return model


def hadamard_resnet32(block_type='original', shortcut_type='A',  load_weights=False, #l2_reg=1e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10, reg='l2'):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(5, 5, 5), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet32',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet32')
    return model


def hadamard_resnet44(block_type='original', shortcut_type='A', load_weights=False, #l2_reg=1e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10, reg='l2'):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(7, 7, 7), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet44',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet44')
    return model


def hadamard_wide_resnet(N, K, block_type='preactivated', shortcut_type='B', dropout=0, #l2_reg=2.5e-4,
                      use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                      input_shape=(32, 32,3), n_classes=10,reg='l2'):
    assert (N-4) % 6 == 0, "N-4 has to be divisible by 6"
    lpb = (N-4) // 6 # layers per block - since N is total number of convolutional layers in Wide ResNet
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(lpb, lpb, lpb), features=(16*K, 32*K, 64*K),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type,
                   block_type=block_type, dropout=dropout, preact_shortcuts=True, name = 'hadamard_cifar_WRN_' + str(N) + '_' + str(K),
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, reg=reg)
    return model


def hadamard_WRN_16_4(shortcut_type='B', load_weights=False, dropout=0, #l2_reg=2.5e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10,reg='l2'):
    model = hadamard_wide_resnet(16, 4, 'preactivated', shortcut_type, dropout=dropout, #l2_reg=l2_reg,
                              use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la,
                              input_shape=input_shape, n_classes=n_classes, reg=reg)
    if load_weights: model = load_weights_func(model, 'cifar_WRN_16_4')
    return model



def hadamard_WRN_16_8(shortcut_type='B', load_weights=False, dropout=0, #l2_reg=2.5e-4,
                   use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                   input_shape=(32, 32,3), n_classes=10,reg='l2'):
    model = hadamard_wide_resnet(16, 8, 'preactivated', shortcut_type, dropout=dropout, #l2_reg=l2_reg,
                              use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la,
                              input_shape=input_shape, n_classes=n_classes, reg=reg)
    if load_weights: model = load_weights_func(model, 'cifar_WRN_16_8')
    return model


def hadamard_WRN_28_10(shortcut_type='B', load_weights=False, dropout=0, #l2_reg=2.5e-4,
                    use_bias=True, factorize_bias=True, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, la=0,
                    input_shape=(32, 32,3), n_classes=10,reg='l2'):
    model = hadamard_wide_resnet(28, 10, 'preactivated', shortcut_type, dropout=dropout, #l2_reg=l2_reg,
                              use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la,
                              input_shape=input_shape, n_classes=n_classes, reg=reg)
    if load_weights: model = load_weights_func(model, 'cifar_WRN_28_10')
    return model


#def cifar_resnext(N, cardinality, width, shortcut_type='B', use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal(), la=0,):
#    assert (N-3) % 9 == 0, "N-4 has to be divisible by 6"
#    lpb = (N-3) // 9 # layers per block - since N is total number of convolutional layers in Wide ResNet
#    model = Resnet(input_shape=(32, 32, 3), n_classes=10, l2_reg=1e-4, group_sizes=(lpb, lpb, lpb), features=(16*width, 32*width, 64*width),
#                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type,
#                   block_type='resnext', cardinality=cardinality, bottleneck_width=width,
#                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
#    return model

#mod = hadamard_resnet18(depth=1)
#mod.summary()

#mod2 = hadamard_resnet56()
#mod2.summary()

#mod3 = hadamard_WRN_16_8()
#mod3.summary()

# Different WRN-16-8 implementation (to check for equivalence)

# Hadamard WideResNet function (WRN-16-8 has dep=16 k=8)
# remove IN_FILTERS manual definitions when outside paperspace
IN_FILTERS=16 
def WideResNet(dep=16, k=8, input_shape = (32, 32, 3), n_classes = 10, depth=1, init_type='equivar', init = tf.keras.initializers.HeNormal(), la=0,
              use_bias=True, factorize_bias=True):
    print('Wide-Resnet %dx%d' %(dep, k))
    n_filters  = [16, 16*k, 32*k, 64*k]
    n_stack    = (dep - 4) // 6
    IN_FILTERS = 16 #remove?

    def conv3x3(x,filters, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
        return HadamardConv2D(filters=filters, kernel_size=(3,3), strides=1, padding='same', use_bias=False,
                              depth=depth, init_type=init_type, init=init, la=la)(x)
        #return Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1), padding='same',
        #kernel_initializer='he_normal',
        #kernel_regularizer=l2(WEIGHT_DECAY),
        #use_bias=False)(x)

    def bn_relu(x):
        x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
        x = tf.keras.layers.Activation('relu')(x)
        return x

    def residual_block(x, out_filters, increase=False, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
        global IN_FILTERS
        stride = 1
        if increase:
            stride = 2

        o1 = bn_relu(x)
        conv_1 = HadamardConv2D(filters=out_filters, kernel_size=(3,3), strides=stride,padding='same',use_bias=False,
                                depth=depth,init_type=init_type,init=init,la=la)(o1)
        #conv_1 = Conv2D(out_filters,kernel_size=(3,3),strides=stride,padding='same',kernel_initializer='he_normal',
        #    kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(o1)
        o2 = bn_relu(conv_1)
        conv_2 = HadamardConv2D(filters=out_filters,kernel_size=(3,3),strides=1,padding='same',use_bias=False,
                                depth=depth,init_type=init_type,init=init,la=la)(o2)
        #conv_2 = Conv2D(out_filters,kernel_size=(3,3), strides=(1,1), padding='same',kernel_initializer='he_normal',
        #    kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(o2)
        if increase or IN_FILTERS != out_filters:
            proj = HadamardConv2D(filters=out_filters,kernel_size=(1,1),strides=stride,padding='same',use_bias=False,
                                depth=depth,init_type=init_type,init=init,la=la)(o1)
            #proj = Conv2D(out_filters,kernel_size=(1,1),strides=stride,padding='same',kernel_initializer='he_normal',
            #                    kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(o1)
            block = tf.keras.layers.add([conv_2, proj])
        else:
            block = tf.keras.layers.add([conv_2,x])
        return block

    def wide_residual_layer(x,out_filters,increase=False, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
        global IN_FILTERS
        x = residual_block(x,out_filters,increase, depth=depth, init_type=init_type, init = init, la=la)
        IN_FILTERS = out_filters
        for _ in range(1,int(n_stack)):
            x = residual_block(x,out_filters, depth=depth, init_type=init_type, init = init, la=la)
        return x

    x_input = tf.keras.layers.Input(input_shape)
    x = conv3x3(x_input,n_filters[0], depth=depth, init_type=init_type, init = init, la=la) #img_input
    x = wide_residual_layer(x,n_filters[1], depth=depth, init_type=init_type, init = init, la=la)
    x = wide_residual_layer(x,n_filters[2], increase=True, depth=depth, init_type=init_type, init = init, la=la)
    x = wide_residual_layer(x,n_filters[3], increase=True, depth=depth, init_type=init_type, init = init, la=la)
    x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.AveragePooling2D((8,8))(x)
    x = tf.keras.layers.Flatten()(x)
    x = HadamardDense(units=n_classes,activation='softmax',depth=depth,use_bias=use_bias,factorize_bias=factorize_bias,
                      init_type=init_type,init=init,la=la)(x)
    #x = Dense(n_classes,activation='softmax',kernel_initializer='he_normal',kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(x)

    # define model
    model = tf.keras.models.Model(inputs = x_input, outputs = x, name = "WideResNet")
    return model

# Original WideResNet function (WRN-16-8 has dep=16 k=8)
# remove IN_FILTERS manual definitions when outside paperspace
def WideResNetVanilla(dep=16, k=8, input_shape = (32, 32, 3), n_classes = 10, la=0,
              use_bias=True):
    print('Wide-Resnet %dx%d' %(dep, k))
    n_filters  = [16, 16*k, 32*k, 64*k]
    n_stack    = (dep - 4) // 6
    IN_FILTERS = 16 #remove?
    l2 = tf.keras.regularizers.L2

    def conv3x3(x,filters, depth=1, la=0):
        #return HadamardConv2D(filters=filters, kernel_size=(3,3), strides=1, padding='same', use_bias=False,
        #                      depth=depth, init_type=init_type, init=init, la=la)(x)
        return tf.keras.layers.Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1), padding='same',
                      kernel_initializer='he_normal',kernel_regularizer=l2(la),use_bias=False)(x)

    def bn_relu(x):
        x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
        x = tf.keras.layers.Activation('relu')(x)
        return x

    def residual_block(x, out_filters, increase=False, la=0):
        global IN_FILTERS
        stride = 1
        if increase:
            stride = 2

        o1 = bn_relu(x)
        #conv_1 = HadamardConv2D(filters=out_filters, kernel_size=(3,3), strides=stride,padding='same',use_bias=False,
        #                        depth=depth,init_type=init_type,init=init,la=la)(o1)
        conv_1 = tf.keras.layers.Conv2D(out_filters,kernel_size=(3,3),strides=stride,padding='same',kernel_initializer='he_normal',
            kernel_regularizer=l2(la),use_bias=False)(o1)
        o2 = bn_relu(conv_1)
        #conv_2 = HadamardConv2D(filters=out_filters,kernel_size=(3,3),strides=1,padding='same',use_bias=False,
        #                        depth=depth,init_type=init_type,init=init,la=la)(o2)
        conv_2 = tf.keras.layers.Conv2D(out_filters,kernel_size=(3,3), strides=(1,1), padding='same',kernel_initializer='he_normal',
            kernel_regularizer=l2(la),use_bias=False)(o2)
        if increase or IN_FILTERS != out_filters:
            #proj = HadamardConv2D(filters=out_filters,kernel_size=(1,1),strides=stride,padding='same',use_bias=False,
            #                    depth=depth,init_type=init_type,init=init,la=la)(o1)
            proj = tf.keras.layers.Conv2D(out_filters,kernel_size=(1,1),strides=stride,padding='same',kernel_initializer='he_normal',
                                kernel_regularizer=l2(la),use_bias=False)(o1)
            block = tf.keras.layers.add([conv_2, proj])
        else:
            block = tf.keras.layers.add([conv_2,x])
        return block

    def wide_residual_layer(x,out_filters,increase=False, la=0):
        global IN_FILTERS
        x = residual_block(x,out_filters,increase, la=la)
        IN_FILTERS = out_filters
        for _ in range(1,int(n_stack)):
            x = residual_block(x,out_filters, la=la)
        return x

    x_input = tf.keras.layers.Input(input_shape)
    x = conv3x3(x_input,n_filters[0], la=la) #img_input
    x = wide_residual_layer(x,n_filters[1], la=la)
    x = wide_residual_layer(x,n_filters[2], increase=True,la=la)
    x = wide_residual_layer(x,n_filters[3], increase=True,la=la)
    x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.AveragePooling2D((8,8))(x)
    x = tf.keras.layers.Flatten()(x)
    #x = HadamardDense(units=n_classes,activation='softmax',depth=depth,use_bias=use_bias,factorize_bias=factorize_bias,
    #                  init_type=init_type,init=init,la=la)(x)
    x = tf.keras.layers.Dense(n_classes,activation='softmax',kernel_initializer='he_normal',kernel_regularizer=l2(la),use_bias=False)(x)

    # define model
    model = tf.keras.models.Model(inputs = x_input, outputs = x, name = "WideResNet")
    return model


###################### Vanilla ResNet-18 (cifar version) #########################

# Components for ResNet construction
def regularized_padded_conv_v(filters, kernel_size, strides, la, reg='l2', *args, **kwargs):
    
    if reg == 'l2':
        regularizer = tf.keras.regularizers.L2(l2=la)
    else:
        regularizer = tf.keras.regularizers.L1(l1=la)
            
    return tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same', kernel_regularizer=regularizer,
                                    kernel_initializer='he_normal', use_bias=False, strides=strides, *args, **kwargs)

def bn_relu(x):
    x = tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU()(x)


def shortcut_v(x, filters, stride, mode, la, reg):
    if x.shape[-1] == filters:
        return x
    elif mode == 'B':
        return regularized_padded_conv_v(filters=filters, kernel_size=1, strides=stride, la=la, reg=reg)(x)
    elif mode == 'B_original':
        x = regularized_padded_conv_v(filters=filters, kernel_size=1, strides=stride, la=la, reg=reg)(x)
        return tf.keras.layers.BatchNormalization()(x)
    elif mode == 'A':
        return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride>1 else x,
                      paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])])
    else:
        raise KeyError("Parameter shortcut_type not recognized!")
    

def original_block_v(x, filters, la, reg, stride=1, **kwargs):
    c1 = regularized_padded_conv_v(filters=filters, kernel_size=3, strides=stride, la=la, reg=reg)(x)
    c2 = regularized_padded_conv_v(filters=filters, kernel_size=3, strides=1, la=la, reg=reg)(bn_relu(c1))
    c2 = tf.keras.layers.BatchNormalization()(c2)
    
    mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type
    x = shortcut_v(x, filters, stride, mode=mode, la=la, reg=reg)
    return tf.keras.layers.ReLU()(x + c2)
    
    
def preactivation_block_v(x, filters, la, reg, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
        
    c1 = regularized_padded_conv_v(filters=filters, kernel_size=3, strides=stride, la=la, reg=reg)(flow)
    if _dropout:
        c1 = tf.keras.layers.Dropout(_dropout)(c1)
        
    c2 = regularized_padded_conv_v(filters=filters, kernel_size=3, strides=1, la=la, reg=reg)(bn_relu(c1))
    x = shortcut_v(x, filters, stride, mode=_shortcut_type, la=la, reg=reg)
    return x + c2


def bottleneck_block_v(x, filters, la, reg, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
         
    c1 = regularized_padded_conv_v(filters=filters//_bottleneck_width, kernel_size=1, strides=1,la=la,reg=reg)(flow)
    c2 = regularized_padded_conv_v(filters=filters//_bottleneck_width, kernel_size=3, strides=stride,la=la,reg=reg)(bn_relu(c1))
    c3 = regularized_padded_conv_v(filters=filters, kernel_size=1, strides=1,la=la,reg=reg)(bn_relu(c2))
    x = shortcut_v(x, filters, stride, mode=_shortcut_type, la=la, reg=reg)
    return x + c3


def group_of_blocks_v(x, block_type, num_blocks, filters, stride, la, reg, block_idx=0):
    global _preact_shortcuts
    preact_block = True if _preact_shortcuts or block_idx == 0 else False
    
    x = block_type(x, filters=filters, la=la, reg=reg, stride=stride, preact_block=preact_block)
    for i in range(num_blocks-1):
        x = block_type(x, filters=filters, la=la, reg=reg, stride=1)
    return x

# ResNet model definition
def Resnet_v(input_shape, n_classes, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2),
           shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
           dropout=0, cardinality=1, bottleneck_width=4, preact_shortcuts=True, name = 'ResNet',
           use_bias=True, la=0, reg='l2'):
    
    global _shortcut_type, _preact_projection, _dropout, _cardinality, _bottleneck_width, _preact_shortcuts
    _bottleneck_width = bottleneck_width # used in ResNeXts and bottleneck blocks
    #_regularizer = tf.keras.regularizers.l2(l2_reg)
    _shortcut_type = shortcut_type # used in blocks
    _cardinality = cardinality # used in ResNeXts
    _dropout = dropout # used in Wide ResNets
    _preact_shortcuts = preact_shortcuts
    
    block_types_v = {'preactivated': preactivation_block_v,
                   'bottleneck': bottleneck_block_v,
                   'original': original_block_v}
    
    selected_block = block_types_v[block_type]
    inputs = tf.keras.layers.Input(shape=input_shape)
    flow = regularized_padded_conv_v(la=la, reg=reg, **first_conv)(inputs)
    #flow = regularized_padded_conv(filters=16, kernel_size=3, strides=1, depth=depth, init_type=init_type, init=init, la=la)(inputs) # filter size for resnet20/cifar10
    
    if block_type == 'original':
        flow = bn_relu(flow)
    
    for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)):
        flow = group_of_blocks_v(flow,
                               block_type=selected_block,
                               num_blocks=group_size,
                               filters=feature,
                               stride=stride,
                               la=la,
                               reg=reg,
                               block_idx=block_idx)
    
    if block_type != 'original':
        flow = bn_relu(flow)
    
    flow = tf.keras.layers.GlobalAveragePooling2D()(flow)

    if reg == 'l2':
        reg_dense = tf.keras.regularizers.L2(l2=la)
    else:
        reg_dense = tf.keras.regularizers.L1(l1=la)

    outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=reg_dense, activation='softmax',
                                   use_bias=use_bias)(flow)
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    return model



# ResNet architecture wrappers
# small ResNet18 with 0.2 mio params
def small_vanilla_resnet18(block_type='original', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, la=0,
                   input_shape=(32, 32,3), n_classes=10, reg='l2'):
    model = Resnet_v(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(2, 2, 2), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters":16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'small_vanilla_cifar_resnet18',
                   use_bias=use_bias, la=la, reg=reg)
    if load_weights: model = load_weights_func(model, 'vanilla_16f_resnet18')
    return model

# standard ResNet18 with 11.174 mio params
def vanilla_resnet18(block_type='original', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=1, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=100, reg='l2'):
    model = Resnet_v(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(2, 2, 2, 2), features=(64, 128, 256, 512),
                   strides=(1, 2, 2, 2), first_conv={"filters":64, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'vanilla_cifar_resnet18',
                   use_bias=use_bias, la=la, reg=reg)
    if load_weights: model = load_weights_func(model, 'vanilla_64f_resnet18')
    return model


# --- New Individual Variant Functions ---

def individual_original_block_v(x, block_filter, la, reg, stride=1, **kwargs):
    """
    Individual variant of the original block that skips any conv layer with 0 filters.
    block_filter: list of filter counts.
      - If length == 2, no projection shortcut is used.
        [conv1_filter, conv2_filter]
      - If length == 3, a projection shortcut is applied.
        [conv1_filter, proj_filter, conv2_filter]
    """
    # --- SKIP ENTIRE BLOCK WHEN THE FIRST CONV IS PRUNED AWAY ---
    if block_filter[0] == 0:
        return x
    
    if len(block_filter) == 2:
        conv_applied = False
        # First convolution
        if block_filter[0] > 0:
            c1 = regularized_padded_conv_v(filters=block_filter[0], kernel_size=3, strides=stride, la=la, reg=reg)(x)
            c1 = bn_relu(c1)
            conv_applied = True
        else:
            c1 = x
        # Second convolution
        if block_filter[1] > 0:
            c2 = regularized_padded_conv_v(filters=block_filter[1], kernel_size=3, strides=1, la=la, reg=reg)(c1)
            c2 = tf.keras.layers.BatchNormalization()(c2)
            conv_applied = True
        else:
            c2 = c1
        if not conv_applied:
            return x  # No conv applied; return identity.
        # Compute shortcut: use the desired number of filters (if >0) or fall back to input channels.
        target_filters = block_filter[1] if block_filter[1] > 0 else int(x.shape[-1])
        shortcut = shortcut_v(x, filters=target_filters, stride=stride, mode=_shortcut_type, la=la, reg=reg)
        out = tf.keras.layers.ReLU()(tf.keras.layers.add([shortcut, c2]))
        return out

    elif len(block_filter) == 3:
        conv_applied = False
        # First conv (main branch)
        if block_filter[0] > 0:
            c1 = regularized_padded_conv_v(filters=block_filter[0], kernel_size=3, strides=stride, la=la, reg=reg)(x)
            c1 = bn_relu(c1)
            conv_applied = True
        else:
            c1 = x
        # Shortcut branch via projection
        if block_filter[1] > 0:
            shortcut = regularized_padded_conv_v(filters=block_filter[1], kernel_size=1, strides=stride, la=la, reg=reg)(x)
            shortcut = tf.keras.layers.BatchNormalization()(shortcut)
            conv_applied = True
        else:
            shortcut = x
        # Second conv (main branch)
        if block_filter[2] > 0:
            c2 = regularized_padded_conv_v(filters=block_filter[2], kernel_size=3, strides=1, la=la, reg=reg)(c1)
            c2 = tf.keras.layers.BatchNormalization()(c2)
            conv_applied = True
        else:
            c2 = c1
        if not conv_applied:
            return x
        out = tf.keras.layers.ReLU()(tf.keras.layers.add([shortcut, c2]))
        return out

    else:
        raise ValueError("block_filter length must be 2 or 3.")


def individual_preactivation_block_v(x, block_filter, la, reg, stride=1, preact_block=False):
    """
    Individual variant of the preactivation block that skips any conv layer with 0 filters.
    block_filter: list of filter counts.
      - If length == 2, no projection shortcut is used.
      - If length == 3, a projection shortcut is applied.
    """
    # If the first 3×3 conv is pruned, skip the whole block
    if block_filter[0] == 0:
        return x
    
    flow = bn_relu(x)
    if preact_block:
        x = flow

    if len(block_filter) == 2:
        conv_applied = False
        if block_filter[0] > 0:
            c1 = regularized_padded_conv_v(filters=block_filter[0], kernel_size=3, strides=stride, la=la, reg=reg)(flow)
            c1 = bn_relu(c1)
            conv_applied = True
        else:
            c1 = flow
        if block_filter[1] > 0:
            c2 = regularized_padded_conv_v(filters=block_filter[1], kernel_size=3, strides=1, la=la, reg=reg)(bn_relu(c1))
            conv_applied = True
        else:
            c2 = bn_relu(c1)
        if not conv_applied:
            return x
        target_filters = block_filter[1] if block_filter[1] > 0 else int(x.shape[-1])
        shortcut = shortcut_v(x, filters=target_filters, stride=stride, mode=_shortcut_type, la=la, reg=reg)
        out = tf.keras.layers.add([shortcut, c2])
        return out

    elif len(block_filter) == 3:
        conv_applied = False
        if block_filter[0] > 0:
            c1 = regularized_padded_conv_v(filters=block_filter[0], kernel_size=3, strides=stride, la=la, reg=reg)(flow)
            c1 = bn_relu(c1)
            conv_applied = True
        else:
            c1 = flow
        if block_filter[1] > 0:
            shortcut = regularized_padded_conv_v(filters=block_filter[1], kernel_size=1, strides=stride, la=la, reg=reg)(x)
            shortcut = tf.keras.layers.BatchNormalization()(shortcut)
            conv_applied = True
        else:
            shortcut = x
        if block_filter[2] > 0:
            c2 = regularized_padded_conv_v(filters=block_filter[2], kernel_size=3, strides=1, la=la, reg=reg)(bn_relu(c1))
            conv_applied = True
        else:
            c2 = bn_relu(c1)
        if not conv_applied:
            return x
        out = tf.keras.layers.add([shortcut, c2])
        return out

    else:
        raise ValueError("block_filter length must be 2 or 3.")


# --- New ResNet Model Builder Using Individual Blocks ---

def Resnet_v_individual(input_shape, n_classes, block_filters=None, shortcut_type='B', 
                          block_type='original', dropout=0, preact_shortcuts=True, 
                          name='ResNet_individual', use_bias=True, la=0, reg='l2'):
    """
    Builds a ResNet model using the individual block definitions.
    
    block_filters: nested list controlling filters for all Conv2D layers.
    Expected structure:
    
        block_filters = [
            stem_filters,   # integer for the stem conv layer
            [
                [                   # Group 1
                    [conv1, conv2],         # Block 1 (no projection shortcut)
                    [conv1, conv2]          # Block 2 (no projection shortcut)
                ],
                [                   # Group 2
                    [conv1, proj, conv2],   # Block 1 (with projection shortcut)
                    [conv1, conv2]          # Block 2 (no projection shortcut)
                ],
                [                   # Group 3
                    [conv1, proj, conv2],
                    [conv1, conv2]
                ],
                [                   # Group 4
                    [conv1, proj, conv2],
                    [conv1, conv2]
                ]
            ]
        ]
        
    The default value corresponds to the original ResNet-18 architecture.
    """
    
    global _shortcut_type
    _shortcut_type = shortcut_type

    if block_filters is None:
        block_filters = [
            64,  # Stem conv (conv2d_60)
            [
                [   # Group 1 (64 feature maps)
                    [64, 64],         # Block 1: conv2d_61, conv2d_62
                    [64, 64]          # Block 2: conv2d_63, conv2d_64
                ],
                [   # Group 2 (128 feature maps)
                    [128, 128, 128],  # Block 1: conv2d_65, shortcut conv2d_67, conv2d_66
                    [128, 128]        # Block 2: conv2d_68, conv2d_69
                ],
                [   # Group 3 (256 feature maps)
                    [256, 256, 256],  # Block 1: conv2d_70, shortcut conv2d_72, conv2d_71
                    [256, 256]        # Block 2: conv2d_73, conv2d_74
                ],
                [   # Group 4 (512 feature maps)
                    [512, 512, 512],  # Block 1: conv2d_75, shortcut conv2d_77, conv2d_76
                    [512, 512]        # Block 2: conv2d_78, conv2d_79
                ]
            ]
        ]
    inputs = tf.keras.layers.Input(shape=input_shape)
    stem_filters = block_filters[0]
    x = regularized_padded_conv_v(filters=stem_filters, kernel_size=3, strides=1, la=la, reg=reg)(inputs)
    x = bn_relu(x)
    groups = block_filters[1]
    global _dropout
    _dropout = dropout
    for group_idx, group in enumerate(groups):
        group_stride = 1 if group_idx == 0 else 2
        for block_idx, block_filter in enumerate(group):
            stride = group_stride if block_idx == 0 else 1
            if block_type == 'original':
                x = individual_original_block_v(x, block_filter, la, reg, stride=stride)
            elif block_type == 'preactivated':
                preact = True if preact_shortcuts or block_idx == 0 else False
                x = individual_preactivation_block_v(x, block_filter, la, reg, stride=stride, preact_block=preact)
            else:
                raise ValueError("Unsupported block type: {}".format(block_type))
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    reg_dense = tf.keras.regularizers.L2(l2=la) if reg == 'l2' else tf.keras.regularizers.L1(l1=la)
    outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=reg_dense, activation='softmax',
                                    use_bias=use_bias)(x)
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    return model

# --- ResNet Architecture Wrapper for the Individual Variant ---

def vanilla_resnet18_individual(block_type='original', shortcut_type='B', load_weights=False,
                                use_bias=False, la=0, input_shape=(32, 32, 3), n_classes=100, reg='l2', block_filters=None):
    model = Resnet_v_individual(input_shape=input_shape, n_classes=n_classes,
                                block_filters=block_filters,  # default block_filters (original ResNet-18)
                                shortcut_type=shortcut_type, block_type=block_type, preact_shortcuts=False,
                                name='vanilla_cifar_resnet18_individual', use_bias=use_bias, la=la, reg=reg)
    # if load_weights: model = load_weights_func(model, 'vanilla_individual_resnet18')
    return model

### Regularized L21 and Network Slimming ResNet18 definitions

# 1. Helper layers

def bn_relu_reg(x, use_bn, la_gamma):
    if use_bn:
        x = tf.keras.layers.BatchNormalization(
            gamma_regularizer=tf.keras.regularizers.L1(l1=la_gamma)
        )(x)
    return tf.keras.layers.ReLU()(x)

def conv_reg(filters, kernel_size, strides, la, use_bias, **kwargs):
    return tf.keras.layers.Conv2D(
        filters=filters,
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        use_bias=use_bias,
        kernel_initializer='he_normal',
        kernel_regularizer=GroupLassoRegularizer(lam=la, axis=3),
        **kwargs
    )

# 2. Shortcuts

def shortcut_reg(x, filters, stride, mode):
    if x.shape[-1] == filters:
        return x
    if mode == 'A':
        y = tf.keras.layers.MaxPool2D(1, stride)(x) if stride>1 else x
        return tf.pad(y, [(0,0),(0,0),(0,0),(0, filters - x.shape[-1])])
    if mode in ('B', 'B_original'):
        y = conv_reg(filters=filters, kernel_size=1, strides=stride, la=_la, use_bias=False)(x)
        return tf.keras.layers.BatchNormalization()(y) if mode=='B_original' else y
    raise ValueError

# 3. Residual blocks

def original_block_reg(x, filters, stride, la, la_gamma, mode='B'):
    c1 = conv_reg(filters, 3, stride, la, use_bias=False)(x)
    c1 = bn_relu_reg(c1, True, la_gamma)
    c2 = conv_reg(filters, 3, 1, la, use_bias=False)(c1)
    c2 = tf.keras.layers.BatchNormalization()(c2)
    sc = shortcut_reg(x, filters, stride, mode)
    return tf.keras.layers.ReLU()(c2 + sc)

def preact_block_reg(x, filters, stride, la, la_gamma, preact_first=False, mode='B'):
    flow = bn_relu_reg(x, True, la_gamma)
    if preact_first: x = flow
    c1 = conv_reg(filters, 3, stride, la, use_bias=False)(flow)
    c1 = bn_relu_reg(c1, True, la_gamma)
    c2 = conv_reg(filters, 3, 1, la, use_bias=False)(c1)
    sc = shortcut_reg(x, filters, stride, mode)
    return c2 + sc

# 4. Model builder

def Resnet_v_reg(input_shape, n_classes,
                 group_sizes=(2,2,2,2),
                 features=(64,128,256,512),
                 strides=(1,2,2,2),
                 shortcut_type='B',
                 block_type='original',
                 use_bias=False,
                 la=0,
                 la_gamma=0,
                 preact_shortcuts=False,
                 name='resnet18_reg'):

    global _la
    _la = la
    mode = 'B_original' if shortcut_type=='B' else shortcut_type

    inputs = tf.keras.layers.Input(shape=input_shape)
    x = conv_reg(filters=features[0], kernel_size=3, strides=strides[0],
                 la=la, use_bias=use_bias)(inputs)
    if block_type!='original':
        x = bn_relu_reg(x, True, la_gamma)

    for i, (n_blocks, f, s) in enumerate(zip(group_sizes, features, strides)):
        for j in range(n_blocks):
            st = s if j==0 else 1
            if block_type=='original':
                x = original_block_reg(x, f, st, la, la_gamma, mode=mode)
            else:
                first = (i>0 and j==0) or preact_shortcuts
                x = preact_block_reg(x, f, st, la, la_gamma, preact_first=first, mode=mode)

    x = bn_relu_reg(x, True, la_gamma)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    out = tf.keras.layers.Dense(
        n_classes,
        activation='softmax',
        kernel_regularizer=tf.keras.regularizers.L2(l2=la),
        use_bias=use_bias
    )(x)

    return tf.keras.Model(inputs, out, name=name)

# 5. User wrapper

def vanilla_resnet18_reg(load_weights=False,
                         input_shape=(32,32,3),
                         n_classes=100,
                         use_bias=False,
                         la=0,
                         la_gamma=0,
                         block_type='original',
                         shortcut_type='B',
                         preact_shortcuts=False):

    model = Resnet_v_reg(
        input_shape=input_shape,
        n_classes=n_classes,
        group_sizes=(2,2,2,2),
        features=(64,128,256,512),
        strides=(1,2,2,2),
        shortcut_type=shortcut_type,
        block_type=block_type,
        use_bias=use_bias,
        la=la,
        la_gamma=la_gamma,
        preact_shortcuts=preact_shortcuts,
        name='vanilla_resnet18_reg'
    )
    if load_weights:
        try:
            model.load_weights(os.path.join('saved_models','vanilla_resnet18_reg.tf'))
        except tf.errors.NotFoundError:
            print("No weights found!")
    return model
