import os
import tensorflow as tf

# Module is always loaded from hadamard/python folder location
from layers import HadamardDense, HadamardConv2D, SparseConv2D, StrHadamardDense, StrHadamardDenseV2, StrConv2D
from initializers import TwiceTruncatedNormalInitializer, ExactNormalFactorization, equivar_initializer, equivar_initializer_conv2d
from utils import GroupLassoRegularizer

# Components for VGG construction
def bn_relu(x, use_bn):
    if use_bn:
        x = tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU()(x)

###########################################################################################################

# Vanilla VGG with L2 regularization

# vanilla Conv2D modules
def vanilla_regularized_padded_conv3(filters, init, la, use_bias, *args, **kwargs):
    return tf.keras.layers.Conv2D(*args, **kwargs, filters=filters, kernel_size=3, padding='same', kernel_initializer=init, 
                                  use_bias=use_bias, kernel_regularizer=tf.keras.regularizers.L2(l2 = la))

def vanilla_conv_block(x, num_blocks, filters, pool, init, la, use_bn, use_bias):
    x = vanilla_regularized_padded_conv3(filters=filters, init=init, la=la, use_bias=use_bias)(x)
    for i in range(num_blocks-1):
        x = vanilla_regularized_padded_conv3(filters=filters, init=init, la=la, use_bias=use_bias)(bn_relu(x, use_bn=use_bn))
    x = tf.keras.layers.MaxPool2D(pool)(bn_relu(x, use_bn=use_bn))
    return x

# Define VGG template without sparse regularization (L2 only)
def VanillaVGG(input_shape, n_classes, group_sizes=(1, 1, 2, 2, 2), features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2),
               init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,  #HeNormal()
               use_bias=True, factorize_bias=False, dense_units=(4096,4096)):
    # Inputs
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # Conv flow
    flow = inputs
    for group_size, feature, pool in zip(group_sizes, features, pools):
        flow = vanilla_conv_block(flow, num_blocks=group_size, filters=feature, pool=pool,
                                  init=init, la=la, use_bn=use_bn, use_bias=use_bias)
    # Head
    pooled = tf.keras.layers.MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="same")(flow) #tf.keras.layers.GlobalAveragePooling2D()(flow)
    pooled = tf.keras.layers.Flatten()(pooled)
    if dense_units[0]>0:
        pooled = tf.keras.layers.Dense(dense_units[0], activation = 'relu', kernel_regularizer = tf.keras.regularizers.L2(l2=la),use_bias=use_bias)(bn_relu(pooled, use_bn=use_bn))
    if dense_units[1]>0:
        pooled = tf.keras.layers.Dense(dense_units[1], activation = 'relu', kernel_regularizer = tf.keras.regularizers.L2(l2=la),use_bias=use_bias)(bn_relu(pooled, use_bn=use_bn))
    output = tf.keras.layers.Dense(n_classes, activation = 'softmax', kernel_regularizer = tf.keras.regularizers.L2(l2=la),use_bias=use_bias)(pooled)
    
    model = tf.keras.Model(inputs=inputs, outputs=output)
    return model


###########################################################################################################


# Vanilla VGG model with finer control over filter number to create reduced networks for FLOP measurement

def vanilla_conv_block_custom(x, filters_list, pool, init, la, use_bn, use_bias):
    conv_added = False
    for f in filters_list:
        if f > 0:
            if conv_added:
                x = bn_relu(x, use_bn=use_bn)
            x = vanilla_regularized_padded_conv3(filters=f, init=init, la=la, use_bias=use_bias)(x)
            conv_added = True
    if conv_added:
        x = bn_relu(x, use_bn=use_bn)
    x = tf.keras.layers.MaxPool2D(pool)(x)
    return x

def VanillaVGGIndividual(input_shape, n_classes, block_filters, pools=(2,2,2,2,2),
                         init=tf.keras.initializers.HeNormal, la=0, use_bn=True, use_bias=True, dense_units=(4096,4096)):
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = inputs
    for filters_list, pool in zip(block_filters, pools):
        x = vanilla_conv_block_custom(x, filters_list, pool, init, la, use_bn, use_bias)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(x)
    x = tf.keras.layers.Flatten()(x)
    if dense_units[0] > 0:
        x = bn_relu(x, use_bn=use_bn)
        x = tf.keras.layers.Dense(dense_units[0], activation='relu', 
                                  kernel_regularizer=tf.keras.regularizers.L2(l2=la), use_bias=use_bias)(x)
    if dense_units[1] > 0:
        x = bn_relu(x, use_bn=use_bn)
        x = tf.keras.layers.Dense(dense_units[1], activation='relu', 
                                  kernel_regularizer=tf.keras.regularizers.L2(l2=la), use_bias=use_bias)(x)
    outputs = tf.keras.layers.Dense(n_classes, activation='softmax', 
                                    kernel_regularizer=tf.keras.regularizers.L2(l2=la), use_bias=use_bias)(x)
    return tf.keras.Model(inputs=inputs, outputs=outputs)



def vanilla_conv_block_customOld(x, filters_list, pool, init, la, use_bn, use_bias):
    # First conv layer in the block
    x = vanilla_regularized_padded_conv3(filters=filters_list[0], init=init, la=la, use_bias=use_bias)(x)
    # Remaining conv layers in the block, each with its specified filter count
    for f in filters_list[1:]:
        x = bn_relu(x, use_bn=use_bn)
        x = vanilla_regularized_padded_conv3(filters=f, init=init, la=la, use_bias=use_bias)(x)
    x = bn_relu(x, use_bn=use_bn)
    x = tf.keras.layers.MaxPool2D(pool)(x)
    return x

def VanillaVGGIndividualOld(input_shape, n_classes, block_filters, pools=(2,2,2,2,2),
                         init=tf.keras.initializers.HeNormal, la=0, use_bn=True, use_bias=True, dense_units=(4096,4096)):
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = inputs
    for filters_list, pool in zip(block_filters, pools):
        x = vanilla_conv_block_custom(x, filters_list, pool, init, la, use_bn, use_bias)
    # Head: pooling, flattening, and fully connected layers
    x = tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(x)
    x = tf.keras.layers.Flatten()(x)
    if dense_units[0] > 0:
        x = bn_relu(x, use_bn=use_bn)
        x = tf.keras.layers.Dense(dense_units[0], activation='relu', 
                                  kernel_regularizer=tf.keras.regularizers.L2(l2=la), use_bias=use_bias)(x)
    if dense_units[1] > 0:
        x = bn_relu(x, use_bn=use_bn)
        x = tf.keras.layers.Dense(dense_units[1], activation='relu', 
                                  kernel_regularizer=tf.keras.regularizers.L2(l2=la), use_bias=use_bias)(x)
    outputs = tf.keras.layers.Dense(n_classes, activation='softmax', 
                                    kernel_regularizer=tf.keras.regularizers.L2(l2=la), use_bias=use_bias)(x)
    
    return tf.keras.Model(inputs=inputs, outputs=outputs)

# Example usage:
# Define filter lists for each block, e.g., for a VGG-16-like model:
# block_filters = [
#     [64, 64],        # Block 1: two conv layers with 64 filters each.
#     [128, 128],      # Block 2: two conv layers with 128 filters each.
#     [256, 256, 256], # Block 3: three conv layers with 256 filters each.
#     [512, 512, 512], # Block 4: three conv layers with 512 filters each.
#     [512, 512, 512]  # Block 5: three conv layers with 512 filters each.
# ]
# model = VanillaVGGIndividual(input_shape=(224, 224, 3), n_classes=1000, block_filters=block_filters)
# model.summary()

###########################################################################################################

# Hadamard VGG for unstructured sparsity

# Unstructured sparse Conv modules
def unstr_regularized_padded_conv3(filters, depth, init_type, init, la, use_bias, factorize_bias):
    return HadamardConv2D(filters=filters, kernel_size=(3,3), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='same')

def unstr_conv_block(x, num_blocks, filters, pool, depth, init_type, init, la, use_bn, use_bias, factorize_bias):
    x = unstr_regularized_padded_conv3(filters=filters, depth=depth, init_type=init_type, init=init, la=la,use_bias=use_bias, factorize_bias=factorize_bias)(x)
    for i in range(num_blocks-1):
        x = unstr_regularized_padded_conv3(filters=filters, depth=depth, init_type=init_type, init=init, 
                                           use_bias=use_bias, factorize_bias=factorize_bias, la=la)(bn_relu(x, use_bn=use_bn))
    x = tf.keras.layers.MaxPool2D(pool)(bn_relu(x, use_bn=use_bn))
    return x

# Define VGG template with differentiable unstructured sparsity 
def HadamardVGG(input_shape, n_classes, group_sizes=(1, 1, 2, 2, 2),
           features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2),
           depth=1, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
           use_bias=True, factorize_bias=True, dense_units=(4096,4096)): #HeNormal()
    # Input
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # Conv flow
    flow = inputs
    for group_size, feature, pool in zip(group_sizes, features, pools):
        flow = unstr_conv_block(flow, num_blocks=group_size, filters=feature, pool=pool, depth=depth,
                                init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    # Head
    pooled = tf.keras.layers.MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="same")(flow) #tf.keras.layers.GlobalAveragePooling2D()(flow)
    pooled = tf.keras.layers.Flatten()(pooled)
    if dense_units[0]>0:
        pooled = HadamardDense(units=dense_units[0], activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                               use_bias=use_bias, factorize_bias=factorize_bias)(bn_relu(pooled, use_bn=use_bn))
    if dense_units[1]>0:
        pooled = HadamardDense(units=dense_units[1], activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                           use_bias=use_bias, factorize_bias=factorize_bias)(bn_relu(pooled, use_bn=use_bn))
        
    output = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type=init_type, init=init, 
                           use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    model = tf.keras.Model(inputs=inputs, outputs=output)
    return model
    
###########################################################################################################

# Filter-sparse VGG for structured filter/channel or neuron sparsity

# Structured sparse Conv modules using GHPowP
def str_regularized_padded_conv3(filters, depth, kernel_initializer, multfac_initializer, la, use_bias):
    #return SparseConv2D(filters=filters, kernel_size=(3,3), strides=1, use_bias=use_bias,
    #                    depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,
    #                    la=la, padding='same')
    return StrConv2D(filters=filters, kernel_size=(3,3), strides=1, use_bias=False, #use_bias
                        depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,
                        la=la, padding='same')

def str_conv_block(x, num_blocks, filters, pool, depth, kernel_initializer, multfac_initializer, la, use_bn, use_bias):
    x = str_regularized_padded_conv3(filters=filters, depth=depth, kernel_initializer=kernel_initializer, la=la, use_bias=use_bias,
                                     multfac_initializer=multfac_initializer)(x)
    for i in range(num_blocks-1):
        x = str_regularized_padded_conv3(filters=filters,depth=depth,kernel_initializer=kernel_initializer, la=la, use_bias=use_bias,
                                         multfac_initializer=multfac_initializer)(bn_relu(x, use_bn=use_bn))
    x = tf.keras.layers.MaxPool2D(pool)(bn_relu(x, use_bn=use_bn))
    return x

# Define VGG template with differentiable structured sparsity 
def StructuredVGG(input_shape, n_classes, group_sizes=(1, 1, 2, 2, 2),
                  features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2),
                  depth=1, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, #HeNormal()
                  la=0, use_bn=True, use_bias=False, factorize_bias=False, sparse_head=False, dense_units=(4096, 4096)):
    # Input
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # Conv flow
    flow = inputs
    for group_size, feature, pool in zip(group_sizes, features, pools):
        flow = str_conv_block(flow, num_blocks=group_size, filters=feature, pool=pool, depth=depth, kernel_initializer=kernel_initializer,
                              multfac_initializer=multfac_initializer, la=la, use_bn=use_bn, 
                              use_bias=use_bias)
    # Head
    pooled = tf.keras.layers.MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="same")(flow) #tf.keras.layers.GlobalAveragePooling2D()(flow)
    pooled = tf.keras.layers.Flatten()(pooled)
    if sparse_head:
        if dense_units[0]>0:
            pooled = StrHadamardDense(units=dense_units[0], activation='relu', depth=depth, la=0.1*la, init_type='ones', init=kernel_initializer, 
                                 use_bias=use_bias, factorize_bias=factorize_bias)(bn_relu(pooled, use_bn=use_bn)) 
        if dense_units[1]>0:
            pooled = StrHadamardDense(units=dense_units[1], activation='relu', depth=depth, la=0.1*la, init_type='ones', init=kernel_initializer, 
                                 use_bias=use_bias, factorize_bias=factorize_bias)(bn_relu(pooled, use_bn=use_bn)) 
        output = StrHadamardDenseV2(units=n_classes, activation='softmax', depth=depth, la=0.1*la, init_type='ones', init=kernel_initializer, 
                             use_bias=use_bias, factorize_bias=factorize_bias)(pooled) 
    else:
        if dense_units[0]>0:
            pooled = tf.keras.layers.Dense(dense_units[0], activation = 'relu', kernel_regularizer = tf.keras.regularizers.L2(l2=0.1*la),use_bias=use_bias)(bn_relu(pooled, use_bn=use_bn))
        if dense_units[1]>0:
            pooled = tf.keras.layers.Dense(dense_units[1], activation = 'relu', kernel_regularizer = tf.keras.regularizers.L2(l2=0.1*la),use_bias=use_bias)(bn_relu(pooled, use_bn=use_bn))
        output = tf.keras.layers.Dense(n_classes, activation = 'softmax', kernel_regularizer = tf.keras.regularizers.L2(l2=la),use_bias=True)(pooled)
    
    model = tf.keras.Model(inputs=inputs, outputs=output)
    #pooled = HadamardDense(units=dense_units[0], activation='relu', depth=depth, la=la, init_type='equivar', init=kernel_initializer, 
    #                       use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    #pooled = HadamardDense(units=dense_units[1], activation='relu', depth=depth, la=la, init_type='equivar', init=kernel_initializer, 
    #                       use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    #pooled = HadamardDense(units=dense_units[2], activation='relu', depth=depth, la=la, init_type='equivar', init=kernel_initializer, 
    #                       use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    #output = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type='equivar', init=kernel_initializer, 
    #                       use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    #model = tf.keras.Model(inputs=inputs, outputs=output)
    return model

###########################################################################################################

# VGG model with direct group lasso regularization on the filter weights as well as l1 regularization on the batch norm gammas:

def bn_relu_reg(x, use_bn, la_gamma):
    if use_bn:
        x = tf.keras.layers.BatchNormalization(
            gamma_regularizer=tf.keras.regularizers.L1(l1=la_gamma)
        )(x)
    return tf.keras.layers.ReLU()(x)

def vanilla_regularized_padded_conv3_reg(filters, init, la, use_bias, **kwargs):
    return tf.keras.layers.Conv2D(
        filters=filters,
        kernel_size=3,
        padding='same',
        kernel_initializer=init,
        use_bias=use_bias,
        kernel_regularizer=GroupLassoRegularizer(lam=la, axis=3),
        **kwargs
    )

def vanilla_conv_block_reg(x, num_blocks, filters, pool, init, la, la_gamma, use_bn, use_bias):
    x = vanilla_regularized_padded_conv3_reg(
        filters=filters, init=init, la=la, use_bias=use_bias
    )(x)
    for _ in range(num_blocks - 1):
        x = vanilla_regularized_padded_conv3_reg(
            filters=filters, init=init, la=la, use_bias=use_bias
        )(bn_relu_reg(x, use_bn, la_gamma))
    x = tf.keras.layers.MaxPool2D(pool_size=pool)(
        bn_relu_reg(x, use_bn, la_gamma)
    )
    return x

def VanillaVGG_reg(
    input_shape,
    n_classes,
    group_sizes=(1, 1, 2, 2, 2),
    features=(64, 128, 256, 512, 512),
    pools=(2, 2, 2, 2, 2),
    init_type='vanilla',
    init=tf.keras.initializers.HeNormal,
    la=0,
    la_gamma=0,
    use_bn=True,
    use_bias=True,
    dense_units=(4096, 4096)
):
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = inputs
    for gs, f, p in zip(group_sizes, features, pools):
        x = vanilla_conv_block_reg(
            x, num_blocks=gs, filters=f, pool=p,
            init=init, la=la, la_gamma=la_gamma,
            use_bn=use_bn, use_bias=use_bias
        )
    x = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
    x = tf.keras.layers.Flatten()(x)
    if dense_units[0] > 0:
        x = tf.keras.layers.Dense(
            dense_units[0],
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.L2(l2=la),
            use_bias=use_bias
        )(bn_relu_reg(x, use_bn, la_gamma))
    if dense_units[1] > 0:
        x = tf.keras.layers.Dense(
            dense_units[1],
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.L2(l2=la),
            use_bias=use_bias
        )(bn_relu_reg(x, use_bn, la_gamma))
    output = tf.keras.layers.Dense(
        n_classes,
        activation='softmax',
        kernel_regularizer=tf.keras.regularizers.L2(l2=la),
        use_bias=use_bias
    )(x)
    return tf.keras.Model(inputs, output)


# VGG model construction from template

def load_weights_func(model, model_name):
    try: model.load_weights(os.path.join('saved_models', model_name + '.tf'))
    except tf.errors.NotFoundError: print("No weights found for this model!")
    return model

###########################################################################################################

### VGG7

# Vanilla 
def vanilla_vgg7(load_weights=False, depth=1, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                 input_shape=(28, 28, 1), n_classes=10, dense_units =(256,0), use_bias=True):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2), 
                features=(32, 64, 128), pools=(2, 2, 2),
                init_type=init_type, init=init, la=la, use_bn=use_bn, dense_units=dense_units, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'mnist_vanilla_vgg7')
    return model

# Unstr. Hadamard
def hadamard_vgg7(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(28, 28, 1), n_classes=10, dense_units =(256,0),use_bias=True, factorize_bias=True):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2), 
                features=(32, 64, 128), pools=(2, 2, 2),
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, dense_units=dense_units,
                use_bias=use_bias, factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'mnist_hadamard_vgg7')
    return model

# Str. Hadamard
def str_hadamard_vgg7(load_weights=False, depth=2, la=0, use_bn=True, kernel_initializer=tf.keras.initializers.HeNormal,
                      multfac_initializer=tf.keras.initializers.Ones, input_shape=(28, 28, 1), n_classes=10, dense_units =(256,0),
                      use_bias=True, factorize_bias=False):
    model = StructuredVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2), 
                          features=(32, 64, 128), pools=(2, 2, 2), kernel_initializer=kernel_initializer,
                          multfac_initializer=multfac_initializer, depth=depth, la=la, use_bn=use_bn, dense_units=dense_units,
                          use_bias=use_bias, factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'mnist_str_hadamard_vgg7')
    return model

###########################################################################################################

### VGG 11 originally for (224,224,3) images

# Vanilla
def vanilla_vgg11(load_weights=False, depth=1, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(32, 32, 3), n_classes=10,use_bias=True, dense_units=(512,0)):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(1, 1, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'cifar_vanilla_vgg11')
    return model

# Unstr. Hadamard
def hadamard_vgg11(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                   input_shape=(32, 32, 3), n_classes=10,use_bias=True, factorize_bias=True, dense_units=(512,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(1, 1, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn,
                use_bias=use_bias, factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_hadamard_vgg11')
    return model

# Str. Hadamard
def str_hadamard_vgg11(load_weights=False, depth=2, kernel_initializer=tf.keras.initializers.HeNormal,
                       multfac_initializer=tf.keras.initializers.Ones, la=0, use_bn=True, input_shape=(32, 32, 3), n_classes=10,
                       use_bias=False, factorize_bias=False, dense_units=(512,0)):
    model = StructuredVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(1, 1, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,
                la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_str_hadamard_vgg11')
    return model

###########################################################################################################

### VGG 13

# Vanilla
def vanilla_vgg13(load_weights=False, depth=1, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(32, 32, 3), n_classes=10, use_bias=True, dense_units=(512,0)):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'cifar_vanilla_vgg13')
    return model

# Unstr. Hadamard
def hadamard_vgg13(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                   input_shape=(32, 32, 3), n_classes=10,use_bias=True, factorize_bias=True, dense_units=(512,0,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_hadamard_vgg13')
    return model

# Str. Hadamard
def str_hadamard_vgg13(load_weights=False, depth=2, kernel_initializer=tf.keras.initializers.HeNormal,
                       multfac_initializer=tf.keras.initializers.Ones, la=0, use_bn=True, input_shape=(32, 32, 3), n_classes=10,
                       use_bias=False, factorize_bias=False, dense_units=(512,0)):
    model = StructuredVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, use_bn=use_bn,
                use_bias=use_bias, factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_str_hadamard_vgg13')
    return model

###########################################################################################################

### VGG 16

# Vanilla
def vanilla_vgg16(load_weights=False, depth=2, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(32, 32, 3), n_classes=10,use_bias=True, dense_units=(512,0)):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 3, 3, 3), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'cifar_vanilla_vgg16')
    return model

# Vanilla individual filter numbers
# VGG 16 using VanillaVGGIndividual

def vanilla_vgg16_individual(load_weights=False, depth=2, init_type='vanilla', 
                             init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                             input_shape=(32, 32, 3), n_classes=10, use_bias=True, 
                             dense_units=(512, 0), block_filters=None):
    # Default block_filters for VGG16 if not provided
    if block_filters is None:
        block_filters = [
            [64, 64],
            [128, 128],
            [256, 256, 256],
            [512, 512, 512],
            [512, 512, 512]
        ]
    
    model = VanillaVGGIndividual(input_shape=input_shape, n_classes=n_classes, 
                                 block_filters=block_filters, pools=(2, 2, 2, 2, 2),
                                 init=init, la=la, use_bn=use_bn, use_bias=use_bias, 
                                 dense_units=dense_units)
    if load_weights:
        model = load_weights_func(model, 'cifar_vanilla_vgg16_individual')
    return model

# Example usage:
# model = vanilla_vgg16_individual(load_weights=False, input_shape=(32, 32, 3), n_classes=10, la = 1e-4)
# model.summary()

# Example usage:
# For VGG16 with default filters:
# model16 = vanilla_vgg16_individual(load_weights=False, input_shape=(32, 32, 3), n_classes=10)
# model16.summary()
#
# For VGG19 with custom reduced filters:
# custom_filters = [
#     [32, 32],
#     [64, 64],
#     [128, 128, 128],
#     [256, 256, 256],
#     [256, 256, 256]
# ]
# model19 = vanilla_vgg19_individual(load_weights=False, input_shape=(32, 32, 3), n_classes=10, block_filters=custom_filters)
# model19.summary()

# Unstr. Hadamard
def hadamard_vgg16(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                   input_shape=(32, 32, 3), n_classes=10, use_bias=True, factorize_bias=True, dense_units=(512,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 3, 3, 3), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_hadamard_vgg16')
    return model

# Str. Hadamard
def str_hadamard_vgg16(load_weights=False, depth=2, kernel_initializer=tf.keras.initializers.HeNormal,
                       multfac_initializer=tf.keras.initializers.Ones, la=0, use_bn=True, input_shape=(32, 32, 3), n_classes=10,
                      use_bias=True, factorize_bias=False, dense_units=(512,0)):
    model = StructuredVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 3, 3, 3), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, use_bn=use_bn,
                use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_str_hadamard_vgg16')
    return model

# VGG16 with direct L21 and network slimming regularization options

def vanilla_vgg16_reg(
    load_weights=False,
    init_type='vanilla',
    init=tf.keras.initializers.HeNormal,
    la=0,
    la_gamma=0,
    use_bn=True,
    input_shape=(32, 32, 3),
    n_classes=10,
    use_bias=True,
    dense_units=(512, 0)
):
    model = VanillaVGG_reg(
        input_shape=input_shape,
        n_classes=n_classes,
        group_sizes=(2, 2, 3, 3, 3),
        features=(64, 128, 256, 512, 512),
        pools=(2, 2, 2, 2, 2),
        init_type=init_type,
        init=init,
        la=la,
        la_gamma=la_gamma,
        use_bn=use_bn,
        use_bias=use_bias,
        dense_units=dense_units
    )
    if load_weights:
        model = load_weights_func(model, 'cifar_vanilla_vgg16')
    return model

###########################################################################################################

### VGG19

# Vanilla
def vanilla_vgg19(load_weights=False, depth=1, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(32, 32, 3), n_classes=10, use_bias=True, dense_units=(512,0)):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 4, 4, 4), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'cifar_vanilla_vgg19')
    return model

# Vanilla individual filter numbers for reduced models and FLOPs computation
def vanilla_vgg19_individual(load_weights=False, depth=1, init_type='vanilla', 
                             init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                             input_shape=(32, 32, 3), n_classes=10, use_bias=True, 
                             dense_units=(512, 0), block_filters=None):
    # Default block_filters for VGG19 if not provided
    if block_filters is None:
        block_filters = [
            [64, 64],
            [128, 128],
            [256, 256, 256, 256],
            [512, 512, 512, 512],
            [512, 512, 512, 512]
        ]
    
    model = VanillaVGGIndividual(input_shape=input_shape, n_classes=n_classes, 
                                 block_filters=block_filters, pools=(2, 2, 2, 2, 2),
                                 init=init, la=la, use_bn=use_bn, use_bias=use_bias, 
                                 dense_units=dense_units)
    if load_weights:
        model = load_weights_func(model, 'cifar_vanilla_vgg19_individual')
    return model

# Example usage:
# For VGG16 with default filters:
# model16 = vanilla_vgg16_individual(load_weights=False, input_shape=(32, 32, 3), n_classes=10, la = 1e-4)
# model16.summary()
#
# For VGG19 with custom reduced filters:
# custom_filters = [
#     [32, 32],
#     [64, 64],
#     [128, 128, 128],
#     [256, 256, 256],
#     [256, 256, 256]
# ]
# model19 = vanilla_vgg19_individual(load_weights=False, input_shape=(32, 32, 3), n_classes=10, block_filters=custom_filters)
# model19.summary()


# Unstr. Hadamard
def hadamard_vgg19(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                   input_shape=(32, 32, 3), n_classes=10, use_bias=True, factorize_bias=True, dense_units=(512,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 4, 4, 4), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_hadamard_vgg19')
    return model

# Str. Hadamard
def str_hadamard_vgg19(load_weights=False, depth=2, kernel_initializer=tf.keras.initializers.HeNormal,
                   multfac_initializer=tf.keras.initializers.Ones, la=0, use_bn=True,input_shape=(32, 32, 3), n_classes=10,
                   use_bias=False, factorize_bias=False, dense_units=(512,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 4, 4, 4), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, use_bn=use_bn,
                       use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_str_hadamard_vgg19')
    return model

        