import os
import tensorflow as tf

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.regularizers import l2

# Module is always loaded from hadamard/python folder location
from layers import HadamardDense, HadamardConv2D, SparseConv2D, StrHadamardDense, StrConv2D
from initializers import TwiceTruncatedNormalInitializer, ExactNormalFactorization, equivar_initializer, equivar_initializer_conv2d


# Vanilla LeNet5 model
def LeNet5(input_shape=(32,32,1), n_classes=10, init=tf.keras.initializers.HeNormal, reg='l2', la=0, 
           use_bias=False, factorize_bias=False, init_type ='vanilla', depth=1, name='VanillaLeNet5'):
    
    # Define regularizer for vanilla model
    if reg == 'l2':
           reg = tf.keras.regularizers.L2(l2 = la)
    elif reg == 'l1':
           reg = tf.keras.regularizers.L1(l1 = la)
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    x = tf.keras.layers.Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
                               kernel_regularizer=reg, use_bias=use_bias, padding='valid')(inputs)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
                              kernel_regularizer=reg, use_bias=use_bias, padding='valid')(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120,84,n_classes units
    x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
                                   use_bias=use_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model

# Vanilla LeNet5 model with batch normalization
def LeNet5BN(input_shape=(32,32,1), n_classes=10, init=tf.keras.initializers.HeNormal, reg='l2', la=0, 
           use_bias=False, factorize_bias=False, init_type ='vanilla', depth=1, name='VanillaLeNet5'):
    
    # Define regularizer for vanilla model
    if reg == 'l2':
           reg = tf.keras.regularizers.L2(l2 = la)
    elif reg == 'l1':
           reg = tf.keras.regularizers.L1(l1 = la)
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    x = tf.keras.layers.Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
                               kernel_regularizer=reg, use_bias=use_bias, padding='valid')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
                              kernel_regularizer=reg, use_bias=use_bias, padding='valid')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120,84,n_classes units
    x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
                                   use_bias=use_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model
           
# LeNet5 with overparametrization for unstructured sparsity via HPP_k          
def HadamardLeNet5(input_shape=(32,32,1), n_classes=10, depth=1, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, 
                   use_bias=True, factorize_bias=True, name= 'HadamardLeNet5'):
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    #x = tf.keras.layers. Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                           kernel_regularizer=reg, use_bias=use_bias)(inputs)
    x = HadamardConv2D(filters=6, kernel_size=(5,5), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='valid', activation='relu')(inputs)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    #x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                          kernel_regularizer=reg, use_bias=use_bias, padding='valid)(x)
    x = HadamardConv2D(filters=16, kernel_size=(5,5), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='valid', activation='relu')(x)
    #x = tf.keras.layers. MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120,84,n_classes units
    #x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
    #                               use_bias=use_bias)(x)
    x = HadamardDense(units=120, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    x = HadamardDense(units=84, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    outputs = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type=init_type, init=init, 
                            use_bias=use_bias, factorize_bias=factorize_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model

# LeNet5 with overparametrization for unstructured sparsity via HPP_k          
def HadamardLeNet5BN(input_shape=(32,32,1), n_classes=10, depth=1, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, 
                   use_bias=True, factorize_bias=True, name= 'HadamardLeNet5BN'):
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    #x = tf.keras.layers. Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                           kernel_regularizer=reg, use_bias=use_bias)(inputs)
    x = HadamardConv2D(filters=6, kernel_size=(5,5), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='valid', activation='relu')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    #x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                          kernel_regularizer=reg, use_bias=use_bias, padding='valid)(x)
    x = HadamardConv2D(filters=16, kernel_size=(5,5), strides=1, use_bias=use_bias, factorize_bias=factorize_bias,
                          depth=depth, init_type=init_type, init=init,la=la, padding='valid', activation='relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers. MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120,84,n_classes units
    #x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    #outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
    #                               use_bias=use_bias)(x)
    x = HadamardDense(units=120, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    x = HadamardDense(units=84, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    outputs = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type=init_type, init=init, 
                            use_bias=use_bias, factorize_bias=factorize_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model

# LeNet5 with overparametrization for structured sparsity via GHPP_k (Conv2D)          
def StrHadamardLeNet5BN(input_shape=(32,32,1), n_classes=10, depth=1, la=0, kernel_initializer=tf.keras.initializers.HeNormal,
                      multfac_initializer=tf.keras.initializers.Ones, use_bias=True, factorize_bias=True, name= 'StrHadamardLeNet5BN'):
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    # First Conv2D: 6 filters and 5x5 kernel
    #x = tf.keras.layers. Conv2D(filters=6, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                           kernel_regularizer=reg, use_bias=use_bias)(inputs)
    #x = SparseConv2D(filters=6, kernel_size=(5,5), strides=1, use_bias=use_bias, depth=depth, kernel_initializer=kernel_initializer,\
    #                 multfac_initializer=multfac_initializer, la=la, padding='valid', activation='relu')(inputs)
    x = StrConv2D(filters=6, kernel_size=(5,5), strides=1, use_bias=use_bias, depth=depth, kernel_initializer=kernel_initializer, activation='relu',\
                     multfac_initializer=multfac_initializer, la=la, padding='valid')(inputs)
    #StrConv2D(filters=filters, kernel_size=(3,3), strides=1, use_bias=False, #use_bias
    #                    depth=depth, kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer,
    #                    la=la, padding='same')
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    
    # Second Conv2D: 16 filters and 5x5  kernel size.
    #x = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1, activation='relu', kernel_initializer=init,
    #                          kernel_regularizer=reg, use_bias=use_bias, padding='valid)(x)
    #x = SparseConv2D(filters=16, kernel_size=(5,5), strides=1, use_bias=use_bias, depth=depth, la=la, activation='relu',\
    #                 kernel_initializer=kernel_initializer, multfac_initializer=multfac_initializer, padding='valid')(x)
    x = StrConv2D(filters=16, kernel_size=(5,5), strides=1, use_bias=use_bias, depth=depth, la=la,\
                     kernel_initializer=kernel_initializer, activation='relu', multfac_initializer=multfac_initializer, padding='valid')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    
    # Head: FC layers with 120, 84, n_classes units
    x = tf.keras.layers.Dense(units=120, activation='relu', kernel_initializer=kernel_initializer, kernel_regularizer=tf.keras.regularizers.L2(l2=0.01*la), use_bias=True)(x)
    x = tf.keras.layers.Dense(units=84, activation='relu', kernel_initializer=kernel_initializer, kernel_regularizer=tf.keras.regularizers.L2(l2=0.01*la), use_bias=True)(x)
    outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=kernel_initializer, kernel_regularizer=tf.keras.regularizers.L2(l2=0.01*la),
                                   use_bias=True)(x)
    #x = StrHadamardDense(units=120, activation='relu', depth=depth, la=la, kernel_initializer=kernel_initializer,\
    #                     multfac_initializer=multfac_initializer, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    #x = StrHadamardDense(units=84, activation='relu', depth=depth, la=la, kernel_initializer=kernel_initializer,\
    #                     multfac_initializer=multfac_initializer, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    #outputs = StrHadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, kernel_initializer=kernel_initializer,\
    #                           multfac_initializer=multfac_initializer, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model