import os
import tensorflow as tf

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, BatchNormalization
from tensorflow.keras.regularizers import l2

from layers import HadamardDense, StrHadamardDense, StrHadamardDenseV2
from initializers import TwiceTruncatedNormalInitializer, ExactNormalFactorization, equivar_initializer, equivar_initializer_conv2d


# Vanilla LeNet-300-100 model
def LeNet300100(input_shape=(28,28,1), n_classes=10, init=tf.keras.initializers.HeNormal, reg='l2', la=0, 
                use_bias=True, factorize_bias=False, init_type ='vanilla', depth=1, name='VanillaLeNet300100',
                units1=300, units2=100, use_bn=False):
    
    # Define regularizer for vanilla model
    if reg == 'l2':
           reg = tf.keras.regularizers.L2(l2 = la)
    elif reg == 'l1':
           reg = tf.keras.regularizers.L1(l1 = la)
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    if len(input_shape) >= 2:
        x = tf.keras.layers.Flatten()(inputs)
    else:
        x = inputs
    
    # Head: FC layers with 300,100,n_classes units
    x = tf.keras.layers.Dense(units=units1, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    if use_bn:
        x = tf.keras.layers.BatchNormalization(x)
    x = tf.keras.layers.Dense(units=units2, activation='relu', kernel_initializer=init, kernel_regularizer=reg, use_bias=use_bias)(x)
    if use_bn:
        x = tf.keras.layers.BatchNormalization(x)
    outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=reg,
                                   use_bias=use_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model
           
# LeNet-300-100 with overparametrization for unstructured sparsity via HPP_k          
def HadamardLeNet300100(input_shape=(28,28,1), n_classes=10, depth=2, init_type='equivar', init=tf.keras.initializers.HeNormal, la=0, 
                   use_bias=True, factorize_bias=True, name= 'HadamardLeNet300100', units1=300, units2=100, use_bn=False):
    inputs = tf.keras.layers.Input(shape=input_shape)
    if len(input_shape) >= 2:
        x = tf.keras.layers.Flatten()(inputs)
    else:
        x = inputs
    x = HadamardDense(units=units1, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    if use_bn:
        x = tf.keras.layers.BatchNormalization(x)
    x = HadamardDense(units=units2, activation='relu', depth=depth, la=la, init_type=init_type, init=init, 
                      use_bias=use_bias, factorize_bias=factorize_bias)(x)
    if use_bn:
        x = tf.keras.layers.BatchNormalization(x)
        
    # classification layer with unstructured sparsity
    outputs = HadamardDense(units=n_classes, activation='softmax', depth=depth, la=la, init_type=init_type, init=init, 
                            use_bias=use_bias, factorize_bias=factorize_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model

# LeNet-300-100 with overparametrization for structured sparsity grouped by hidden units/weights incoming to hidden unit (not outgoing/inputs!)
def StrHadamardLeNet300100(input_shape=(28,28,1), n_classes=10, depth=2, la=0, init=tf.keras.initializers.HeNormal,
                          init_rest=tf.keras.initializers.Ones, use_bias=True, factorize_bias=False, name= 'StrHadamardLeNet300100',
                          units1=300, units2=100, groupsize=None, use_bn=False):
    
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    if len(input_shape) >=2:
        x = tf.keras.layers.Flatten()(inputs)
    else:
        x = inputs
    
    x = StrHadamardDense(units=units1, activation='relu', depth=depth, la=la, init=init,\
                         init_rest=init_rest, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    if use_bn:
        x = tf.keras.layers.BatchNormalization(x)
    x = StrHadamardDense(units=units2, activation='relu', depth=depth, la=la, init=init,\
                         init_rest=init_rest, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    if use_bn:
        x = tf.keras.layers.BatchNormalization(x)
    outputs = StrHadamardDenseV2(units=n_classes, activation='softmax', depth=depth, la=la, init=init,\
                               init_rest=init_rest, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model

# LeNet-300-100 with input sparsity via GHPP_k          
def InpHadamardLeNet300100(input_shape=(None,), n_classes=26, depth=2, la=0, init=tf.keras.initializers.HeNormal,
                           init_rest=tf.keras.initializers.Ones, use_bias=True, factorize_bias=False, name= 'InpHadamardLeNet300100',
                           units1=300, units2=100, groupsize=None, use_bn=False):
    
    inputs = tf.keras.layers.Input(shape=input_shape)
    
    if len(input_shape) >=2:
        x = tf.keras.layers.Flatten()(inputs)
    else:
        x = inputs
    
    x = StrHadamardDenseV2(units=units1, activation='relu', depth=depth, la=la, init=init,init_rest=init_rest, use_bias=False, factorize_bias=False)(inputs)
    if use_bn:
        x = tf.keras.layers.BatchNormalization(x)
    x = tf.keras.layers.Dense(units=units2, activation='relu', kernel_initializer=init, kernel_regularizer=tf.keras.regularizers.L2(l2 = 0.01*la), use_bias=use_bias)(x)
    if use_bn:
        x = tf.keras.layers.BatchNormalization(x)
    outputs = tf.keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer=init, kernel_regularizer=tf.keras.regularizers.L2(l2 = 0.01*la),
                                   use_bias=use_bias)(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
    
    return model