import os
import tensorflow as tf

from layers import HadamardDense, HadamardConv2D, StrHadamardDense
from callbacks import HadamardCallback
from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d

# Components for ResNet construction
def regularized_padded_conv(filters, kernel_size, strides, depth, init_type, init, la, *args, **kwargs):
    #return tf.keras.layers.Conv2D(*args, **kwargs, padding='same', kernel_regularizer=_regularizer,
    #                              kernel_initializer='he_normal', use_bias=False)
    return HadamardConv2D(filters=filters, kernel_size=(kernel_size,kernel_size), strides=strides, use_bias=False, factorize_bias=False,
                          depth=depth, init_type=init_type, init=init,la=la, padding='same')

def bn_relu(x):
    x = tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU()(x)


def shortcut(x, filters, stride, mode, depth, init_type, init, la):
    if x.shape[-1] == filters:
        return x
    elif mode == 'B':
        return regularized_padded_conv(filters=filters, kernel_size=1, strides=stride, depth=depth, init_type=init_type, init=init, la=la)(x)
    elif mode == 'B_original':
        x = regularized_padded_conv(filters=filters, kernel_size=1, strides=stride, depth=depth, init_type=init_type, init=init, la=la)(x)
        return tf.keras.layers.BatchNormalization()(x)
    elif mode == 'A':
        return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride>1 else x,
                      paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])])
    else:
        raise KeyError("Parameter shortcut_type not recognized!")
    

def original_block(x, filters, depth, init_type, init, la, stride=1, **kwargs):
    c1 = regularized_padded_conv(filters=filters, kernel_size=3, strides=stride, depth=depth, init_type=init_type, init=init, la=la)(x)
    c2 = regularized_padded_conv(filters=filters, kernel_size=3, strides=1, depth=depth, init_type=init_type, init=init, la=la)(bn_relu(c1))
    c2 = tf.keras.layers.BatchNormalization()(c2)
    
    mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type
    x = shortcut(x, filters, stride, mode=mode, depth=depth, init_type=init_type, init=init, la=la)
    return tf.keras.layers.ReLU()(x + c2)
    
    
def preactivation_block(x, filters, depth, init_type, init, la, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
        
    c1 = regularized_padded_conv(filters=filters, kernel_size=3, strides=stride, depth=depth, init_type=init_type, init=init, la=la)(flow)
    if _dropout:
        c1 = tf.keras.layers.Dropout(_dropout)(c1)
        
    c2 = regularized_padded_conv(filters=filters, kernel_size=3, strides=1, depth=depth, init_type=init_type, init=init, la=la)(bn_relu(c1))
    x = shortcut(x, filters, stride, mode=_shortcut_type, depth=depth, init_type=init_type, init=init, la=la)
    return x + c2


def bottleneck_block(x, filters, depth, init_type, init, la, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
         
    c1 = regularized_padded_conv(filters=filters//_bottleneck_width, kernel_size=1, strides=1,depth=depth,init_type=init_type,init=init,la=la)(flow)
    c2 = regularized_padded_conv(filters=filters//_bottleneck_width, kernel_size=3, strides=stride,depth=depth,init_type=init_type,init=init,la=la)(bn_relu(c1))
    c3 = regularized_padded_conv(filters=filters, kernel_size=1, strides=1,depth=depth,init_type=init_type,init=init,la=la)(bn_relu(c2))
    x = shortcut(x, filters, stride, mode=_shortcut_type, depth=depth, init_type=init_type, init=init, la=la)
    return x + c3


def group_of_blocks(x, block_type, num_blocks, filters, stride, depth, init_type, init, la, block_idx=0):
    global _preact_shortcuts
    preact_block = True if _preact_shortcuts or block_idx == 0 else False
    
    x = block_type(x, filters=filters, depth=depth, init_type=init_type, init=init, la=la, stride=stride, preact_block=preact_block)
    for i in range(num_blocks-1):
        x = block_type(x, filters=filters, depth=depth, init_type=init_type, init=init, la=la, stride=1)
    return x

# ResNet model definition
def Resnet(input_shape, n_classes, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2),
           shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
           dropout=0, cardinality=1, bottleneck_width=4, preact_shortcuts=True, name = 'ResNet',
           use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0):
    
    global _shortcut_type, _preact_projection, _dropout, _cardinality, _bottleneck_width, _preact_shortcuts
    _bottleneck_width = bottleneck_width # used in ResNeXts and bottleneck blocks
    #_regularizer = tf.keras.regularizers.l2(l2_reg)
    _shortcut_type = shortcut_type # used in blocks
    _cardinality = cardinality # used in ResNeXts
    _dropout = dropout # used in Wide ResNets
    _preact_shortcuts = preact_shortcuts
    
    block_types = {'preactivated': preactivation_block,
                   'bottleneck': bottleneck_block,
                   'original': original_block}
    
    selected_block = block_types[block_type]
    inputs = tf.keras.layers.Input(shape=input_shape)
    flow = regularized_padded_conv(depth=depth, init_type=init_type, init=init, la=la, **first_conv)(inputs)
    #flow = regularized_padded_conv(filters=16, kernel_size=3, strides=1, depth=depth, init_type=init_type, init=init, la=la)(inputs)
    
    if block_type == 'original':
        flow = bn_relu(flow)
    
    for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)):
        flow = group_of_blocks(flow,
                               block_type=selected_block,
                               num_blocks=group_size,
                               filters=feature,
                               stride=stride,
                               depth=depth,
                               init_type=init_type,
                               init=init,
                               la=la,
                               block_idx=block_idx)
    
    if block_type != 'original':
        flow = bn_relu(flow)
    
    flow = tf.keras.layers.GlobalAveragePooling2D()(flow)
    #outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer)(flow)
    outputs = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type=init_type, init=init, 
                            use_bias=use_bias, factorize_bias=factorize_bias)(flow)
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    return model


def load_weights_func(model, model_name):
    try: model.load_weights(os.path.join('saved_models', model_name + '.tf'))
    except tf.errors.NotFoundError: print("No weights found for this model!")
    return model

# ResNet architecture wrappers
# small ResNet18 with 0.2 mio params
def small_hadamard_resnet18(block_type='original', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(2, 2, 2), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters":16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'small_hadamard_cifar_resnet18',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_16f_resnet18')
    return model

# standard ResNet18 with 11.174 mio params
def hadamard_resnet18(block_type='original', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(2, 2, 2, 2), features=(64, 128, 256, 512),
                   strides=(1, 2, 2, 2), first_conv={"filters":64, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet18',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_64f_resnet18')
    return model

# standard ResNet34 with 22.282 mio params
def hadamard_resnet34(block_type='original', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(3, 4, 6, 3), features=(64, 128, 256, 512),
                   strides=(1, 2, 2, 2), first_conv={"filters":64, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet34',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_64f_resnet34')
    return model

# standard ResNet50 with 23.521 mio params
def hadamard_resnet50(block_type='bottleneck', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(3, 4, 6, 3), features=(256,512,1024,2048),
                   strides=(1, 2, 2, 2), first_conv={"filters":256, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet50',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_64f_resnet50')
    return model

# standard ResNet101
def hadamard_resnet101(block_type='bottleneck', shortcut_type='B', load_weights=False, #l2_reg = 5e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #lr_reg=l2_reg,
                   group_sizes=(3, 4, 23, 3), features=(256,512,1024,2048),
                   strides=(1, 2, 2, 2), first_conv={"filters":256, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet101',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_64f_resnet101')
    return model

def hadamard_resnet20(block_type='original', shortcut_type='A', load_weights=False, #l2_reg=1e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(3, 3, 3), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet20',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet20')
    return model


def hadamard_resnet32(block_type='original', shortcut_type='A',  load_weights=False, #l2_reg=1e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(5, 5, 5), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet32',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet32')
    return model


def hadamard_resnet44(block_type='original', shortcut_type='A', load_weights=False, #l2_reg=1e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(7, 7, 7), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet44',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet44')
    return model


def hadamard_resnet56(block_type='original', shortcut_type='A', load_weights=False, #l2_reg=1e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(9, 9, 9), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet56',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet56')
    return model


def hadamard_resnet110(block_type='preactivated', shortcut_type='B', load_weights=False, #l2_reg=1e-4,
                    use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                    input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(18, 18, 18), features=(16, 32, 64),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type=block_type, preact_shortcuts=False, name = 'hadamard_cifar_resnet110',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet110')
    return model


def hadamard_resnet164(shortcut_type='B', load_weights=False, #l2_reg=1e-4,
                    use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                    input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(18, 18, 18), features=(64, 128, 256),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type='bottleneck', preact_shortcuts=True, name = 'hadamard_cifar_resnet164',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet164')
    return model


def hadamard_resnet1001(shortcut_type='B', load_weights=False, #l2_reg=1e-4,
                     use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                     input_shape=(32, 32,3), n_classes=10):
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(111, 111, 111), features=(64, 128, 256),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 
                   block_type='bottleneck', preact_shortcuts=True, name = 'hadamard_cifar_resnet1001',
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    if load_weights: model = load_weights_func(model, 'hadamard_cifar_resnet1001')
    return model


def hadamard_wide_resnet(N, K, block_type='preactivated', shortcut_type='B', dropout=0, #l2_reg=2.5e-4,
                      use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                      input_shape=(32, 32,3), n_classes=10):
    assert (N-4) % 6 == 0, "N-4 has to be divisible by 6"
    lpb = (N-4) // 6 # layers per block - since N is total number of convolutional layers in Wide ResNet
    model = Resnet(input_shape=input_shape, n_classes=n_classes, #l2_reg=l2_reg, 
                   group_sizes=(lpb, lpb, lpb), features=(16*K, 32*K, 64*K),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type,
                   block_type=block_type, dropout=dropout, preact_shortcuts=True, name = 'hadamard_cifar_WRN_' + str(N) + '_' + str(K),
                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
    return model


def hadamard_WRN_16_4(shortcut_type='B', load_weights=False, dropout=0, #l2_reg=2.5e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = hadamard_wide_resnet(16, 4, 'preactivated', shortcut_type, dropout=dropout, #l2_reg=l2_reg,
                              use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la,
                              input_shape=input_shape, n_classes=n_classes)
    if load_weights: model = load_weights_func(model, 'cifar_WRN_16_4')
    return model


def hadamard_WRN_40_4(shortcut_type='B', load_weights=False, dropout=0, #l2_reg=2.5e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = hadamard_wide_resnet(40, 4, 'preactivated', shortcut_type, dropout=dropout, #l2_reg=l2_reg,
                              use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la,
                              input_shape=input_shape, n_classes=n_classes)
    if load_weights: model = load_weights_func(model, 'cifar_WRN_40_4')
    return model


def hadamard_WRN_16_8(shortcut_type='B', load_weights=False, dropout=0, #l2_reg=2.5e-4,
                   use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                   input_shape=(32, 32,3), n_classes=10):
    model = hadamard_wide_resnet(16, 8, 'preactivated', shortcut_type, dropout=dropout, #l2_reg=l2_reg,
                              use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la,
                              input_shape=input_shape, n_classes=n_classes)
    if load_weights: model = load_weights_func(model, 'cifar_WRN_16_8')
    return model


def hadamard_WRN_28_10(shortcut_type='B', load_weights=False, dropout=0, #l2_reg=2.5e-4,
                    use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0,
                    input_shape=(32, 32,3), n_classes=10):
    model = hadamard_wide_resnet(28, 10, 'preactivated', shortcut_type, dropout=dropout, #l2_reg=l2_reg,
                              use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la,
                              input_shape=input_shape, n_classes=n_classes)
    if load_weights: model = load_weights_func(model, 'cifar_WRN_28_10')
    return model


#def cifar_resnext(N, cardinality, width, shortcut_type='B', use_bias=True, factorize_bias=True, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal(), la=0,):
#    assert (N-3) % 9 == 0, "N-4 has to be divisible by 6"
#    lpb = (N-3) // 9 # layers per block - since N is total number of convolutional layers in Wide ResNet
#    model = Resnet(input_shape=(32, 32, 3), n_classes=10, l2_reg=1e-4, group_sizes=(lpb, lpb, lpb), features=(16*width, 32*width, 64*width),
#                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type,
#                   block_type='resnext', cardinality=cardinality, bottleneck_width=width,
#                   use_bias=use_bias, factorize_bias=factorize_bias, depth=depth, init_type=init_type, init=init, la=la)
#    return model

#mod = hadamard_resnet18(depth=1)
#mod.summary()

#mod2 = hadamard_resnet56()
#mod2.summary()

#mod3 = hadamard_WRN_16_8()
#mod3.summary()

# Different WRN-16-8 implementation (to check for equivalence)

# Hadamard WideResNet function (WRN-16-8 has dep=16 k=8)
# remove IN_FILTERS manual definitions when outside paperspace
IN_FILTERS=16 
def WideResNet(dep=16, k=8, input_shape = (32, 32, 3), n_classes = 10, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0,
              use_bias=True, factorize_bias=True):
    print('Wide-Resnet %dx%d' %(dep, k))
    n_filters  = [16, 16*k, 32*k, 64*k]
    n_stack    = (dep - 4) // 6
    IN_FILTERS = 16 #remove?

    def conv3x3(x,filters, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
        return HadamardConv2D(filters=filters, kernel_size=(3,3), strides=1, padding='same', use_bias=False,
                              depth=depth, init_type=init_type, init=init, la=la)(x)
        #return Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1), padding='same',
        #kernel_initializer='he_normal',
        #kernel_regularizer=l2(WEIGHT_DECAY),
        #use_bias=False)(x)

    def bn_relu(x):
        x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
        x = tf.keras.layers.Activation('relu')(x)
        return x

    def residual_block(x, out_filters, increase=False, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
        global IN_FILTERS
        stride = 1
        if increase:
            stride = 2

        o1 = bn_relu(x)
        conv_1 = HadamardConv2D(filters=out_filters, kernel_size=(3,3), strides=stride,padding='same',use_bias=False,
                                depth=depth,init_type=init_type,init=init,la=la)(o1)
        #conv_1 = Conv2D(out_filters,kernel_size=(3,3),strides=stride,padding='same',kernel_initializer='he_normal',
        #    kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(o1)
        o2 = bn_relu(conv_1)
        conv_2 = HadamardConv2D(filters=out_filters,kernel_size=(3,3),strides=1,padding='same',use_bias=False,
                                depth=depth,init_type=init_type,init=init,la=la)(o2)
        #conv_2 = Conv2D(out_filters,kernel_size=(3,3), strides=(1,1), padding='same',kernel_initializer='he_normal',
        #    kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(o2)
        if increase or IN_FILTERS != out_filters:
            proj = HadamardConv2D(filters=out_filters,kernel_size=(1,1),strides=stride,padding='same',use_bias=False,
                                depth=depth,init_type=init_type,init=init,la=la)(o1)
            #proj = Conv2D(out_filters,kernel_size=(1,1),strides=stride,padding='same',kernel_initializer='he_normal',
            #                    kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(o1)
            block = tf.keras.layers.add([conv_2, proj])
        else:
            block = tf.keras.layers.add([conv_2,x])
        return block

    def wide_residual_layer(x,out_filters,increase=False, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
        global IN_FILTERS
        x = residual_block(x,out_filters,increase, depth=depth, init_type=init_type, init = init, la=la)
        IN_FILTERS = out_filters
        for _ in range(1,int(n_stack)):
            x = residual_block(x,out_filters, depth=depth, init_type=init_type, init = init, la=la)
        return x

    x_input = tf.keras.layers.Input(input_shape)
    x = conv3x3(x_input,n_filters[0], depth=depth, init_type=init_type, init = init, la=la) #img_input
    x = wide_residual_layer(x,n_filters[1], depth=depth, init_type=init_type, init = init, la=la)
    x = wide_residual_layer(x,n_filters[2], increase=True, depth=depth, init_type=init_type, init = init, la=la)
    x = wide_residual_layer(x,n_filters[3], increase=True, depth=depth, init_type=init_type, init = init, la=la)
    x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.AveragePooling2D((8,8))(x)
    x = tf.keras.layers.Flatten()(x)
    x = HadamardDense(units=n_classes,activation='softmax',depth=depth,use_bias=use_bias,factorize_bias=factorize_bias,
                      init_type=init_type,init=init,la=la)(x)
    #x = Dense(n_classes,activation='softmax',kernel_initializer='he_normal',kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(x)

    # define model
    model = tf.keras.models.Model(inputs = x_input, outputs = x, name = "WideResNet")
    return model
