import tensorflow as tf
import tensorflow.compat.v2.keras as keras
from tensorflow.compat.v2.keras.utils import to_categorical
import numpy as np
import random
import os
from itertools import combinations
import matplotlib.pyplot as plt
from tensorflow.compat.v2.keras.callbacks import Callback
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_context('poster')
# Set style to 'white' to remove grid lines
sns.set_style('white')

def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    # tf.experimental.numpy.random.seed(seed)
    # tf.compat.v2.numpy.random.seed(seed) # this is used for an older version of tensorflow
    tf.random.set_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)


set_seed(42)

# Set the visible GPU devices
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Visible devices must be set before GPUs have been initialized
        print(e)
tf.debugging.set_log_device_placement(True)
gpus = tf.config.list_logical_devices('GPU')
strategy = tf.distribute.MirroredStrategy(gpus)
with strategy.scope():

    save_dir = 'Area2_Bump/20_LSTM_seed_42_lr_0.0001_big kernel/'
    log_file = '/subnetwork_fine_tuning.txt'

    # Restore the training and validation data
    trainX = np.load(save_dir + 'training_data_X.npy')
    trainy = np.load(save_dir + 'training_data_Y.npy')
    testX = np.load(save_dir + 'testing_data_X.npy')
    testy = np.load(save_dir + 'testing_data_Y.npy')
    validation_dataX = testX
    validation_datay = testy

    # Get sorted indices based on y_test
    sorted_indices = np.argsort(np.argmax(testy, axis=1))

    # Rearrange X_test and y_test according to sorted indices
    X_test_sorted = testX[sorted_indices]
    y_test_sorted = testy[sorted_indices]

    sorted_indices = np.argsort(np.argmax(trainy, axis=1))

    # Rearrange X_test and y_test according to sorted indices
    X_train_sorted = trainX[sorted_indices]
    y_train_sorted = trainy[sorted_indices]

    # Create a new model instance
    # build the model
    num_layer = 1
    num_unit = 20
    architecture_type = 'LSTM'
    verbose, epochs, batch_size = 2, 200, 64
    num_class = 8
    dropout = False
    # print('Keras LSTM '+' adam optimizer '+'loss = categorical crossentropy', file=open(save_dir+'/log.txt', 'a'))
    # print('num_layer: '+str(num_layer)+' num_unit: '+str(num_unit)+' epochs: '+str(epochs)+' batch_size: '+str(batch_size), file=open(save_dir+'/log.txt', 'a'))
    architecture = str(num_layer)+' layer_'+str(num_unit)+f' units_{architecture_type}_'
    n_timesteps, n_features, n_outputs = trainX.shape[1], trainX.shape[2], trainy.shape[1]
    inputs = tf.keras.Input(shape=(n_timesteps, n_features))
    lstm = tf.keras.layers.LSTM(num_unit, return_sequences=True)(inputs)
    if dropout:
        dropout = tf.keras.layers.Dropout(0.8)(lstm)

        flatten = tf.keras.layers.Flatten()(dropout)
        outputs = tf.keras.layers.Dense(num_class, activation='softmax')(flatten)
    else:
        flatten = tf.keras.layers.Flatten()(lstm)
        outputs = tf.keras.layers.Dense(num_class, activation='softmax')(flatten)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.summary()
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    # Restore the weights
    model.load_weights(save_dir + 'model')

    # Evaluate the model
    loss, acc = model.evaluate(validation_dataX, validation_datay, verbose=2)
    print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

    def per_class_accuracy(model, testX, testy):
        # Get the predictions for the test set
        predictions = model.predict(testX)

        # Convert one-hot encoded predictions and true labels to integer class labels
        y_pred = np.argmax(predictions, axis=-1)
        y_true = np.argmax(testy, axis=-1)

        # Number of classes (from the shape of testy)
        num_classes = testy.shape[-1]

        class_accuracies = []

        # Calculate accuracy for each class
        for i in range(num_classes):
            # Get all the samples where the true label is the current class
            class_indices = (y_true == i)
            correct_preds = np.sum(y_pred[class_indices] == i)
            total_samples = np.sum(class_indices)

            # Calculate the accuracy for this class and append to list
            class_accuracy = correct_preds / total_samples
            class_accuracies.append(class_accuracy)

        return class_accuracies

    # Example usage:
    accuracies = per_class_accuracy(model, testX, testy)
    print(accuracies)


    # determine cluster
    clusters = np.load(save_dir + '/cluster_entropy_across_time_label.npy')
    print(clusters)
    cluster0 = np.where(clusters==0)
    cluster1 = np.where(clusters==1)


    def copy_weight_to_subnetwork(model, num_unit, index_array, index_array_stable):
        # EXTRACT THE WEIGHT PARAMETERS FOR THE HIDDEN LAYER
        units = int(int(model.layers[1].trainable_weights[0].shape[1]) / 4)
        # print('Nb units: ', units)
        W = model.layers[1].get_weights()[0]  # kernel weights
        U = model.layers[1].get_weights()[1]  # recurrent kernel
        b = model.layers[1].get_weights()[2]  # bias
        W_i = W[:, :units]
        W_f = W[:, units: units * 2]
        W_c = W[:, units * 2: units * 3]
        W_o = W[:, units * 3:]
        U_i = U[:, :units]
        U_f = U[:, units: units * 2]
        U_c = U[:, units * 2: units * 3]
        U_o = U[:, units * 3:]
        b_i = b[:units]
        b_f = b[units: units * 2]
        b_c = b[units * 2: units * 3]
        b_o = b[units * 3:]
        # CREATE A SUBMODEL
        num_layer = 1
        num_unit = num_unit
        print('Nb units:' + str(num_unit))
        verbose, epochs, batch_size = 2, 10, 64
        n_timesteps, n_features, n_outputs = trainX.shape[1], trainX.shape[2], trainy.shape[1]
        inputs = tf.keras.Input(shape=(trainX.shape[1], trainX.shape[2]))
        lstm = tf.keras.layers.LSTM(num_unit, return_sequences=True)(inputs)
        flatten = tf.keras.layers.Flatten()(lstm)
        outputs = tf.keras.layers.Dense(num_class, activation='softmax')(flatten)

        sub_model = tf.keras.Model(inputs=inputs, outputs=outputs)

        # !!!!!!!!!!EXTRACT SELECTED WEIGHTS FROM OG MODEL!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        original_array = np.arange(units)
        # index_array_stable = index_array_stable
        # index_array_unstable = np.delete(original_array, index_array_stable) # remove a subset of units from the original network
        # index_array = np.random.choice(original_array, size=num_unit, replace=False) # randomly choose a number of units from the original network
        index_array = index_array
        print(index_array)

        if np.intersect1d(index_array, index_array_stable).size == 0:
            is_all_unstable = True
        else:
            is_all_unstable = False
        W_i_sub = np.stack([W_i[:, i] for i in index_array], axis=1)
        print(W_i.shape)
        print(W_i_sub.shape)
        W_f_sub = np.stack([W_f[:, i] for i in index_array], axis=1)
        print(W_f.shape)
        print(W_f_sub.shape)
        W_c_sub = np.stack([W_c[:, i] for i in index_array], axis=1)
        print(W_c.shape)
        print(W_c_sub.shape)
        W_o_sub = np.stack([W_o[:, i] for i in index_array], axis=1)
        print(W_o.shape)
        print(W_o_sub.shape)
        U_i_sub = np.stack([U_i[:, i] for i in index_array], axis=1)
        U_i_sub = np.stack([U_i_sub[i, :] for i in index_array], axis=0)
        print(U_i.shape)
        print(U_i_sub.shape)
        U_f_sub = np.stack([U_f[:, i] for i in index_array], axis=1)
        U_f_sub = np.stack([U_f_sub[i, :] for i in index_array], axis=0)
        print(U_f.shape)
        print(U_f_sub.shape)
        U_c_sub = np.stack([U_c[:, i] for i in index_array], axis=1)
        U_c_sub = np.stack([U_c_sub[i, :] for i in index_array], axis=0)
        print(U_c.shape)
        print(U_c_sub.shape)
        U_o_sub = np.stack([U_o[:, i] for i in index_array], axis=1)
        U_o_sub = np.stack([U_o_sub[i, :] for i in index_array], axis=0)
        print(U_o.shape)
        print(U_o_sub.shape)
        b_i_sub = np.stack([b_i[i] for i in index_array], axis=0)
        print(b_i.shape)
        print(b_i_sub.shape)
        b_f_sub = np.stack([b_f[i] for i in index_array], axis=0)
        print(b_f.shape)
        print(b_f_sub.shape)
        b_c_sub = np.stack([b_c[i] for i in index_array], axis=0)
        print(b_c.shape)
        print(b_c_sub.shape)
        b_o_sub = np.stack([b_o[i] for i in index_array], axis=0)
        print(b_o.shape)
        print(b_o_sub.shape)
        # Recreate weight matrices with good sizes
        W_sub = np.concatenate((W_i_sub, W_f_sub, W_c_sub, W_o_sub), axis=1)
        print(W_sub.shape)
        U_sub = np.concatenate((U_i_sub, U_f_sub, U_c_sub, U_o_sub), axis=1)
        print(U_sub.shape)
        b_sub = np.concatenate((b_i_sub, b_f_sub, b_c_sub, b_o_sub), axis=0)
        print(b_sub.shape)
        # GET WEIGHTS FOR OUTPUT LAYER
        output_weights = model.layers[3].get_weights()[0]
        array_reshaped = np.arange(12000).reshape(600, 20)
        print(array_reshaped.shape)
        indeces = array_reshaped[:, index_array]
        print(indeces.shape)
        indeces = indeces.flatten()
        print(indeces.shape)
        output_weights_sub = np.array([output_weights[i, :] for i in indeces])
        print('last')
        print(output_weights_sub.shape)
        # FIT THE NEW WEIGHTS TO THE SUBMODEL
        sub_model.layers[1].set_weights([W_sub, U_sub, b_sub])
        sub_model.layers[3].set_weights(
            [output_weights_sub, model.layers[3].get_weights()[1]])  # (keep old biases of output layer)
        sub_model.summary()
        sub_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
        return sub_model, is_all_unstable


    # Custom callback to stop training when validation accuracy peaks
    class MaxValidationAccuracyCallback(Callback):
        def __init__(self):
            super(MaxValidationAccuracyCallback, self).__init__()
            self.max_val_acc = 0.0
            self.epochs_since_last_max = 0

        def on_epoch_end(self, epoch, logs=None):
            current_val_acc = logs.get('val_accuracy')
            if current_val_acc > self.max_val_acc:
                self.max_val_acc = current_val_acc
                self.epochs_since_last_max = 0
            else:
                self.epochs_since_last_max += 1

            # Stop training if validation accuracy hasn't improved in the last 5 epochs
            if self.epochs_since_last_max >= 5:
                self.model.stop_training = True


    def test_subnetwork_epochs(subnetwork_model, trainX, trainy, testX, testy):
        callback = MaxValidationAccuracyCallback()

        # Assuming you have categorical data, adjust loss and metrics if not
        subnetwork_model.compile(optimizer='adam',
                                 loss='categorical_crossentropy',
                                 metrics=['accuracy'])
        # sub_model.summary()
        subnetwork_model.fit(trainX, trainy,
                             validation_data=(testX, testy),
                             epochs=100,  # Set a high number for epochs
                             batch_size=32,
                             verbose=2,  # Suppress output
                             callbacks=[callback])

        # Number of epochs trained is equal to the length of the history
        num_epochs = len(subnetwork_model.history.history['val_accuracy'])

        return num_epochs, callback.max_val_acc


    all_stable_epochs_needed = []
    all_unstable_epochs_needed = []
    all_stable_plus_unstable_epochs_needed = []
    all_unstable_plus_stable_epochs_needed = []
    half_stable_half_unstable_epochs_needed = []

    all_stable_max_accuracy = []
    all_unstable_max_accuracy = []
    all_stable_plus_unstable_max_accuracy = []
    all_unstable_plus_stable_max_accuracy = []
    half_stable_half_unstable_max_accuracy = []

    all_stable_per_class = []
    all_unstable_per_class = []
    all_stable_plus_unstable_per_class = []
    all_unstable_plus_stable_per_class = []
    half_stable_half_unstable_per_class = []


    clusters_per_class = []

    counter = 0

    def plot_confusion_matrix(ax, cm, vmax, title='CM'):
        sns.set('poster')  # for plot styling
        cm = cm.astype(float)
        cm[cm == 0] = np.nan
        sns.heatmap(cm, annot=True, fmt=".0f", linewidths=.5, cmap='Blues', ax=ax, cbar=False,
                    vmax=vmax)  # Use common vmax
        ax.set_ylabel('Actual labels', fontsize=28)
        ax.set_xlabel('Predicted labels', fontsize=28)
        ax.set_title(title, fontsize=28)
        ax.spines['top'].set_visible(True)
        ax.spines['right'].set_visible(True)
        ax.spines['bottom'].set_visible(True)
        ax.spines['left'].set_visible(True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5)


    count = 0
    fig_val, axes_val = plt.subplots(1, 2, figsize=(18, 8))  # Two columns for validation
    fig_train, axes_train = plt.subplots(1, 2, figsize=(18, 8))  # Two columns for training

    # To ensure consistent color scale across all plots
    vmax = 0  # This will hold the maximum value for the color scale
    val_conf_matrices = []
    train_conf_matrices = []

    for cluster in [cluster0, cluster1]:
        sub_model, is_all_unstable = copy_weight_to_subnetwork(model, len(cluster[0]), cluster[0],
                                                               index_array_stable=[])
        per_class_accuracies = per_class_accuracy(sub_model, validation_dataX, validation_datay)
        print('accuracy:' + str(per_class_accuracies))
        clusters_per_class.append(per_class_accuracies)

        # Predictions and confusion matrices for validation
        preds = sub_model.predict(X_test_sorted)
        y_pred = np.argmax(preds, axis=1)
        true_labels = np.argmax(y_test_sorted, axis=1)
        cm = confusion_matrix(true_labels, y_pred)
        val_conf_matrices.append(cm)

        # Predictions and confusion matrices for training
        preds = sub_model.predict(X_train_sorted)
        y_pred = np.argmax(preds, axis=1)
        true_labels = np.argmax(y_train_sorted, axis=1)
        cm = confusion_matrix(true_labels, y_pred)
        train_conf_matrices.append(cm)

        count += 1

    # Calculate the maximum value for the color scale
    val_vmax = max([np.max(cm) for cm in val_conf_matrices])
    train_vmax = max([np.max(cm) for cm in train_conf_matrices])
    vmax = max(val_vmax, train_vmax)

    # Now plot using this vmax for all subplots
    count = 0
    for cm, ax_val in zip(val_conf_matrices, axes_val):
        plot_confusion_matrix(ax_val, cm, vmax, title=f'CM for Cluster {count}, Validation')
        ax_val.set_ylabel('')
        count += 1
    count = 0
    for cm, ax_train in zip(train_conf_matrices, axes_train):
        plot_confusion_matrix(ax_train, cm, vmax, title=f'CM for Cluster {count}, Training')
        if count == 1:
            ax_train.set_ylabel('')
        count += 1

    # Colorbar setup
    fig_val.colorbar(plt.cm.ScalarMappable(cmap='Blues', norm=plt.Normalize(vmin=0, vmax=vmax)), ax=axes_val,
                     orientation='vertical', fraction=0.046, pad=0.04)
    # fig_train.colorbar(plt.cm.ScalarMappable(cmap='Blues', norm=plt.Normalize(vmin=0, vmax=vmax)), ax=axes_train,
                       # orientation='vertical', fraction=0.046, pad=0.04)

    # Save figures
    fig_val.savefig(save_dir + '/validation_conf_matrices.png', dpi=300)
    fig_train.savefig(save_dir + '/training_conf_matrices.png', dpi=300)

    clusters_per_class = np.array(clusters_per_class)

    print(clusters_per_class)