import torch
import h5py
from einops import rearrange
import seaborn as sns
import umap
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


# Load Data from HDF5
# def load_h5_dataset(h5_file):
#     with h5py.File(h5_file, "r") as f:
#         actor_latent = np.array(f["actor_latent"])
#         critic_latent = np.array(f["critic_latent"])
#         proprioceptive_obs = np.array(f["proprioceptive_obs"])
#     return actor_latent, critic_latent, proprioceptive_obs

def load_h5_dataset_kine(h5_file):
    with h5py.File(h5_file, "r") as f:
        actor_latent_list = []
        for i in range(4):
            actor_latent_list.append(np.array(f[f"actor_latent_{i}"]))
        # critic_latent = np.array(f["critic_latent"])
    return actor_latent_list


def sns_jointplot(embedding, colors, title, save_dir="./logs/analysis/plots/"):
    # Convert to DataFrame for Seaborn
    import pandas as pd
    embedding_df = pd.DataFrame(embedding, columns=["Dim 1", "Dim 2"])
    embedding_df["Task"] = colors

    # Plot using sns.jointplot
    plot = sns.jointplot(
        data=embedding_df,
        x="Dim 1",
        y="Dim 2",
        hue="Task",
        kind="scatter",
        palette="Set1",
        edgecolor="black",
        linewidth=0.1,
        alpha=0.5,
    )

    plot.fig.suptitle(title, fontsize=14)

    plot.fig.tight_layout()  # first tighten layout
    plot.fig.subplots_adjust(top=0.95)  # move title up (1.0 = top edge of figure)

    # Save the plot to an image file with specified DPI
    fig_name = save_dir + title.replace(" ", "_") + ".png"
    plot.fig.savefig(fig_name, dpi=500)
    print(f"Plot saved as '{fig_name}'")


def main():

    # set the umap parameters
    n_neighbors = 200
    min_dist = 0.2
    random_state = 42

    source_tasks = ["loco", "pedi"]
    # policy_types = ["loco", "pedi", "scratch"]
    policy_types = ["kine"]

    for policy_type in policy_types:
        acl_dist_list = []
        ccl_list = []
        proprioceptive_obs_list = []
        for task in source_tasks:
            actor_latent = load_h5_dataset_kine(f"./logs/analysis/{task}_obs_in_{policy_type}_latent.h5")
            acl_dist_list.append(actor_latent)
            # ccl_list.append(critic_latent)
            # proprioceptive_obs_list.append(proprioceptive_obs)
        length = acl_dist_list[0][0].shape[0]
        concated_list = [np.concatenate(latent_same_layer, axis=0) for latent_same_layer in zip(*acl_dist_list)]
        

        # Combine actor latent spaces
        
        # actor_latent = np.concatenate(acl_dist_list, axis=1) # shape (n_layers, n_samples, 128)
        # critic_latent = np.concatenate(ccl_list, axis=1) # shape (n_layers, n_samples, 128)
        # proprioceptive_obs = np.concatenate(proprioceptive_obs_list, axis=0) # shape (n_samples, 33)


        for i in tqdm(range(len(concated_list)+1)):
            if i == 0:
                continue
                # Use proprioceptive observations
                latent = proprioceptive_obs
                latent_space_description = "raw_proprioceptive"

                # Reduce dimensionality to 2D using UMAP
                reducer = umap.UMAP(
                    n_neighbors=n_neighbors,
                    min_dist=min_dist,
                    random_state=random_state,
                )
                embedding = reducer.fit_transform(latent)

                # Create color labels
                colors = [source_tasks[0] if x < length else source_tasks[1] for x in range(len(embedding))]

                # Plot using Seaborn
                title = f"{latent_space_description} space"
                sns_jointplot(embedding, colors, title)

            else:
                # Use actor or critic latent space
                latent = concated_list[i-1]
                latent_space_description = f"latent_{i}"
                # critic_latent_single_layer = critic_latent[i-1]

                # for (latent, actor_or_critic) in zip([actor_latent_single_layer, critic_latent_single_layer], ["actor", "critic"]):

                # Reduce dimensionality to 2D using UMAP
                reducer = umap.UMAP(
                    n_neighbors=n_neighbors,
                    min_dist=min_dist,
                    random_state=random_state,
                )
                embedding = reducer.fit_transform(latent)

                # Create color labels
                colors = [source_tasks[0] if x < length else source_tasks[1] for x in range(len(embedding))]

                # Plot using Seaborn
                title = f"{policy_type} {latent_space_description} latent space"
                sns_jointplot(embedding, colors, title)






if __name__ == "__main__":
    # Run the main function
    main()
