import config_inputsparse_compar
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import os
from lassonet import LassoNetClassifier
from pyHSICLasso import HSICLasso

try:
    import pyHSICLasso
except:
    os.system("python -m pip install pyHSICLasso")
try:
    import lassonet
except:
    os.system("python -m pip install lassonet")


import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation #,InputLayer
from tensorflow.keras import activations
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import EarlyStopping
import numpy as np
#from data_utils import one_hot_encode
from utils import get_optimizer
from models.lenet300100 import LeNet300100 #, HadamardLeNet300100, ImpHadamardLeNet300100, StrHadamardLeNet300100

def one_hot_encode(y, num_classes):
    return np.eye(num_classes)[y]

# Default model
def default_nn(train_X, train_y, test_X, test_y, 
               units1=300, units2=100,
               #lr_schedule = 'cosine',
               init_lr = 0.1,
               epochs=100,
               batch_size=256,
               #opt='sgd',
               #use_bn=False,
               use_bias=True,
               la=0
              ):

    # Compute sizes
    # hidden_layer_size = int(train_X.shape[1]/factor_hidden)
    num_classes = train_y.shape[1]
    input_size = train_X.shape[1]
    
    print(f'Input size is {input_size}')

    #model = Sequential()
    #model.add(tf.keras.Input(shape=(input_size,))) #InputLayer(shape=(input_size,))
    #model.add(Dense(300, activation='relu'))
    #model.add(Dense(100, activation='relu'))
    #model.add(Dense(nrclasses, activation='softmax'))  
    
    model = LeNet300100(input_shape=(input_size,), n_classes=num_classes, init=tf.keras.initializers.HeNormal, reg='l2', la=0, 
                use_bias=True, name=f'VanillaLeNet300100', units1=units1, units2=units2, use_bn=False)
    
    # Identical optimizer to Hadamard models
    optimizer = get_optimizer(lr_schedule='cosine', init_lr=init_lr, epochs=epochs,\
                          dat=train_X, batch_size=batch_size, opt='sgd', alpha=0)
    
    
    model.compile(optimizer=optimizer, #SGD(learning_rate=learning_rate, momentum = 0.9) 
                  loss='categorical_crossentropy', metrics=['accuracy']) 

    # Callbacks
    #callback_es = EarlyStopping(patience=50, restore_best_weights=True)

    # Fit the model with lambda = 0
    model.fit(train_X, train_y, epochs=epochs, verbose=0, batch_size=256) #callbacks=[callback_es], validation_split=0.1, 
    
    # Evaluate the model with lambda = 0
    test_accuracy = model.evaluate(test_X, test_y, verbose=1)
    
    print(f'test acc for vanilla model is {test_accuracy}')
    print(test_accuracy[0],test_accuracy[1])
    return test_accuracy[1]

def compute_input_sparsity(model):

    hml = model.layers[0].get_weights()
    fullcon = hml[1]
    simplyprod = np.prod(np.stack(hml[2:]), axis=0)
    reconstr_weight = fullcon * simplyprod[:, np.newaxis]
    return(np.mean(np.all(np.abs(reconstr_weight) < np.finfo(float).eps, axis=1).astype(int)))

# Factorize weights function
def factorize_weights(weights, depth):
    new_weights = []
    first_layer_weights = weights[0]
    
    # Create depth-1 layers of shape (input_dim,) with values of one
    input_dim = first_layer_weights.shape[0]
    ones_layers = [np.ones(input_dim) for _ in range(depth - 1)]
    
    # Append factorized weights and ones layers
    new_weights.append(first_layer_weights)
    new_weights.extend(ones_layers)
    
    for w in weights[1:]:
        new_weights.append(w)
    
    return new_weights

def hadamard_nn(train_X, train_y, test_X, test_y, depth=2,units1=300, units2=100,
               lr_schedule = 'cosine',
               init_lr = 0.2,
               init = tf.keras.initializers.HeNormal,
               init_type = 'ones',
               epochs=100,
               batch_size=256,
               opt='sgd',
               use_bn=False,
               use_bias=True,
               verbose=1,
               la=0):

    # LeNet 300-100 with input (feature) sparsity through group penalty on all outgoing weights from one input feature to first hidden layer

    # Lambda sequence
    lambda_seq = config_inputsparse_compar.LAMBDA_LIST

    # Compute sizes
    num_classes = train_y.shape[1]
    input_size = train_X.shape[1]

    # Loop over lambdas, fit model, predict, compute accuracy and sparsity
    accuracy = []
    sparsity = []

    def build_nn(_lambda):

        # Initialize model
        #if _lambda>0:        
        #    hmlayer = HadamardLayer(300, depth=depth, la=_lambda) 
        #else:
        #    hmlayer = tf.keras.layers.Dense(300)

        #model = Sequential()
        #model.add(InputLayer(shape=(nrfeatures,)))
        #model.add(hmlayer)
        #model.add(Activation(activations.relu))
        #model.add(Dense(100, activation='relu'))  
        #model.add(Dense(nrclasses, activation='softmax'))  
        model = InpHadamardLeNet300100(input_shape=(input_size,), n_classes=num_classes, depth=depth, la=_lambda, init=init,\
                                          use_bias=use_bias)

        optimizer = get_optimizer(lr_schedule=lr_schedule, init_lr=init_lr, epochs=epochs, nrow=train_y.shape[0], batch_size=batch_size, opt=opt, alpha=0)

        # Compile the model
        model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) 
    
        return model

    model = build_nn(0)

    # Initial training with lambda = 0
    print(f"Current lambda value: 0")
    
    # callback_es = EarlyStopping(monitor="val_accuracy", patience=250, restore_best_weights=True)

    # Fit the model with lambda = 0
    model.fit(train_X, train_y, verbose=verbose,
                  epochs=epochs, batch_size=batch_size)
    
    # Evaluate the model with lambda = 0
    test_accuracy = model.evaluate(test_X, test_y, verbose=0)
    accuracy.append(test_accuracy[1])
    
    this_sparsity = 0
    sparsity.append(this_sparsity)
    print(f"Sparsity: 0, Accuracy: {test_accuracy[1]:.4f}")
    
    if config_inputsparse_compar.WARMSTARTS:

        # Save weights to warm start next model
        initial_weights = model.get_weights()

        # factorize weights into the depth-root with all positive weights for simplyconnected and
        # the kernel weight getting sign(weight) as sign 
        initial_weights = factorize_weights(initial_weights, depth)

    # Subsequent training with other lambda values
    for lamb in lambda_seq:
        print(f"Current lambda value: {lamb:.6f}")

        # Set lambda for HadamardLayer
        model = build_nn(lamb)
        
        if config_inputsparse_compar.WARMSTARTS:
        
            # Reset the model weights to initial weights
            model.set_weights(initial_weights)

        # Fit the model 
        model.fit(train_X, train_y, verbose=verbose,
                  epochs=epochs, batch_size=batch_size)
        
        # Evaluate the model
        test_accuracy = model.evaluate(test_X, test_y, verbose=verbose)
        accuracy.append(test_accuracy[1])
        
        #this_sparsity = compute_input_sparsity(model)
        _,this_sparsity = compute_input_sparsity(model, depth)
        sparsity.append(this_sparsity)
        print(f"Sparsity: {this_sparsity:.4f}, Accuracy: {test_accuracy[1]:.4f}")

        # Save weights to warm start next model
        initial_weights = model.get_weights()

    return sparsity, accuracy, lambda_seq  

def hadamard_nn_depth_2(train_X, train_y, test_X, test_y, **kwargs):
    return hadamard_nn(train_X, train_y, test_X, test_y, depth=2, **kwargs)

def hadamard_nn_depth_3(train_X, train_y, test_X, test_y, **kwargs):
    return hadamard_nn(train_X, train_y, test_X, test_y, depth=3, **kwargs)

def hadamard_nn_depth_4(train_X, train_y, test_X, test_y, **kwargs):
    return hadamard_nn(train_X, train_y, test_X, test_y, depth=4, **kwargs)

# Diagonal layer building block
class SimplyConnected(tf.keras.layers.Layer):
    def __init__(self, la=0, multfac_initializer=tf.initializers.Ones):
        super(SimplyConnected, self).__init__()
        self.la = la
        self.multfac_initializer = multfac_initializer
        
    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], ),
            initializer=self.multfac_initializer,
            regularizer=tf.keras.regularizers.l2(self.la),
            trainable=True,
        )
        
    def call(self, inputs):
        return tf.math.multiply(inputs, self.w)
        
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'la': self.la
        })
        return config
    
class HadamardLayer(tf.keras.layers.Layer):
    def __init__(self, units, depth, la=0, use_bias=True, 
                 kernel_initializer=tf.keras.initializers.HeNormal(),
                 multfac_initializer=tf.initializers.Ones):
        super(HadamardLayer, self).__init__()
        self.units = units
        self.la = la
        self.depth = depth
        self.use_bias = use_bias
        self.fc_hadamard = tf.keras.layers.Dense(self.units, use_bias=False,
                                                 kernel_initializer=kernel_initializer,
                                                 kernel_regularizer=tf.keras.regularizers.l2(self.la))
        self.simply_connected = tf.keras.Sequential(
            [SimplyConnected(self.la, multfac_initializer=multfac_initializer) for _ in range(self.depth-1)]
            )
        if use_bias:
            self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)

    def call(self, inputs):
        output = self.fc_hadamard(self.simply_connected(inputs))
        if self.use_bias:
            output += self.bias
        return output
        

# HSIC+DNN
def hsic_dnn(X_train, y_train, X_test, y_test,
         downstream_model = "nn", nr_models=10):

    num_feature_seq = np.round(np.linspace(20, X_train.shape[1], num=nr_models)).astype(int)
    #num_feature_seq = np.logspace(start=np.log10(15), stop=np.log10(X_train.shape[1]), num=nr_models, endpoint=True)
    num_feature_seq = np.concatenate(([1, 5, 10], num_feature_seq))
    num_feature_seq = np.round(num_feature_seq).astype(int)
    print(f'Sequence of features is {num_feature_seq}')

    sparsity = []
    accuracy = []
    
    print(f'Downstream model for HSIC is {downstream_model}')

    for num_features in num_feature_seq:

        print(f"Current number of features: {num_features}")

        hsic_lasso = HSICLasso()
        hsic_lasso.input(X_train, y_train.squeeze())
        hsic_lasso.classification(num_features)
        
        if downstream_model == "nn":

            num_classes = np.unique(y_train).shape[0]
                        
            accuracy.append(default_nn(X_train[:, hsic_lasso.get_index()], 
                                       one_hot_encode(y_train.astype(int), num_classes), 
                                       X_test[:, hsic_lasso.get_index()], 
                                       one_hot_encode(y_test.astype(int), num_classes)
                                      ))

        else:

            clf = SVC(kernel='rbf', gamma='auto')
            clf.fit(X_train[:, hsic_lasso.get_index()], y_train.squeeze())

            y_pred = clf.predict(X_test[:, hsic_lasso.get_index()])
            accuracy.append(accuracy_score(y_test.squeeze(), y_pred))
        
        sparsity.append(1 - (num_features / X_train.shape[1]))
    
    return sparsity, accuracy, num_feature_seq

# HSIC+SVM
def hsic_svm(X_train, y_train, X_test, y_test,
         downstream_model = "svm", nr_models=10):

    num_feature_seq = np.round(np.linspace(20, X_train.shape[1], num=nr_models)).astype(int)
    #num_feature_seq = np.logspace(start=np.log10(15), stop=np.log10(X_train.shape[1]), num=nr_models, endpoint=True)
    num_feature_seq = np.concatenate(([1, 5, 10], num_feature_seq))
    num_feature_seq = np.round(num_feature_seq).astype(int)
    print(f'Sequence of features is {num_feature_seq}')

    sparsity = []
    accuracy = []
    
    print(f'Downstream model for HSIC is {downstream_model}')

    for num_features in num_feature_seq:

        print(f"Current number of features: {num_features}")

        hsic_lasso = HSICLasso()
        hsic_lasso.input(X_train, y_train.squeeze())
        hsic_lasso.classification(num_features)
        
        if downstream_model == "nn":

            num_classes = np.unique(y_train).shape[0]
                        
            accuracy.append(default_nn(X_train[:, hsic_lasso.get_index()], 
                                       one_hot_encode(y_train.astype(int), num_classes), 
                                       X_test[:, hsic_lasso.get_index()], 
                                       one_hot_encode(y_test.astype(int), num_classes)
                                      ))

        else:

            clf = SVC(kernel='rbf', gamma='auto')
            clf.fit(X_train[:, hsic_lasso.get_index()], y_train.squeeze())

            y_pred = clf.predict(X_test[:, hsic_lasso.get_index()])
            accuracy.append(accuracy_score(y_test.squeeze(), y_pred))
        
        sparsity.append(1 - (num_features / X_train.shape[1]))
    
    return sparsity, accuracy, num_feature_seq

# LassoNet
def lassoNet(X_train, y_train, X_test, y_test, M=10):

    model = LassoNetClassifier(hidden_dims = (300, 100, ), M=M, verbose=False)
    path = model.path(X_train, y_train, return_state_dicts=True)
    
    n_selected = []
    accuracy = []
    lambda_ = []

    for save in path:
        model.load(save.state_dict)
        y_pred = model.predict(X_test)
        n_selected.append(save.selected.sum())
        accuracy.append(accuracy_score(y_test, y_pred))
        lambda_.append(save.lambda_)

    sparsity = list(map(lambda x: 1-(x.numpy()/X_train.shape[1]), n_selected))
    return sparsity, accuracy, lambda_