import random
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
import numpy
import numpy as np
from numpy import mean
from numpy import std
from numpy import dstack
from pandas import read_csv
import tensorflow as tf
import tensorflow.compat.v2.keras as keras
from tensorflow.compat.v2.keras.utils import to_categorical
from tensorflow.compat.v2.keras.callbacks import ModelCheckpoint
from tensorflow.compat.v2.keras import backend as K
from tensorflow.compat.v2.keras.callbacks import TensorBoard
from tensorflow.compat.v2.keras.losses import get
from tensorflow.keras.optimizers import Adam
import scprep
from scipy.fft import fft, fftfreq
import m_phate.train
import m_phate.data
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import random
import os
from numpy import asarray
from numpy import save
import scipy
from sklearn.preprocessing import OneHotEncoder
import glob
import itertools as it
from nlb_tools.nwb_interface import NWBDataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical


#%% Load the Data from the original file and preprocess

# dataset = NWBDataset("mm-phate/Area2_Bump/000127\sub-Han", "*train", split_heldout=False)
# # Smooth spikes with 40 ms std Gaussian kernel
# dataset.smooth_spk(40, name='smth_40')
# # Choose lag value
# lag = 40
# align_field='move_onset_time'
# align_range=(-100, 500)
# neur_num = 1001
# # All 16 conditions, in the format (ctr_hold_bump, cond_dir)
# # unique_conditions = [(False, 0.0), (False, 45.0), (False, 90.0), (False, 135.0),
# #                      (False, 180.0), (False, 225.0), (False, 270.0), (False, 315.0)]
# unique_conditions = [(False, 0.0), (False, 90.0), (False, 180.0), (False, 270.0)] # Edit to select the directions for classification
# data = []
# label = []
#
# # Loop through conditions
# for idx, cond in enumerate(unique_conditions):
#     # Filter out invalid trials (labeled 'none') and trials in other conditions
#     cond_mask = (np.all(dataset.trial_info[['ctr_hold_bump', 'cond_dir']] == cond, axis=1)) & \
#                 (dataset.trial_info.split != 'none')
#     # Extract relevant portion of selected trials
#     cond_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-100, 500), ignored_trials=~cond_mask)['spikes_smth_40'].to_numpy().reshape(-1, 600, 65)
#     cond_label = idx * np.ones(cond_data.shape[0])
#     data.append(cond_data)
#     label.append(cond_label)
# data = np.concatenate(data, axis=0)
# label = np.concatenate(label, axis=0)
#
# # Find the minimum class sample size
# _, counts = np.unique(label, return_counts=True)
# min_samples = np.min(counts)
#
# # Balance the dataset
# balanced_data = []
# balanced_label = []
# leftover_data = []
# leftover_label = []
# for unique_label in np.unique(label):
#     class_idx = np.where(label == unique_label)[0]
#     np.random.shuffle(class_idx)  # Shuffle indices
#     class_idx_balanced = class_idx[:min_samples]  # Keep only `min_samples` indices
#     leftover_idx = class_idx[min_samples:]
#     leftover_data.append(data[leftover_idx])
#     leftover_label.append(label[leftover_idx])
#     balanced_data.append(data[class_idx_balanced])
#     balanced_label.append(label[class_idx_balanced])
#
# # Convert lists to numpy arrays
# balanced_data = np.concatenate(balanced_data, axis=0)
# balanced_label = np.concatenate(balanced_label, axis=0)
# leftover_data = np.concatenate(leftover_data, axis=0)
# leftover_label = np.concatenate(leftover_label, axis=0)
#
# # Shuffle the balanced dataset
# balanced_data, balanced_label = shuffle(balanced_data, balanced_label, random_state=42)
# leftover_data, leftover_label = shuffle(leftover_data, leftover_label, random_state=42)
#
# X_train, X_test, y_train, y_test = train_test_split(balanced_data, balanced_label, test_size=0.2, random_state=42)
#
# X_test = np.concatenate((X_test,leftover_data),axis=0)
# y_test = np.concatenate((y_test,leftover_label),axis=0)
#
# # Convert labels to one-hot encoding
# y_train_categorical = to_categorical(y_train)
# y_test_categorical = to_categorical(y_test)


#%% Load the preprocessed data
trainX = np.load('Area2_Bump/trainX.npy')
trainy = np.load('Area2_Bump/trainy.npy')
testX = np.load('Area2_Bump/testX.npy')
testy = np.load('Area2_Bump/testy.npy')

# Convert labels to one-hot encoding
trainy = to_categorical(trainy)
testy = to_categorical(testy)

#%% Set the visible GPU devices
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Visible devices must be set before GPUs have been initialized
        print(e)
tf.debugging.set_log_device_placement(True)
gpus = tf.config.list_logical_devices('GPU')
strategy = tf.distribute.MirroredStrategy(gpus)
with (strategy.scope()):
    def set_seed(seed: int = 42) -> None:
        random.seed(seed)
        np.random.seed(seed)
        # tf.experimental.numpy.random.seed(seed)
        # tf.compat.v2.numpy.random.seed(seed) # this is used for an older version of tensorflow
        tf.random.set_seed(seed)
        # When running on the CuDNN backend, two further options must be set
        os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        # Set a fixed value for the hash seed
        os.environ["PYTHONHASHSEED"] = str(seed)

    for seed in [42]:
        print('seed:', seed)
        set_seed(seed)
        # select the digit indices in the test set
        trace_idx = []
        num_sample = 5
        num_class = 8
        for i in range(num_class):
            print(i)
            condition_indices = np.argwhere(trainy[:, i] == 1).flatten()

            # Safeguard in case there are fewer available samples than num_sample
            actual_num_samples = min(len(condition_indices), num_sample)

            sampled_indices = np.random.choice(condition_indices, actual_num_samples, replace=False)
            trace_idx.append(sampled_indices)
        trace_idx = np.concatenate(trace_idx)

        # extract the selected images
        x_trace = trainX[trace_idx]
        #%%
        class TestCallback(keras.callbacks.Callback):
            def __init__(self, test_data):
                self.test_data = test_data
                self.history = {}

            def on_epoch_end(self, epoch, logs):
                x, y = self.test_data
                loss, acc = self.model.evaluate(x, y, verbose=0)
                logs['test_accuracy'] = acc
                logs['val_loss'] = loss
                for k, v in logs.items():
                    self.history.setdefault(k, []).append(v)

                # Set the history attribute on the model after the epoch ends. This will
                # make sure that the state which is set is the latest one.
                self.model.history = self

        # Model Variables
        num_layer = 1
        num_unit = 20
        architecture_type = 'LSTM'
        intrinsic_steps = 600
        intrinsic_step_sample_size = 100
        learning_rate = 0.00010
        verbose, epochs, batch_size = 2, 200, 64
        epoch_sample_step_after_epoch_30 = 5
        dropout = False
        # Set the path to the directory you want to save the plots to
        save_dir = f"Area2_Bump/{num_unit}_{architecture_type}_seed_{seed}_lr_{learning_rate}_big kernel"
        # Check if the directory already exists
        if os.path.exists(save_dir):
            # If it exists, create a new directory with a unique name
            i = 1
            while os.path.exists(save_dir + f"_{i}"):
                i += 1
            save_dir = save_dir + f"_{i}"
            os.makedirs(save_dir)
        else:
            # If it doesn't exist, create the directory
            os.makedirs(save_dir)

        os.makedirs(save_dir+'/gradient')

        print(save_dir)
        print(f'seed: {seed}',  file=open(save_dir+'/log.txt', 'a'))
        print(f'Number of class: {num_class}',
              file=open(save_dir + '/log.txt', 'a'))


        #%%
        # build the model
        print('Keras LSTM '+' adam optimizer '+'loss = categorical crossentropy', file=open(save_dir+'/log.txt', 'a'))
        print('num_layer: '+str(num_layer)+' num_unit: '+str(num_unit)+' epochs: '+str(epochs)+' batch_size: '+str(batch_size), file=open(save_dir+'/log.txt', 'a'))
        architecture = str(num_layer)+' layer_'+str(num_unit)+f' units_{architecture_type}_'

        inputs = tf.keras.Input(shape=(trainX.shape[1], trainX.shape[2]))
        lstm = tf.keras.layers.LSTM(num_unit, return_sequences=True)(inputs)
        if dropout:
            dropout = tf.keras.layers.Dropout(0.8)(lstm)

            flatten = tf.keras.layers.Flatten()(dropout)
            outputs = tf.keras.layers.Dense(num_class, activation='softmax')(flatten)
        else:
            flatten = tf.keras.layers.Flatten()(lstm)
            outputs = tf.keras.layers.Dense(num_class, activation='softmax')(flatten)

        model = tf.keras.Model(inputs=inputs, outputs=outputs)

        # Display the model's architecture
        model.summary()

        # Set custom learning rate
        print(f'Adam optimizer learning rate: {learning_rate}', file=open(save_dir + '/log.txt', 'a'))
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

        model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

        # Callback to save model after each epoch
        checkpoint_callback = ModelCheckpoint(
            save_dir + '/model_epoch_{epoch:03d}.h5',
            save_best_only=False,
            save_freq='epoch'
        )


        # Initialize TensorBoard and Custom Callback
        # gradient_logger = GradientLogger(val_data=(trainX, trainy),
        #                                  log_dir=save_dir+'/gradient')

        # build trace model helper
        model_trace = keras.models.Model(inputs=inputs, outputs=[lstm])
        trace = m_phate.train.TraceHistory(x_trace, model_trace)
        # Create another callback to store loss, accuracy etc
        history = keras.callbacks.History()
        test_callback = TestCallback((testX,testy))
        # fit network
        model.fit(trainX, trainy, validation_data=(testX, testy), callbacks=[trace, history, test_callback#, checkpoint_callback#, gradient_logger
                                                                             ], epochs=epochs, batch_size=batch_size, verbose=verbose)
        # evaluate model
        _, accuracy = model.evaluate(testX, testy, batch_size=batch_size, verbose=0)
        score = accuracy * 100.0
        print('>#%d: %.3f' % (1, score), file=open(save_dir+'/log.txt', 'a'))

        # Save the weights
        model.save_weights(save_dir + '/model')

        # save trace data
        # save to csv file
        data = np.array(trace.trace)
        np.save(save_dir+'/'+architecture+'_trace_data', data)

        # save training and testing data
        np.save(save_dir+'/'+'training_data_X', trainX)
        np.save(save_dir+'/'+'training_data_Y', trainy)
        np.save(save_dir+'/'+'testing_data_X', testX)
        np.save(save_dir+'/'+'testing_data_Y', testy)
        np.save(save_dir+'/'+'accuracy', test_callback.history['accuracy'])
        np.save(save_dir+'/'+'test_accuracy', test_callback.history['test_accuracy'])
        np.save(save_dir+'/'+'loss', test_callback.history['loss'])
        np.save(save_dir + '/' + 'val_loss', test_callback.history['val_loss'])

        # summary history for accuracy
        fig_accuracy = plt.figure()
        plt.plot(test_callback.history['accuracy'], c='b')
        plt.plot(test_callback.history['test_accuracy'], c='r')
        plt.title('model accuracy')
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['train','val'], loc='upper left')
        # plt.show()
        fig_accuracy.savefig(save_dir+'/'+architecture+'accuracy.png')

        # summary history for loss
        peaks, _ = find_peaks(test_callback.history['loss'], prominence=0.2)  # identify the peaks and return the idx
        print('peaks: ' + str(peaks),  file=open(save_dir+'/log.txt', 'a'))
        print('loss' + str([test_callback.history['loss'][i] for i in peaks]),  file=open(save_dir+'/log.txt', 'a'))
        minimum = np.argmin(np.array(test_callback.history['loss']))
        print(f'minimum: {minimum}',  file=open(save_dir+'/log.txt', 'a'))
        a = [test_callback.history['loss'][minimum]]
        print(f'loss {a}',  file=open(save_dir+'/log.txt', 'a'))
        fig_loss = plt.figure()
        plt.plot(test_callback.history['loss'], c='b')
        plt.plot(test_callback.history['val_loss'], c='r')
        plt.title('model loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train', 'val'], loc='upper left')
        # plt.show()
        fig_loss.savefig(save_dir+'/'+architecture+'loss.png')

    # M-PHATE
    mphate_2D = False
    mphate_3D = True
    # extract trace data

    entropies = []
    variances = []
    # epoch_samples = np.array([0,1,2])
    # epoch_samples = np.arange(200)
    epoch_samples = np.concatenate([np.arange(29), np.arange(29, epochs, epoch_sample_step_after_epoch_30)])

    # Downsampling to 150 samples, including the first and last step
    intrinsic_step_samples = np.linspace(0, intrinsic_steps-1, intrinsic_step_sample_size, endpoint=True, dtype=int)

    np.save(save_dir + '/epoch_samples.npy', epoch_samples)
    variance = []
    print(' ', file=open(save_dir + '/log.txt', 'a'))
    # trace.trace has shape (epochs,units,intrinsic steps, samples)
    trace_data = np.array(trace.trace).transpose(0, 2, 1, 3)[
        epoch_samples]  # (epochs, intrinsic steps, units, samples)
    print(f'trace data shape: {np.shape(trace_data)}')
    # normalize the data
    trace_data = trace_data[:, intrinsic_step_samples, :, :]
    trace_data_norm = m_phate.utils.normalize(trace_data)  # (epochs, intrinsic steps, units, samples)
    trace_data = trace_data.reshape(len(epoch_samples) * intrinsic_step_sample_size, num_unit,
                                    num_sample*num_class)  # (epochs*intrinsic steps, units, samples)
    print(f'trace data shape: {np.shape(trace_data)}')
    intrinsic_step = np.tile(intrinsic_step_samples, len(epoch_samples))
    intrinsic_step = np.repeat(intrinsic_step, num_unit)  # (intrinsic_step*num_unit)
    unit = np.tile(np.arange(trace_data.shape[1]), len(epoch_samples) * intrinsic_step_sample_size)
    epoch_label = np.repeat(epoch_samples, intrinsic_step_sample_size * num_unit)

    # the label of each digit we selected: this should be the same as `np.repeat(np.arange(10), 10)`
    digit_ids = np.repeat(np.arange(num_class), num_sample)

    digit_activity = np.empty((len(epoch_samples), num_class, intrinsic_step_sample_size, num_unit))
    for idx, the_epoch in enumerate(epoch_samples):
        # the average activity over digit labels for each element of the flattened trace
        trace_data_norm_sample = trace_data_norm[idx]
        digit_activity[idx] = np.array(
            [np.sum(np.abs(trace_data_norm_sample[:, :, digit_ids == digit]), axis=2)
             for digit in np.unique(digit_ids)])
    np.save(save_dir + '/' + architecture + '_accuracy ' + (
                '%.3f' % test_callback.history['accuracy'][-1]) + '_digit activity', digit_activity)

    # the digit label with the highest average activity for each element of the flattened trace
    most_active_output = np.empty((len(epoch_samples), intrinsic_step_sample_size * num_unit))
    for idx, the_epoch in enumerate(epoch_samples):
        most_active_output[idx] = np.argmax(digit_activity[idx], axis=0).flatten()
    most_active_output = most_active_output.flatten()
    unique, counts = np.unique(most_active_output, return_counts=True)
    print(np.asarray((unique, counts)).T, file=open(save_dir + '/log.txt', 'a'))
    print('loss: ' + str(test_callback.history['loss'][-1]), file=open(save_dir + '/log.txt', 'a'))
    print('training_accuracy: ' + str(test_callback.history['accuracy'][-1]), file=open(save_dir + '/log.txt', 'a'))
    print('testing_accuracy: ' + str(test_callback.history['test_accuracy'][-1]),
          file=open(save_dir + '/log.txt', 'a'))
    np.save(save_dir + '/' + architecture + '_accuracy ' + (
            '%.3f' % test_callback.history['accuracy'][-1]) + '_m-phate_most_active_output',
            most_active_output)
    np.save(save_dir + '/' + architecture + '_accuracy ' + (
            '%.3f' % test_callback.history['accuracy'][-1]) + '_m-phate_intrinsic_step', intrinsic_step)
    np.save(save_dir + '/' + architecture + '_accuracy ' + (
            '%.3f' % test_callback.history['accuracy'][-1]) + '_m-phate_hidden_unit', unit)
    np.save(save_dir + '/' + architecture + '_accuracy ' + (
            '%.3f' % test_callback.history['accuracy'][-1]) + '_m-phate_epoch_label', epoch_label)

    # apply M-PHATE
    m_phate_op = m_phate.M_PHATE(n_jobs=1)
    m_phate_data = m_phate_op.fit_transform(trace_data)
    np.save(save_dir + '/' + architecture + '_accuracy ' + (
            '%.3f' % test_callback.history['accuracy'][-1]) + '_m-phate 2D', m_phate_data)
    if mphate_2D:

        # plot the result in 2D
        scprep.plot.scatter2d(m_phate_data, c=intrinsic_step, ticks=True,
                              label_prefix="M-PHATE",
                              legend_title="intrinsic step",
                              filename=save_dir+'/'+architecture+'epoch '+str(i)+'_accuracy '+('%.3f' % test_callback.history['accuracy'][-1])+'_intrinsic step.png',
                              title="intrinsic step",
                              figsize=(8,8),
                              dpi=1600)
        plt.close()
        # scprep.plot.scatter2d(m_phate_data, c=layer, ticks=False,
        #                       label_prefix="M-PHATE",
        #                       legend_title="layer")
        scprep.plot.scatter2d(m_phate_data, c=unit, ticks=True,
                              label_prefix="M-PHATE",
                              legend_title="hidden unit",
                              legend_loc='upper left',
                              legend_anchor=(1,1),
                              figsize=(9, 7.5),
                              filename=save_dir+'/'+architecture+'epoch '+str(i)+'_accuracy '+('%.3f' % test_callback.history['accuracy'][-1])+'_hidden unit.png',
                              title="hidden unit",
                              dpi=1600)
        plt.close()
        scprep.plot.scatter2d(m_phate_data, c=most_active_output, ticks=True,
                              label_prefix="M-PHATE",
                              legend_title="most active output",
                              legend_loc='upper left',
                              filename=save_dir+'/'+architecture+'epoch '+str(i)+'_accuracy '+('%.3f' % test_callback.history['accuracy'][-1])+'_most active output.png',
                              cmap=['red','blue'],
                              title="most active output",
                              legend_anchor=(1,1),
                              figsize=(9, 7.5),
                              dpi=1600)
        plt.close()

    if mphate_3D:
        # 3D transform
        m_phate_op.set_params(n_components=3)
        m_phate_data_3D = m_phate_op.transform()

        np.save(save_dir + '/' + architecture + '_accuracy ' + (
                    '%.3f' % test_callback.history['accuracy'][-1]) + '_m-phate 3D', m_phate_data_3D)
