import os
import tensorflow as tf

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.regularizers import l2

from layers import HadamardDense, HadamardConv2D, SparseConv2D, StrHadamardDense
from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d


# Vanilla LeNet5 model
def LeNet5(input_shape=(32,32,1), n_classes=10, init=tf.keras.initializers.HeNormal, reg='l2', la=0, 
           use_bias=True, factorize_bias=False, init_type ='vanilla', depth=1, name='VanillaLeNet5'):
    
    # Define regularizer for vanilla model
    if reg == 'l2':
           reg = tf.keras.regularizers.L2(l2 = la)
    elif reg == 'l1':
           reg = tf.keras.regularizers.L1(l1 = la)
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    x = tf.keras.layers.Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
                               kernel_regularizer=reg, use_bias=use_bias, padding='valid')(inputs)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
                              kernel_regularizer=reg, use_bias=use_bias, padding='valid')(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120,84,n_classes units
    x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
                                   use_bias=use_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model

# Vanilla LeNet5 model with batch normalization
def LeNet5BN(input_shape=(32,32,1), n_classes=10, init=tf.keras.initializers.HeNormal, reg='l2', la=0, 
           use_bias=True, factorize_bias=False, init_type ='vanilla', depth=1, name='VanillaLeNet5'):
    
    # Define regularizer for vanilla model
    if reg == 'l2':
           reg = tf.keras.regularizers.L2(l2 = la)
    elif reg == 'l1':
           reg = tf.keras.regularizers.L1(l1 = la)
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    x = tf.keras.layers.Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
                               kernel_regularizer=reg, use_bias=use_bias, padding='valid')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
                              kernel_regularizer=reg, use_bias=use_bias, padding='valid')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120,84,n_classes units
    x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
                                   use_bias=use_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model
           
# LeNet5 with overparametrization for unstructured sparsity via HPP_k          
def HadamardLeNet5(input_shape=(32,32,1), n_classes=10, depth=1, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, 
                   use_bias=True, factorize_bias=True, name= 'HadamardLeNet5'):
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    #x = tf.keras.layers. Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                           kernel_regularizer=reg, use_bias=use_bias)(inputs)
    x = HadamardConv2D(filters=6, kernel_size=(5,5), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='valid', activation='relu')(inputs)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    #x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                          kernel_regularizer=reg, use_bias=use_bias, padding='valid)(x)
    x = HadamardConv2D(filters=16, kernel_size=(5,5), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='valid', activation='relu')(x)
    #x = tf.keras.layers. MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120,84,n_classes units
    #x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
    #                               use_bias=use_bias)(x)
    x = HadamardDense(units=120, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    x = HadamardDense(units=84, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    outputs = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type=init_type, init=init, 
                            use_bias=use_bias, factorize_bias=factorize_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model

# LeNet5 with overparametrization for unstructured sparsity via HPP_k          
def HadamardLeNet5BN(input_shape=(32,32,1), n_classes=10, depth=1, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, 
                   use_bias=True, factorize_bias=True, name= 'HadamardLeNet5BN'):
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    #x = tf.keras.layers. Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                           kernel_regularizer=reg, use_bias=use_bias)(inputs)
    x = HadamardConv2D(filters=6, kernel_size=(5,5), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='valid', activation='relu')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    #x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                          kernel_regularizer=reg, use_bias=use_bias, padding='valid)(x)
    x = HadamardConv2D(filters=16, kernel_size=(5,5), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='valid', activation='relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers. MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120,84,n_classes units
    #x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
    #                               use_bias=use_bias)(x)
    x = HadamardDense(units=120, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    x = HadamardDense(units=84, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    outputs = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type=init_type, init=init, 
                            use_bias=use_bias, factorize_bias=factorize_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model

# LeNet5 with overparametrization for structured sparsity via GHPowP (Conv2D) and GHPP (FC)          
def StrHadamardLeNet5(input_shape=(32,32,1), n_classes=10, depth=1, la=0, kernel_initializer=tf.keras.initializers.HeNormal,
                      multfac_initializer=tf.keras.initializers.Ones, use_bias=True, factorize_bias=True, name= 'StrHadamardLeNet5'):
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    #x = tf.keras.layers. Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                           kernel_regularizer=reg, use_bias=use_bias)(inputs)
    x = SparseConv2D(filters=6, kernel_size=(5,5), strides=1, use_bias=use_bias, depth=depth, kernel_initializer=kernel_initializer,\
                     multfac_initializer=multfac_initializer, la=la, padding='valid', activation='relu')(inputs)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    #x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                          kernel_regularizer=reg, use_bias=use_bias, padding='valid)(x)
    x = SparseConv2D(filters=16, kernel_size=(5,5), strides=1, use_bias=use_bias, depth=depth, la=la, activation='relu',\
                     kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, padding='valid')(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120, 84, n_classes units
    #x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
    #                               use_bias=use_bias)(x)
    x = StrHadamardDense(units=120, activation='relu', depth=depth, la=la, kernel_initializer=kernel_initializer,\
                         multfac_initializer=multfac_initializer, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    x = StrHadamardDense(units=84, activation='relu', depth=depth, la=la, kernel_initializer=kernel_initializer,\
                         multfac_initializer=multfac_initializer, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    outputs = StrHadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, kernel_initializer=kernel_initializer,\
                               multfac_initializer=multfac_initializer, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model
