import torch
import h5py
from einops import rearrange
import seaborn as sns
import umap
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from rsl_rl.rsl_rl.addons.dynamics_analysis.data_utils import DynamicsAnalysisDataset



def sns_jointplot(embedding, colors, title, save_dir="./logs/analysis/plots/"):
    """
    embedding: 2D numpy array of shape (n_samples, 2)
    colors: (n_samples,) list of labels for each sample
    title: str, title of the plot
    save_dir: str, directory to save the plot
    """
    # Convert to DataFrame for Seaborn
    import pandas as pd
    embedding_df = pd.DataFrame(embedding, columns=["Dim 1", "Dim 2"])
    embedding_df["File"] = colors

    # Plot using sns.jointplot
    plot = sns.jointplot(
        data=embedding_df,
        x="Dim 1",
        y="Dim 2",
        hue="File",
        kind="scatter",
        palette="Set1",
        edgecolor="black",
        linewidth=0.1,
        alpha=0.4,
    )

    plot.fig.suptitle(title, fontsize=14)

    plot.fig.tight_layout()  # first tighten layout
    plot.fig.subplots_adjust(top=0.95)  # move title up (1.0 = top edge of figure)

    # Save the plot to an image file with specified DPI
    fig_name = save_dir + title.replace(" ", "_") + ".png"
    plot.fig.savefig(fig_name, dpi=500)
    print(f"Plot saved as '{fig_name}'")



def main():

    # set the umap parameters
    n_neighbors = 10
    min_dist = 0.0
    random_state = 42
    num_samples_per_file = 200

    dataset_path = "logs/datasets/dynamics_analysis"
    ds = DynamicsAnalysisDataset(dataset_path)
    # generator = torch.Generator().manual_seed(42)

    samples_list = []
    color_list = []
    for i in range(4):
        st, at, stp1 = ds.get_batch_from_file_by_index(0, num_samples=num_samples_per_file)
        samples = torch.cat([st, at[:, 0]], dim=-1)
        samples_list.append(samples)
        color_list.append([i] * num_samples_per_file)
    samples = torch.cat(samples_list, dim=0).cpu().numpy()
    colors = np.concatenate(color_list, axis=0)

    # Reduce dimensionality to 2D using UMAP
    reducer = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        random_state=random_state,
    )
    embedding = reducer.fit_transform(samples)

    # Plot using Seaborn
    title = f"Dynamics Analysis Data UMAP visulization"
    sns_jointplot(embedding, colors, title)


if __name__ == "__main__":
    # Run the main function
    main()
