import os
import tensorflow as tf


from layers import HadamardDense, HadamardConv2D, SparseConv2D, StrHadamardDense, StrHadamardDenseV2
from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d

# Components for VGG construction
def bn_relu(x, use_bn):
    if use_bn:
        x = tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU()(x)

###########################################################################################################

# Vanilla VGG with L2 regularization

# vanilla Conv2D modules
def vanilla_regularized_padded_conv3(filters, init, la, use_bias, *args, **kwargs):
    return tf.keras.layers.Conv2D(*args, **kwargs, filters=filters, kernel_size=3, padding='same', kernel_initializer=init, 
                                  use_bias=use_bias, kernel_regularizer=tf.kernel.regularizers.L2(l2 = la))

def vanilla_conv_block(x, num_blocks, filters, pool, init, la, use_bn, use_bias):
    x = vanilla_regularized_padded_conv3(filters=filters, init=init, la=la, use_bias=use_bias)(x)
    for i in range(num_blocks-1):
        x = vanilla_regularized_padded_conv3(filters=filters, init=init, la=la, use_bias=use_bias)(bn_relu(x, use_bn=use_bn))
    x = tf.keras.layers.MaxPool2D(pool)(bn_relu(x, use_bn=use_bn))
    return x

# Define VGG template without sparse regularization (L2 only)
def VanillaVGG(input_shape, n_classes, group_sizes=(1, 1, 2, 2, 2), features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2),
               init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,  #HeNormal()
               use_bias=True, factorize_bias=False, dense_units=(4096,4096)):
    # Inputs
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # Conv flow
    flow = inputs
    for group_size, feature, pool in zip(group_sizes, features, pools):
        flow = vanilla_conv_block(flow, num_blocks=group_size, filters=feature, pool=pool,
                                  init=init, la=la, use_bn=use_bn, init_type=init_type, use_bias=use_bias)
    # Head
    pooled = tf.keras.layers.MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="same")(flow) #tf.keras.layers.GlobalAveragePooling2D()(flow)
    pooled = tf.keras.layers.Flatten()(pooled)
    if dense_units[0]>0:
        pooled = tf.keras.layers.Dense(dense_units[0], activation = 'relu', kernel_regularizer = tf.keras.regularizers.L2(l2=la),use_bias=use_bias)(bn_relu(pooled, use_bn=use_bn))
    if dense_units[1]>0:
        pooled = tf.keras.layers.Dense(dense_units[1], activation = 'relu', kernel_regularizer = tf.keras.regularizers.L2(l2=la),use_bias=use_bias)(bn_relu(pooled, use_bn=use_bn))
    output = tf.keras.layers.Dense(n_classes, activation = 'softmax', kernel_regularizer = tf.keras.regularizers.L2(l2=la),use_bias=use_bias)(pooled)
    
    model = tf.keras.Model(inputs=inputs, outputs=output)
    return model

###########################################################################################################

# Hadamard VGG for unstructured sparsity

# Unstructured sparse Conv modules
def unstr_regularized_padded_conv3(filters, depth, init_type, init, la, use_bias, factorize_bias):
    return HadamardConv2D(filters=filters, kernel_size=(3,3), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='same')

def unstr_conv_block(x, num_blocks, filters, pool, depth, init_type, init, la, use_bn, use_bias, factorize_bias):
    x = unstr_regularized_padded_conv3(filters=filters, depth=depth, init_type=init_type, init=init, la=la,use_bias=use_bias, factorize_bias=factorize_bias)(x)
    for i in range(num_blocks-1):
        x = unstr_regularized_padded_conv3(filters=filters, depth=depth, init_type=init_type, init=init, 
                                           use_bias=use_bias, factorize_bias=factorize_bias, la=la)(bn_relu(x, use_bn=use_bn))
    x = tf.keras.layers.MaxPool2D(pool)(bn_relu(x, use_bn=use_bn))
    return x

# Define VGG template with differentiable unstructured sparsity 
def HadamardVGG(input_shape, n_classes, group_sizes=(1, 1, 2, 2, 2),
           features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2),
           depth=1, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
           use_bias=True, factorize_bias=True, dense_units=(4096,4096)): #HeNormal()
    # Input
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # Conv flow
    flow = inputs
    for group_size, feature, pool in zip(group_sizes, features, pools):
        flow = unstr_conv_block(flow, num_blocks=group_size, filters=feature, pool=pool, depth=depth,
                                init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    # Head
    pooled = tf.keras.layers.MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="same")(flow) #tf.keras.layers.GlobalAveragePooling2D()(flow)
    pooled = tf.keras.layers.Flatten()(pooled)
    if dense_units[0]>0:
        pooled = HadamardDense(units=dense_units[0], activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                               use_bias=use_bias, factorize_bias=factorize_bias)(bn_relu(pooled, use_bn=use_bn))
    if dense_units[1]>0:
        pooled = HadamardDense(units=dense_units[1], activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                           use_bias=use_bias, factorize_bias=factorize_bias)(bn_relu(pooled, use_bn=use_bn))
        
    output = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type=init_type, init=init, 
                           use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    model = tf.keras.Model(inputs=inputs, outputs=output)
    return model
    
###########################################################################################################

# Filter-sparse VGG for structured filter/channel and neuron sparsity

# Structured sparse Conv modules using GHPowP
def str_regularized_padded_conv3(filters, depth, kernel_initializer, multfac_initializer, la, use_bias):
    return SparseConv2D(filters=filters, kernel_size=(3,3), strides=1, use_bias=use_bias,
                        depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,
                        la=la, padding='same')

def str_conv_block(x, num_blocks, filters, pool, depth, kernel_initializer, multfac_initializer, la, use_bn, use_bias):
    x = str_regularized_padded_conv3(filters=filters, depth=depth, kernel_initializer=kernel_initializer, la=la, use_bias=use_bias,
                                     multfac_initializer=multfac_initializer)(x)
    for i in range(num_blocks-1):
        x = str_regularized_padded_conv3(filters=filters,depth=depth,kernel_initializer=kernel_initializer, la=la, use_bias=use_bias,
                                         multfac_initializer=multfac_initializer)(bn_relu(x, use_bn=use_bn))
    x = tf.keras.layers.MaxPool2D(pool)(bn_relu(x, use_bn=use_bn))
    return x

# Define VGG template with differentiable structured sparsity 
def StructuredVGG(input_shape, n_classes, group_sizes=(1, 1, 2, 2, 2),
                  features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2),
                  depth=1, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, #HeNormal()
                  la=0, use_bn=True, use_bias=True, factorize_bias=False, dense_units=(4096, 4096)):
    # Input
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # Conv flow
    flow = inputs
    for group_size, feature, pool in zip(group_sizes, features, pools):
        flow = str_conv_block(flow, num_blocks=group_size, filters=feature, pool=pool, depth=depth, kernel_initializer=kernel_initializer,
                              multfac_initializer=multfac_initializer, la=la, use_bn=use_bn, 
                              use_bias=use_bias)
    # Head
    pooled = tf.keras.layers.MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="same")(flow) #tf.keras.layers.GlobalAveragePooling2D()(flow)
    pooled = tf.keras.layers.Flatten()(pooled)
    if dense_units[0]>0:
        pooled = StrHadamardDense(units=dense_units[0], activation='relu', depth=depth, la=la, init_type='ones', init=kernel_initializer, 
                                 use_bias=use_bias, factorize_bias=factorize_bias)(bn_relu(pooled, use_bn=use_bn)) 
    if dense_units[1]>0:
        pooled = StrHadamardDense(units=dense_units[1], activation='relu', depth=depth, la=la, init_type='ones', init=kernel_initializer, 
                                 use_bias=use_bias, factorize_bias=factorize_bias)(bn_relu(pooled, use_bn=use_bn)) 
    
    output = StrHadamardDenseV2(units=n_classes, activation='softmax', depth=depth, la=0.1*la, init_type='ones', init=kernel_initializer, 
                             use_bias=use_bias, factorize_bias=factorize_bias)(pooled) 
    model = tf.keras.Model(inputs=inputs, outputs=output)
    #pooled = HadamardDense(units=dense_units[0], activation='relu', depth=depth, la=la, init_type='equivar', init=kernel_initializer, 
    #                       use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    #pooled = HadamardDense(units=dense_units[1], activation='relu', depth=depth, la=la, init_type='equivar', init=kernel_initializer, 
    #                       use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    #pooled = HadamardDense(units=dense_units[2], activation='relu', depth=depth, la=la, init_type='equivar', init=kernel_initializer, 
    #                       use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    #output = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type='equivar', init=kernel_initializer, 
    #                       use_bias=use_bias, factorize_bias=factorize_bias)(pooled)
    #model = tf.keras.Model(inputs=inputs, outputs=output)
    return model

###########################################################################################################

# VGG model construction from template

def load_weights_func(model, model_name):
    try: model.load_weights(os.path.join('saved_models', model_name + '.tf'))
    except tf.errors.NotFoundError: print("No weights found for this model!")
    return model

###########################################################################################################

### VGG7

# Vanilla 
def vanilla_vgg7(load_weights=False, depth=1, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                 input_shape=(28, 28, 1), n_classes=10, dense_units =(256,0), use_bias=True):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2), 
                features=(32, 64, 128), pools=(2, 2, 2),
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, dense_units=dense_units, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'mnist_vanilla_vgg7')
    return model

# Unstr. Hadamard
def hadamard_vgg7(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(28, 28, 1), n_classes=10, dense_units =(256,0),use_bias=True, factorize_bias=True):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2), 
                features=(32, 64, 128), pools=(2, 2, 2),
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, dense_units=dense_units,
                use_bias=use_bias, factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'mnist_hadamard_vgg7')
    return model

# Str. Hadamard
def str_hadamard_vgg7(load_weights=False, depth=2, la=0, use_bn=True, kernel_initializer=tf.keras.initializers.HeNormal,
                      multfac_initializer=tf.keras.initializers.Ones, input_shape=(28, 28, 1), n_classes=10, dense_units =(256,0),
                      use_bias=True, factorize_bias=False):
    model = StructuredVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2), 
                          features=(32, 64, 128), pools=(2, 2, 2), kernel_initializer=kernel_initializer,
                          multfac_initializer=multfac_initializer, depth=depth, la=la, use_bn=use_bn, dense_units=dense_units,
                          use_bias=use_bias, factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'mnist_str_hadamard_vgg7')
    return model

###########################################################################################################

### VGG 11 originally for (224,224,3) images

# Vanilla
def vanilla_vgg11(load_weights=False, depth=1, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(32, 32, 3), n_classes=10,use_bias=True, dense_units=(512,0)):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(1, 1, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'cifar_vanilla_vgg11')
    return model

# Unstr. Hadamard
def hadamard_vgg11(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                   input_shape=(32, 32, 3), n_classes=10,use_bias=True, factorize_bias=True, dense_units=(512,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(1, 1, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn,
                use_bias=use_bias, factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_hadamard_vgg11')
    return model

# Str. Hadamard
def str_hadamard_vgg11(load_weights=False, depth=2, kernel_initializer=tf.keras.initializers.HeNormal,
                       multfac_initializer=tf.keras.initializers.Ones, la=0, use_bn=True, input_shape=(32, 32, 3), n_classes=10,
                       use_bias=False, factorize_bias=False, dense_units=(512,0)):
    model = StructuredVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(1, 1, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,
                la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_str_hadamard_vgg11')
    return model

###########################################################################################################

### VGG 13

# Vanilla
def vanilla_vgg13(load_weights=False, depth=1, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(32, 32, 3), n_classes=10, use_bias=True, dense_units=(512,0)):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'cifar_vanilla_vgg13')
    return model

# Unstr. Hadamard
def hadamard_vgg13(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                   input_shape=(32, 32, 3), n_classes=10,use_bias=True, factorize_bias=True, dense_units=(512,0,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_hadamard_vgg13')
    return model

# Str. Hadamard
def str_hadamard_vgg13(load_weights=False, depth=2, kernel_initializer=tf.keras.initializers.HeNormal,
                       multfac_initializer=tf.keras.initializers.Ones, la=0, use_bn=True, input_shape=(32, 32, 3), n_classes=10,
                       use_bias=False, factorize_bias=False, dense_units=(512,0)):
    model = StructuredVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 2, 2, 2), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, use_bn=use_bn,
                use_bias=use_bias, factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_str_hadamard_vgg13')
    return model

###########################################################################################################

### VGG 16

# Vanilla
def vanilla_vgg16(load_weights=False, depth=2, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(32, 32, 3), n_classes=10,use_bias=True, dense_units=(512,0)):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 3, 3, 3), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'cifar_vanilla_vgg16')
    return model

# Unstr. Hadamard
def hadamard_vgg16(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                   input_shape=(32, 32, 3), n_classes=10, use_bias=True, factorize_bias=True, dense_units=(512,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 3, 3, 3), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_hadamard_vgg16')
    return model

# Str. Hadamard
def str_hadamard_vgg16(load_weights=False, depth=2, kernel_initializer=tf.keras.initializers.HeNormal,
                       multfac_initializer=tf.keras.initializers.Ones, la=0, use_bn=True, input_shape=(32, 32, 3), n_classes=10,
                      use_bias=True, factorize_bias=False, dense_units=(512,0)):
    model = StructuredVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 3, 3, 3), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, use_bn=use_bn,
                use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_str_hadamard_vgg16')
    return model

###########################################################################################################

### VGG19

# Vanilla
def vanilla_vgg19(load_weights=False, depth=1, init_type='vanilla', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                  input_shape=(32, 32, 3), n_classes=10, use_bias=True, dense_units=(512,0)):
    model = VanillaVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 4, 4, 4), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias)
    if load_weights: model = load_weights_func(model, 'cifar_vanilla_vgg19')
    return model

# Unstr. Hadamard
def hadamard_vgg19(load_weights=False, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, use_bn=True,
                   input_shape=(32, 32, 3), n_classes=10, use_bias=True, factorize_bias=True, dense_units=(512,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 4, 4, 4), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, init_type=init_type, init=init, la=la, use_bn=use_bn, use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_hadamard_vgg19')
    return model

# Str. Hadamard
def str_hadamard_vgg19(load_weights=False, depth=2, kernel_initializer=tf.keras.initializers.HeNormal,
                   multfac_initializer=tf.keras.initializers.Ones, la=0, use_bn=True,input_shape=(32, 32, 3), n_classes=10,
                   use_bias=False, factorize_bias=False, dense_units=(512,0)):
    model = HadamardVGG(input_shape=input_shape, n_classes=n_classes, group_sizes=(2, 2, 4, 4, 4), 
                features=(64, 128, 256, 512, 512), pools=(2, 2, 2, 2, 2), dense_units=dense_units,
                depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, la=la, use_bn=use_bn,
                       use_bias=use_bias,factorize_bias=factorize_bias)
    if load_weights: model = load_weights_func(model, 'cifar_str_hadamard_vgg19')
    return model

        
