# This file visualizes features using tsne
"""
Description of the different types of plots:
1. animation: An animation of the t-SNE results over time
2. frames: A large plot with 3 x 5 subplots, each showing the t-SNE results at a different time step
3. trajectory: A plot showing the trajectory of the average t-SNE results over time (average overall all samples with same label)
4. traj-anim: A trajectory animation for average representation of different samples with the same label
5. traj-frames: A trajectory plot with each frame showing the average representation of different samples with the same label


"""

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import matplotlib.animation as animation

actions_dict = { #NOTE: This is zero indexed, the one in train.py starts at 1 index
    0: 'Swipe left',
    1: 'Swipe right',
    2: 'Wave',
    3: 'Clap',
    4: 'Throw',
    5: 'Arm cross',
    6: 'Basketball shoot',
    7: 'Draw X',
    8: 'Draw circle (clockwise)',
    9: 'Draw circle (counter clockwise)',
    10: 'Draw triangle',
    11: 'Bowling',
    12: 'Boxing',
    13: 'Baseball swing',
    14: 'Tennis swing',
    15: 'Arm curl',
    16: 'Tennis serve',
    17: 'Push',
    18: 'Knock',
    19: 'Catch',
    20: 'Pickup and throw',
    21: 'Jog',
    22: 'Walk',
    23: 'Sit to stand',
    24: 'Stand to sit',
    25: 'Lunge',
    26: 'Squat'
}

def plot_time_tsne(z, labels, title = 'z', plot_type = "animation", filter_few = True, pca_n_components = 30, tsne_n_components = 2, tsne_random_state = 42, filter_labels = [1,11,18]):
    # Assuming z and labels are already defined numpy arrays
    # z shape: (44, 15, 1024)
    # labels shape: (44,)

    if filter_few:
        # Filter the vectors and labels where labels are 1, 11, or 18
        all_labels = filter_labels
        mask = np.isin(labels, all_labels)
        z_filtered = z[mask]
        labels_filtered = labels[mask]
        z = z_filtered
        labels = labels_filtered
        pca_n_components = 5
    else:
        all_labels = np.unique(labels)
    print("Shape of z: ", z.shape)

    n,t,d = z.shape # n: number of samples, t: number of time steps, d: dimension of each time step
    z = z.reshape(n*t,d)

    # Step 1: Reduce dimensionality with PCA for each time step
    pca = PCA(n_components=pca_n_components)  # Reduce to 30 dimensions first
    # z_pca = np.array([pca.fit_transform(z[:, t, :]) for t in range(t)])
    z_pca = pca.fit_transform(z)

    # Step 2: Apply t-SNE for each time step
    tsne = TSNE(n_components=tsne_n_components, random_state=tsne_random_state)
    # z_tsne = np.array([tsne.fit_transform(z_pca[t]) for t in range(t)])
    z_tsne = tsne.fit_transform(z_pca)

    # z_tsne = z_tsne.reshape(t, n, tsne_n_components) # not the same thing, need to preserve time order
    z_tsne = z_tsne.reshape(n, t, tsne_n_components)
    z_tsne = z_tsne.transpose(1,0,2)

    color_dict = {i: plt.cm.gist_ncar(i*256//len(actions_dict)) for i in range(len(actions_dict))}
    # color_dict = {i: plt.cm.gist_ncar(30) for i in range(len(actions_dict))}
    # print(color_dict)

    if plot_type == "animation":
        # Step 3: Create an animation of the t-SNE results over time
        fig, ax = plt.subplots(figsize=(10, 8))

        def update_plot(frame):
            ax.clear()
            scatter = ax.scatter(z_tsne[frame][:, 0], z_tsne[frame][:, 1], c=labels, cmap='tab20')
            ax.set_title(f't-SNE Visualization of z Vectors - Frame {frame + 1}')
            ax.set_xlabel('t-SNE Dimension 1')
            ax.set_ylabel('t-SNE Dimension 2')
            ax.set_xlim(np.min(z_tsne[:,:,0]), np.max(z_tsne[:,:,0]))
            ax.set_ylim(np.min(z_tsne[:,:,1]), np.max(z_tsne[:,:,1]))
            return scatter,

        ani = animation.FuncAnimation(fig, update_plot, frames=range(15), blit=False, repeat=True)

        # Save the animation as a GIF
        ani.save(f"TSNE_animation_{title}_{'filtered' if filter_few else ''}.gif", writer='imagemagick', fps=2)
        print(f"Saving animation as TSNE_animation_{title}_{'filtered' if filter_few else ''}.gif")

    elif plot_type == "frames":
        # Step 3: Create a large plot with 3 x 5 subplots
        fig, axs = plt.subplots(3, 5, figsize=(20, 12))
        fig.suptitle(f't-SNE Visualization of {title} Vectors Over Time', fontsize=16)

        for t in range(15):
            ax = axs[t // 5, t % 5]
            scatter = ax.scatter(z_tsne[t][:, 0], z_tsne[t][:, 1], c=labels, cmap='tab20')
            ax.set_title(f'Time Step {t + 1}')
            ax.set_xlabel('t-SNE Dimension 1')
            ax.set_ylabel('t-SNE Dimension 2')

        # Add a single colorbar for the entire figure
        cbar = fig.colorbar(scatter, ax=axs, orientation='vertical', fraction=.01)
        cbar.set_ticks(range(28))
        cbar.set_label('Labels')

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        plt.savefig(f"TSNE_frames_{title}_{'filtered' if filter_few else ''}.png")
        print(f"Saving frames as TSNE_frames_{title}_{'filtered' if filter_few else ''}.png")

    elif plot_type == "trajectory":
        # Step 3: Create a trajectory plot
        fig, ax = plt.subplots(figsize=(10, 8))
        
        # for i in range(n):
        #     traj_plot = ax.plot(z_tsne[:, i, 0], z_tsne[:, i, 1], c=color_dict[labels[i]])
        
        # Average across all samples n at each time step t
        for ind,l in enumerate(all_labels):
            plt_x = np.mean(z_tsne[:,labels == l,0], axis=1) # average across all samples n, so this is a single trajectory n,1
            plt_y = np.mean(z_tsne[:,labels == l,1], axis=1)
            if ind==0:
                start_pt = ax.plot(plt_x[0], plt_y[0], 'o', c='black', label='starts')#color_dict[l], label=f'{actions_dict[l]} start')
                end_pt = ax.plot(plt_x[-1], plt_y[-1], 'x', c='black', label='ends')#color_dict[l], label=f'{actions_dict[l]} end')
            else:
                start_pt = ax.plot(plt_x[0], plt_y[0], 'o', c='black')
                end_pt = ax.plot(plt_x[-1], plt_y[-1], 'x', c='black')
            traj_plot = ax.plot(plt_x, plt_y, c=color_dict[l], label=actions_dict[l])

        # traj_plot = ax.plot(np.mean(z_tsne[:,:,0], axis=1), np.mean(z_tsne[:,:,1], axis=1), c='black')
        ax.set_title(f't-SNE Visualization of {title} Vectors Over Time')
        ax.set_xlabel('t-SNE Dimension 1')
        ax.set_ylabel('t-SNE Dimension 2')
        ax.legend()
        plt.savefig(f"TSNE_trajectory_{title}_{'filtered' if filter_few else ''}.png")
        print(f"Saving trajectory as TSNE_trajectory_{title}_{'filtered' if filter_few else ''}.png")
    elif plot_type == "traj-anim":
        # Step 3: Create a trajectory animation for average representation
        fig, ax = plt.subplots(figsize=(10, 8))
        
        all_plt_x = []
        all_plt_y = []
        for l in all_labels:
            all_plt_x.append(np.mean(z_tsne[:,labels == l,0], axis=1)) 
            all_plt_y.append(np.mean(z_tsne[:,labels == l,1], axis=1))

        def update_plot(frame):
            ax.clear()
            for ind,l in enumerate(all_labels):
                plt_x = all_plt_x[ind]
                plt_y = all_plt_y[ind]
                # if frame==0:
                #     start_pt = ax.plot(plt_x[0], plt_y[0], 'o', c='black', label='starts')
                #     end_pt = ax.plot(plt_x[-1], plt_y[-1], 'x', c='black', label='ends')
                # else:
                start_pt = ax.plot(plt_x[0], plt_y[0], 'o', c=color_dict[l])
                end_pt = ax.plot(plt_x[frame], plt_y[frame], 'x', c=color_dict[l])
                
                plt_x = plt_x[:frame+1]
                plt_y = plt_y[:frame+1]
                traj_plot = ax.plot(plt_x, plt_y, c=color_dict[l], label=actions_dict[l])

            ax.set_title(f't-SNE Visualization of {title} Vectors Over Time - Frame {frame + 1}')
            ax.set_xlabel('t-SNE Dimension 1')
            ax.set_ylabel('t-SNE Dimension 2')
            ax.legend()
            return traj_plot,

        ani = animation.FuncAnimation(fig, update_plot, frames=range(15), blit=False, repeat=True)
        # Save the animation as a GIF
        ani.save(f"TSNE_traj-anim_{title}_{'filtered' if filter_few else ''}.gif", writer='imagemagick', fps=2)
        print(f"Saving trajectory animation as TSNE_traj-anim_{title}_{'filtered' if filter_few else ''}.gif")
    elif plot_type == "traj-frames":
        # Step 3: Create a trajectory plot with each frame
        all_plt_x = []
        all_plt_y = []
        for l in all_labels:
            all_plt_x.append(np.mean(z_tsne[:,labels == l,0], axis=1)) 
            all_plt_y.append(np.mean(z_tsne[:,labels == l,1], axis=1))

        fig, axs = plt.subplots(3, 5, figsize=(20, 12))
        fig.suptitle(f't-SNE Visualization of {title} Vectors Over Time', fontsize=16)

        for t in range(15):
            ax = axs[t // 5, t % 5]
            for ind,l in enumerate(all_labels):
                plt_x = all_plt_x[ind]
                plt_y = all_plt_y[ind]
                if t==0 and ind==0:
                    start_pt = ax.plot(plt_x[0], plt_y[0], 'o', c='black', label='starts')
                    end_pt = ax.plot(plt_x[t], plt_y[t], 'x', c='black', label='current/end')
                start_pt = ax.plot(plt_x[0], plt_y[0], 'o', c='black')
                end_pt = ax.plot(plt_x[t], plt_y[t], 'x', c='black')
                
                # extract the part we want to plot at this time step
                plt_x = plt_x[:t+1]
                plt_y = plt_y[:t+1]
                if t==0:
                    traj_plot = ax.plot(plt_x, plt_y, c=color_dict[l], label=actions_dict[l])
                else:
                    traj_plot = ax.plot(plt_x, plt_y, c=color_dict[l])

            ax.set_title(f'Time Step {t + 1}')
            ax.set_xlabel('t-SNE Dimension 1')
            ax.set_ylabel('t-SNE Dimension 2')
        
        fig.legend()
        # # Add a single colorbar for the entire figure
        # cbar = fig.colorbar(traj_plot, ax=axs, orientation='vertical', fraction=.01)
        # cbar.set_ticks(range(28))
        # cbar.set_label('Labels')

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        plt.savefig(f"TSNE_traj-frames_{title}_{'filtered' if filter_few else ''}.png")
        print(f"Saving trajectory frames as TSNE_traj-frames_{title}_{'filtered' if filter_few else ''}.png")



def plot_single_tsne(z_rgb, labels):
    # Assuming z shape: (44,1024) and labels shape: (44,)
    z = z_rgb[:,:]

    # Step 1: Reduce dimensionality with PCA (optional but recommended)
    pca = PCA(n_components=30)  # Reduce to 50 dimensions first
    z_pca = pca.fit_transform(z)

    # Step 2: Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    z_tsne = tsne.fit_transform(z_pca)

    # Step 3: Plot the t-SNE results
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(z_tsne[:, 0], z_tsne[:, 1], c=labels, cmap='tab20')
    plt.colorbar(scatter, ticks=range(28))
    plt.title('t-SNE Visualization of z Vectors')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.savefig("TSNE_z.png")

if __name__ == '__main__':
    # read nmpy files from latent_representations directory
    latent_representations = []
    # z_rgb_file = '/home/ANON/toy_HAR/ANON/latent_representations/keep_time_test_set/toy-RGB-IMU-HAR-cross_modal1_keep_time_rgb.npy'
    # z_imu_file = '/home/ANON/toy_HAR/ANON/latent_representations/keep_time_test_set/toy-RGB-IMU-HAR-cross_modal1_keep_time_imu.npy'
    # labels_file = '/home/ANON/toy_HAR/ANON/latent_representations/keep_time_test_set/toy-RGB-IMU-HAR-cross_modal1_keep_time_labels.npy'
    # z_rgb_file = '/home/ANON/toy_HAR/ANON/latent_representations/keep_time_train_set/toy-RGB-IMU-HAR-cross_modal1_keep_time_rgb.npy'
    # z_imu_file = '/home/ANON/toy_HAR/ANON/latent_representations/keep_time_train_set/toy-RGB-IMU-HAR-cross_modal1_keep_time_imu.npy'
    # labels_file = '/home/ANON/toy_HAR/ANON/latent_representations/keep_time_train_set/toy-RGB-IMU-HAR-cross_modal1_keep_time_labels.npy'
    z_rgb_file = '/home/ANON/toy_HAR/ANON/latent_representations/manual_imu.npy'
    z_imu_file = '/home/ANON/toy_HAR/ANON/latent_representations/manual_imu.npy'
    labels_file = '/home/ANON/toy_HAR/ANON/latent_representations/manual_imu.npy'

    z_rgb = np.load(z_rgb_file)
    z_imu = np.load(z_imu_file)
    labels = np.load(labels_file)
    print("Shape of z_rgb: ", z_rgb.shape)
    print("Shape of z_imu: ", z_imu.shape)
    print("Shape of labels: ", labels.shape)
    # z_rgb = (z_imu+ z_rgb)/2

    plot_single_tsne(z_rgb, labels)

    # plot_type = "traj-anim" # "frames" or "animation" or "trajectory" or "traj-anim" or "traj-frames"
    # filter_few = True
    # plot_time_tsne(z_rgb, labels, title = 'z_rgb', plot_type = plot_type, filter_few = filter_few)
    # plot_time_tsne(z_imu, labels, title = 'z_imu', plot_type = plot_type, filter_few = filter_few)
    # plot_time_tsne((z_imu+z_rgb)/2, labels, title = 'z_combined', plot_type = plot_type, filter_few = filter_few)



