import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.ticker import MaxNLocator
import seaborn as sns
from sklearn.neighbors import KernelDensity
import os

sns.set_context('poster')
# Set style to 'white' to remove grid lines
sns.set_style('white')

folder_path = 'Area2_Bump/20_LSTM_seed_42_lr_0.0001_big kernel/'  # replace with the actual path to your folder
save_path = folder_path
save_dir = folder_path
num_layer = 1
num_unit = 20  # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
architecture = 'LSTM'  # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
is2Dor3D = 'm-phate'  # or 'm-phate 2D'
epoch = 200
original_timesteps = 600
timesteps = 100 # !!!!!!!!!!
epoch_sample_step_after_epoch_30 = 5 # !!!!!!!!!!
entropies = np.empty(epoch)
epoch_samples = np.concatenate([np.arange(29), np.arange(29, epoch, epoch_sample_step_after_epoch_30)])
intrinsic_step_samples = np.linspace(0, original_timesteps - 1, timesteps, endpoint=True, dtype=int)
np.save(folder_path + '/epoch_samples.npy', epoch_samples)
file_pattern = folder_path + f'/1 layer_{num_unit} units_{architecture}__accuracy *_m-phate_most_active_output.npy'
file_list = glob.glob(file_pattern)
most_active_output = np.load(file_list[0])
# Load accuracy
validation_accuracy = np.load(save_dir+'/test_accuracy.npy')
training_accuracy = np.load(save_dir+'/accuracy.npy')
# Load loss
training_loss = np.load(save_dir+'/loss.npy')
validation_loss = np.load(save_dir+'/val_loss.npy')

most_active_output = np.reshape(most_active_output, (len(epoch_samples), timesteps, num_unit))

for idx, epoch_sample in enumerate(epoch_samples):
    print(epoch_sample)
    # for epoch_sample in [199]:
    most_active_output_sample = most_active_output[idx]

    change_timestep = np.zeros((num_unit, timesteps))
    for timestep in range(1, timesteps):
        for unit in range(num_unit):
            if most_active_output_sample[timestep, unit] != most_active_output_sample[timestep - 1, unit]:
                change_timestep[unit][timestep] = 1


# 3D M-PHATE
import scprep
import numpy as np
import matplotlib.pyplot as plt

intrinsic_steps = timesteps
# epoch_samples = np.arange(40)
save_dir = folder_path + '/'
file_pattern = save_dir + f'1 layer_{num_unit} units_{architecture}__accuracy *_m-phate 3D.npy'
m_phate_data = np.load(glob.glob(file_pattern)[0])
file_pattern = save_dir + f'1 layer_{num_unit} units_{architecture}__accuracy *_m-phate_intrinsic_step.npy'
intrinsic_step = np.load(glob.glob(file_pattern)[0])
file_pattern = save_dir + f'1 layer_{num_unit} units_{architecture}__accuracy *_m-phate_hidden_unit.npy'
unit = np.load(glob.glob(file_pattern)[0])
file_pattern = save_dir + f'1 layer_{num_unit} units_{architecture}__accuracy *_m-phate_most_active_output.npy'
most_active_output = np.load(glob.glob(file_pattern)[0])
file_pattern = save_dir + f'1 layer_{num_unit} units_{architecture}__accuracy *_m-phate_epoch_label.npy'
epoch_label = np.load(glob.glob(file_pattern)[0])


# plot the result in 3D
def plot_scatter3d(data, color_array, label_prefix, legend_title, file_name, figsize=(9, 7.5), dpi=300, cmap='inferno'):
    fig = plt.figure(figsize=figsize, dpi=dpi)
    ax = fig.add_subplot(111, projection='3d')
    scatter = ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=color_array, cmap=cmap, s=2)

    # Create colorbar
    cbar = fig.colorbar(scatter, pad=0.1)
    cbar.set_label(legend_title)  # Adjust the label padding here

    # Label axes
    ax.set_xlabel(f'{label_prefix} 1')
    ax.set_ylabel(f'{label_prefix} 2')
    ax.set_zlabel(f'{label_prefix} 3')

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    # Save and show plot
    plt.savefig(os.path.join(save_dir, file_name))
    #plt.show()


plot_scatter3d(m_phate_data, intrinsic_step, "MM-PHATE", "Time-step", 'intrinsic step_3D.png')
plt.close()

plot_scatter3d(m_phate_data, unit, "MM-PHATE", "Hidden Unit", 'hidden unit_3D.png', cmap='tab20b')
plt.close()

plot_scatter3d(m_phate_data, most_active_output, "MM-PHATE", "Most Active Output", 'most active output_3D.png',
               cmap='Paired')
plt.close()

plot_scatter3d(m_phate_data, epoch_label, "MM-PHATE", "Epoch", 'epoch_3D.png')
plt.close()

units = num_unit
intrinsic_steps = timesteps
epochs = epoch
dimension = 3
m_phate_data_reshaped_reshaped = m_phate_data.reshape(len(epoch_samples), intrinsic_steps, units, dimension)


def global_feature_min_max_scaling(data):
    """
    Apply Min-Max scaling to each feature across all epochs, intrinsic steps, and units.

    Parameters:
    - data: Numpy array of shape (n_epochs, n_intrinsic_steps, n_units, n_features).

    Returns:
    - Scaled data with the same shape as input.
    """
    # Reshape data to collapse all but the last dimension
    reshaped_data = data.reshape(-1, data.shape[-1])

    # Compute the global min and max for each feature
    min_val = np.min(reshaped_data, axis=0)
    max_val = np.max(reshaped_data, axis=0)

    # Avoid division by zero for constant features
    scale = max_val - min_val
    scale[scale == 0] = 1

    # Scale the reshaped data and then reshape it back to its original shape
    scaled_data = (reshaped_data - min_val) / scale
    scaled_data = scaled_data.reshape(data.shape)

    return scaled_data


scaled_data = global_feature_min_max_scaling(m_phate_data_reshaped_reshaped)


# scaled_data = m_phate_data_reshaped_reshaped

def scotts_bandwidth(data):
    """
    Calculate the bandwidth using Scott's Rule.

    Parameters:
    - data: Numpy array of shape (n_samples, n_features).

    Returns:
    - Bandwidth for KDE.
    """
    n, d = data.shape
    std = np.mean(np.std(data, axis=0))  # Average standard deviation across dimensions
    bandwidth = np.power(n, -1. / (d + 4)) * std
    return bandwidth


# Calculate a good bandwidth with Scott's rule
bandwidths = []
n_epochs, n_steps, _, _ = scaled_data.shape
for epoch in range(n_epochs):
    for step in range(n_steps):
        bandwidths.append(scotts_bandwidth(scaled_data[epoch, step]))
bandwidth = np.average(bandwidths)
print(bandwidth)


def kde_entropy(data, bandwidth):
    """
    Estimate the entropy of a dataset using Kernel Density Estimation (KDE).

    Parameters:
    - data: Numpy array of shape (n_samples, n_features), where each row represents a data point.
    - bandwidth: The bandwidth of the kernel. This parameter greatly affects the estimate.

    Returns:
    - The estimated entropy of the data.
    """
    # Initialize the KDE with a Gaussian kernel and fit it to the data
    kde = KernelDensity(kernel='gaussian', bandwidth=0.5).fit(data)

    # Evaluate the log density model on the data (log(pdf(data)))
    log_pdf = kde.score_samples(data)

    # The entropy is the negative average of the log density
    entropy = -np.mean(log_pdf)
    return entropy


print(scaled_data.shape)

entropy_matrix = np.zeros((n_epochs, n_steps))
entropy_x = np.zeros((n_epochs, n_steps))
entropy_y = np.zeros((n_epochs, n_steps))
entropy_z = np.zeros((n_epochs, n_steps))

np.random.seed(42)
for epoch in range(n_epochs):
    for step in range(n_steps):
        entropy_matrix[epoch, step] = kde_entropy(scaled_data[epoch, step], bandwidth)
        entropy_x[epoch, step] = kde_entropy(scaled_data[epoch, step, :, 0].reshape(-1, 1), bandwidth)
        entropy_y[epoch, step] = kde_entropy(scaled_data[epoch, step, :, 1].reshape(-1, 1), bandwidth)
        entropy_z[epoch, step] = kde_entropy(scaled_data[epoch, step, :, 2].reshape(-1, 1), bandwidth)
np.save(save_dir + '/' + 'entropy_across_units_in_time', entropy_matrix)

# Creating a colormap for the intrinsic time steps
intrinsic_cmap = plt.cm.get_cmap('inferno', original_timesteps)
unit_cmap = plt.cm.get_cmap('tab20b', num_unit)
norm = Normalize(vmin=0, vmax=original_timesteps-1)
timestep_sm = ScalarMappable(cmap=intrinsic_cmap, norm=norm)
timestep_sm.set_array([])
norm = Normalize(vmin=0, vmax=num_unit-1)
unit_sm = ScalarMappable(cmap=unit_cmap, norm=norm)
unit_sm.set_array([])
# Plotting
fig, ax1 = plt.subplots(figsize=(12, 8))

# Left Y-axis
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Intra-step Entropy')
for i, the_step in enumerate(intrinsic_step_samples):
    ax1.plot(epoch_samples, entropy_matrix[:, i], color=intrinsic_cmap(the_step))
ax1.tick_params(axis='y')

# Right Y-axis (Validation Accuracy)
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
color = 'tab:green'
ax2.set_ylabel('Accuracy', color=color)  # we already handled the x-label with ax1
ax2.plot(epoch_samples, validation_accuracy[epoch_samples], color=color, label='val')
ax2.plot(epoch_samples, training_accuracy[epoch_samples], '--', color=color, label='train')
ax2.tick_params(axis='y', labelcolor=color)
# Add colorbar for the intrinsic step
cbar = fig.colorbar(timestep_sm, ax=ax2, pad=0.14)
cbar.set_label('Time-step')
# Set the colorbar ticks to only show integers
cbar.locator = MaxNLocator(integer=True)
cbar.update_ticks()
fig.tight_layout()  # to make sure that the labels don't get cut off
plt.legend()
# plt.title('Epoch vs. Entropy and Loss')
plt.savefig(save_dir + '/' + 'epoch_vs_entropy_and_accuracy.png', dpi=300)
#plt.show()
plt.close()
# Plotting
fig, ax1 = plt.subplots(figsize=(12, 8))

# Left Y-axis
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Intra-step Entropy')
for i, the_step in enumerate(intrinsic_step_samples):
    ax1.plot(epoch_samples, entropy_matrix[:, i], color=intrinsic_cmap(the_step))
ax1.tick_params(axis='y')

# Right Y-axis (Validation Accuracy)
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
color = 'tab:green'
ax2.set_ylabel('Loss', color=color)  # we already handled the x-label with ax1
ax2.plot(epoch_samples, validation_loss[epoch_samples], color=color, label='val')
ax2.plot(epoch_samples, training_loss[epoch_samples], '--', color=color, label='train')
ax2.tick_params(axis='y', labelcolor=color)
# Add colorbar for the intrinsic step
cbar = fig.colorbar(timestep_sm, ax=ax2, pad=0.14)
cbar.set_label('Time-step')
# Set the colorbar ticks to only show integers
cbar.locator = MaxNLocator(integer=True)
cbar.update_ticks()
fig.tight_layout()  # to make sure that the labels don't get cut off
plt.legend()
# plt.title('Epoch vs. Entropy and Loss')
plt.savefig(save_dir + '/' + 'epoch_vs_entropy_and_loss.png', dpi=300)
#plt.show()
plt.close()


entropy_across_time_matrix = np.zeros((n_epochs, num_unit))

np.random.seed(42)
for epoch in range(n_epochs):
    for unit in range(num_unit):
        entropy_across_time_matrix[epoch, unit] = kde_entropy(scaled_data[epoch, :, unit], bandwidth)
np.save(save_dir + '/' + 'entropy_across_time', entropy_across_time_matrix)

# Plotting
fig, ax1 = plt.subplots(figsize=(12, 8))

# Left Y-axis
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Inter-step Entropy')
for the_unit in range(num_unit):
    ax1.plot(epoch_samples, entropy_across_time_matrix[:, the_unit], color=unit_cmap(the_unit))
ax1.tick_params(axis='y')

# Right Y-axis (Validation Accuracy)
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
color = 'tab:green'
ax2.set_ylabel('Accuracy', color=color)  # we already handled the x-label with ax1
ax2.plot(epoch_samples, validation_accuracy[epoch_samples], color=color, label='val')
ax2.plot(epoch_samples, training_accuracy[epoch_samples], '--', color=color, label='train')
ax2.tick_params(axis='y', labelcolor=color)
# Add colorbar for the intrinsic step
cbar = fig.colorbar(unit_sm, ax=ax2, pad=0.14)
cbar.set_label('Hidden Unit')
# Set the colorbar ticks to only show integers
cbar.locator = MaxNLocator(integer=True)
cbar.update_ticks()
fig.tight_layout()  # to make sure that the labels don't get cut off
plt.legend()
# plt.title('Epoch vs. Entropy and Loss')
plt.savefig(save_dir + '/' + 'epoch_vs_entropy_across_time_and_accuracy.png', dpi=300)
#plt.show()
plt.close()
# Plotting
fig, ax1 = plt.subplots(figsize=(12, 8))

# Left Y-axis
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Inter-step Entropy')
for the_unit in range(num_unit):
    ax1.plot(epoch_samples, entropy_across_time_matrix[:, the_unit], color=unit_cmap(the_unit))
ax1.tick_params(axis='y')

# Right Y-axis (Validation Accuracy)
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
color = 'tab:green'
ax2.set_ylabel('Loss', color=color)  # we already handled the x-label with ax1
ax2.plot(epoch_samples, validation_loss[epoch_samples], color=color, label='val')
ax2.plot(epoch_samples, training_loss[epoch_samples], '--', color=color, label='train')
ax2.tick_params(axis='y', labelcolor=color)
# Add colorbar for the intrinsic step
cbar = fig.colorbar(unit_sm, ax=ax2, pad=0.14)
cbar.set_label('Hidden Unit')
# Set the colorbar ticks to only show integers
cbar.locator = MaxNLocator(integer=True)
cbar.update_ticks()
fig.tight_layout()  # to make sure that the labels don't get cut off
plt.legend()
# plt.title('Epoch vs. Entropy and Loss')
plt.savefig(save_dir + '/' + 'epoch_vs_entropy_across_time_and_loss.png', dpi=300)
#plt.show()
plt.close()
import numpy as np
from tslearn.clustering import TimeSeriesKMeans
from tslearn.metrics import dtw
import tslearn
from sklearn.metrics import silhouette_score

data = entropy_across_time_matrix.T
# Generate index for 1 out of every 5 epochs in the first 30
index_first_30 = np.arange(0, 30, 5)

# Generate index for the rest of the epochs, starting from 30
index_rest = np.arange(30, data.shape[1])

# Concatenate the two indices
full_index = np.concatenate((index_first_30, index_rest))

# Use this index to select epochs from your data
# data = data[:, full_index]


# Number of clusters you want to identify, based on your observation it's 2
n_clusters = 2

# Create a Time Series KMeans model with DTW metric
model = TimeSeriesKMeans(n_clusters=n_clusters, metric="dtw", verbose=True, max_iter=10000000000)

# Fit the model on your data
labels = model.fit_predict(data)
np.save(save_dir + '/' + 'cluster_entropy_across_time_label', labels)

# labels now contains the cluster assignments for each unit
print("Cluster assignments:", labels)


# Calculate the Silhouette Score
distance_matrix = tslearn.metrics.cdist_dtw(data)
silhouette_avg = silhouette_score(distance_matrix, labels, metric="precomputed")

# Start plotting
for yi in range(n_clusters):
    plt.subplot(2, 1, yi + 1)
    for xx in data[labels == yi]:
        plt.plot(epoch_samples, xx.ravel(), "k-", alpha=.2)
    plt.plot(epoch_samples, model.cluster_centers_[yi].ravel(), "r-")
    plt.xlim(epoch_samples[0], epoch_samples[-1])
    plt.ylim(data.min(), data.max())
    # Here you add the silhouette score to the title or as a text annotation
    plt.title(f"Cluster {yi}")
plt.tight_layout()
plt.savefig(save_dir + '/' + 'cluster_entropy_across_time.png', dpi=300)
#plt.show()
plt.close()
# Plotting
fig, ax1 = plt.subplots(figsize=(12, 8))

# Left Y-axis
ax1.set_xlabel('Epoch', fontsize=43)
ax1.set_ylabel('Entropy', fontsize=43)
# ax1.plot(epoch_samples[full_index], model.cluster_centers_[0].ravel(), "r-", label = f'MAO Changes ={result_0}')
# ax1.plot(epoch_samples[full_index], model.cluster_centers_[1].ravel(), "r--", label = f'MAO Changes ={result_1}')
line1, = ax1.plot(epoch_samples, model.cluster_centers_[0].ravel(), "r-", label='Cluster 0', linewidth=4)
line2, = ax1.plot(epoch_samples, model.cluster_centers_[1].ravel(), "r--", label='Cluster 1', linewidth=4)
ax1.tick_params(axis='y', labelsize=26)

# Right Y-axis (Validation Accuracy)
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
color = 'tab:green'
ax2.set_ylabel('Loss', color=color, fontsize=43)  # we already handled the x-label with ax1
line3, = ax2.plot(epoch_samples, validation_loss[epoch_samples], color=color, label='val', linewidth=4)
line4, = ax2.plot(epoch_samples, training_loss[epoch_samples], '--', color=color, label='train', linewidth=4)
ax2.tick_params(axis='y', labelcolor=color, labelsize=26)

lines = [line1, line2, line3, line4]
labels = [line.get_label() for line in lines]
fig.legend(lines, labels, loc='upper right', bbox_to_anchor=(0.85, 0.95), fontsize=34)
fig.tight_layout()  # to make sure that the labels don't get cut off
plt.title(f"Cluster Center (Silhouette Score: {silhouette_avg:.2f})", fontsize=43)
plt.savefig(save_dir + '/' + 'cluster_center_and_loss.png', dpi=300)
#plt.show()
plt.close()