# P4RL: this is a data sample visualization script. 

import seaborn as sns
import torch
import pandas as pd
import seaborn as sns
import umap
from rsl_rl.addons.invdynamics.inv_dynamics_utils import DynamicSlidingWindowDataset
from rsl_rl.addons.invdynamics.inv_dynamics_module import build_mlp
from typing import Dict, Callable
import numpy as np
from sklearn.decomposition import PCA
import pickle
from rsl_rl.addons.invdynamics.inv_dynamics_dataset_paths import dataset_paths
import os



def sns_jointplot(embedding, colors, title, palette=None, save_dir="./logs/analysis/plots/", color_notation="Dataset"):
    # Convert to DataFrame for Seaborn
    embedding_df = pd.DataFrame(embedding, columns=["Dim 1", "Dim 2"])
    embedding_df[color_notation] = colors

    # Plot using sns.jointplot
    plot = sns.jointplot(
        data=embedding_df,
        x="Dim 1",
        y="Dim 2",
        hue=color_notation,
        kind="scatter",
        palette=palette if palette else "Set1",
        edgecolor="black",
        linewidth=0,
        s=20, # Size of the points
        alpha=0.5,
    )

    # plot.figure.suptitle(title, fontsize=14)

    plot.ax_joint.set_xticks([])  # Remove x ticks
    plot.ax_joint.set_yticks([])  # Remove y ticks
    plot.ax_joint.set_xlabel("")  # Remove x label
    plot.ax_joint.set_ylabel("")  # Remove y label


    # Move legend beneath the plot
    handles, labels = plot.ax_joint.get_legend_handles_labels()
    plot.ax_joint.legend(
        handles=handles,
        labels=labels,
        loc='upper center',       # relative to bbox
        bbox_to_anchor=(0.5, -0.01),  # x=0.5 center, y=-0.1 below plot
        ncol=1,         # put all legend items in a row
        frameon=False, 
        fontsize=15,
    )

    plot.figure.tight_layout()  # first tighten layout
    plot.figure.subplots_adjust(top=0.95)  # move title up (1.0 = top edge of figure)

    # Save the plot to an image file with specified DPI
    fig_name = os.path.join(save_dir, title.replace(" ", "_") + ".pdf")
    os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists
    plot.figure.savefig(fig_name)  # No need for dpi with PDF
    print(f"Plot saved as '{fig_name}'")


class InvSamplesVisualization:
    def __init__(self, n_neighbors=15, min_dist=0.1, random_seed=42, use_PCA=False):
        self.use_PCA = use_PCA
        if use_PCA:
            self.reducer = PCA(n_components=2, random_state=random_seed)
        else:
            self.reducer = umap.UMAP(
                        n_neighbors=n_neighbors,
                        min_dist=min_dist,
                        random_state=random_seed,
                    )
        torch.manual_seed(random_seed)
        self.random_seed = random_seed
        self.mean = None
        self.std = None


    def visualize_samples_hue_dataset(self, 
                                      dict_dataset_paths: Dict[str, str], 
                                      vis_samples_per_dataset: Dict = {}, 
                                      fit=True, 
                                      dataset_palette=None, ): 
        samples_for_vis: Dict[str, np.array] = {}
        # for each dataset path, load the data and apply UMAP
        for dataset_name, dataset_path in dict_dataset_paths.items():
            print(f"Processing dataset: {dataset_name}")

            file_size_gb = os.path.getsize(dataset_path) / (1024 ** 3)
            load_into_memory = file_size_gb <= 4  # True if <= 4 GB, else False

            dataset = DynamicSlidingWindowDataset(
                h5_path=dataset_path,
                window_size=1,
                load_into_memory=load_into_memory
            )

            x, a = dataset.get_sample_entries_in_file(sample_num=vis_samples_per_dataset.get(dataset_name, 500), seed=self.random_seed)

            del dataset
            # inputs = torch.cat((x_t, x_tp1), dim=-1)  # Concatenate x_t and x_tp1 along the last dimension
            # inputs = x_t
            
            inputs = x.squeeze(1)

            # Use a random MLP as the projection kernel to normalize the data across dimensions
            # if not self.random_mlp:
            #     self.random_mlp = build_mlp(
            #         input_dims=inputs.shape[-1],
            #         hidden_dims=[128, 128],
            #         output_dims=128,
            #         activation_name="elu"
            #     )

            # representations = self.random_mlp(inputs)
            # representations = self.normalize_samples(inputs)
            # representations = inputs

            samples_for_vis[dataset_name] = inputs.detach().cpu().numpy()

        reps = np.concatenate(list(samples_for_vis.values()), axis=0)

        if fit:
            self.fit_normalization(reps)
            reps_normalized = self.normalize_samples(reps)
            vis_reps = self.reducer.fit_transform(reps_normalized)
            if type(self.reducer) is PCA:
                print(f"Explained variance ratio: {self.reducer.explained_variance_ratio_}")
        else:
            reps_normalized = self.normalize_samples(reps)
            vis_reps = self.reducer.transform(reps_normalized)

        colors = np.concatenate([[dataset_tag] * samples.shape[0] for dataset_tag, samples in samples_for_vis.items()], axis=0)

        # !!! Due to special need: exclude the last dataset from the plot
        # colors = np.concatenate([[dataset_tag] * samples.shape[0] for dataset_tag, samples in list(samples_for_vis.items())[:-1]], axis=0)
        # vis_reps = vis_reps[:colors.shape[0]]

        # Plot the results
        sns_jointplot(vis_reps, colors, title=f"{'PCA' if self.use_PCA else 'UMAP'} visualizations of inverse dynamics dataset", palette = dataset_palette, save_dir="./logs/analysis/plots/")


    def visualize_samples_hue_function(self, dict_dataset_paths: Dict[str, str], hue_function: callable, vis_samples_per_dataset: Dict = {}, fit=True): 
        samples_for_vis: Dict[str, np.array] = {}
        # for each dataset path, load the data and apply UMAP
        hue_values_list = []
        for dataset_name, dataset_path in dict_dataset_paths.items():
            print(f"Processing dataset: {dataset_name}")

            dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=2)
            x, a = dataset.get_sample_entries_in_file(sample_num=vis_samples_per_dataset.get(dataset_name, 500), seed=self.random_seed)

            inputs = x[:, -1]  # Use the last state as input

            hues = hue_function(x, a)
            hue_values_list.append(hues.cpu().numpy())

            samples_for_vis[dataset_name] = inputs.detach().cpu().numpy()

        reps = np.concatenate(list(samples_for_vis.values()), axis=0)

        if fit:
            self.fit_normalization(reps)
            reps_normalized = self.normalize_samples(reps)
            vis_reps = self.reducer.fit_transform(reps_normalized)
            if type(self.reducer) is PCA:
                print(f"Explained variance ratio: {self.reducer.explained_variance_ratio_}")
        else:
            reps_normalized = self.normalize_samples(reps)
            vis_reps = self.reducer.transform(reps_normalized)

        if type(self.reducer) is PCA:
            print(f"Explained variance ratio: {self.reducer.explained_variance_ratio_}")
        colors = np.concatenate(hue_values_list, axis=0)
        # normalize colors to [0, 1] range for visualization
        print(f"Colors shape: {colors.shape}, min: {np.min(colors)}, max: {np.max(colors)}")
        colors = (colors - np.min(colors)) / (np.max(colors) - np.min(colors))

        # Plot the results
        sns_jointplot(vis_reps, colors, title=f"{'PCA' if self.use_PCA else 'UMAP'} visualizations of inverse dynamics dataset", palette = "rainbow", save_dir="./logs/analysis/plots/", color_notation="Intrinsic Reward")


    def fit_normalization(self, samples: np.array):
        """
        Fit the normalization parameters (mean and std) based on the provided samples.
        """
        self.mean = np.mean(samples, axis=0, keepdims=True)
        self.std = np.std(samples, axis=0, keepdims=True)

    def normalize_samples(self, samples: np.array):
        """
        Normalize the samples to have zero mean and unit variance.
        """
        normalized_samples = (samples - self.mean) / self.std
        return normalized_samples
    




if __name__ == "__main__":
    # Example usage

    visualize_list = [
        # "Pedipulation Init",
        # "Velocity Init",
        # "Pedipulation Training",
        # "Velocity Training",
        # # "Pedipulation Expert",
        # # "Velocity Expert",
        # # "RND",
        # # "INV Ensemble Exploration",
        # # "RLE exploration",
        # "Pedipulation Init (no random)",


        # "Pedi 100",
        # "Velo 100",
        # "INV collected",

        "Pedi 100 Absolute",
        "Velo 100 Absolute",
        "Exploration Flat",

    ]

    label_list = [
        
        "Pedipulation (First 100 RL Iterations)",
        "Velocity (First 100 RL Iterations)",
        "Exploration-Based Data Collection (Flat)",
        # "Exploration-Based Data Collection (Flat)",
    ]

    dataset_paths_to_visualize = {
        label: dataset_paths[key] for key, label in zip(visualize_list, label_list)
    }


    sample_num = {
        "Pedipulation Init": 500,
        "Velocity Init": 500,
        "Pedipulation Training": 300,
        "Velocity Training": 300,

        "RND": 500,
        "INV Ensemble Exploration": 500,
        "RLE exploration": 500,
        "Pedipulation Expert": 500,
        "Velocity Expert": 500,

        "Pedipulation Init (no random)": 500,

        "Pedi 100 Absolute": 300,
        "Velo 100 Absolute": 300,
        "INV collected Absolute": 2400,

        "Exploration-Based Data Collection (Flat)": 2400,
    }

    palette = {
        "Pedipulation Init": "#FF6600",  
        "Velocity Init": "#00ffc8",  
        "Pedipulation Expert": "#B42F2F", 
        "Velocity Expert": "#6f24e9",  
        "RND": "#0EC70E",  # Light Green
        "INV Ensemble Exploration": "#D165C3",  # Tomato
        "RLE exploration": "#D3B820",  # Gold

        "Pedipulation Training": "#CA2610",  # Tomato
        "Velocity Training": "#187FD4",  # Steel Blue

        "Pedipulation Init (no random)": "#6ED32B",

        "Pedi 100": "#FF5733",  # Red
        "Velo 100": "#33FF57",  # Green
        "INV collected": "#3357FF",  # Blue

    }

    fit = True

    if fit:
        visualizer = InvSamplesVisualization(n_neighbors=30, min_dist=0.99, random_seed=86, use_PCA=False, )
        visualizer.visualize_samples_hue_dataset(
            dict_dataset_paths=dataset_paths_to_visualize, 
            vis_samples_per_dataset=sample_num, 
            fit=True, 
            dataset_palette=None)
        # with open('logs/analysis/umap_visualizer/velo_pedi_dataset_visualizer.pkl', 'wb') as f:
        #     pickle.dump(visualizer, f)
    else:
        with open('logs/analysis/umap_visualizer/velo_pedi_dataset_visualizer.pkl', 'rb') as f:
            visualizer: InvSamplesVisualization = pickle.load(f)
        visualizer.visualize_samples_hue_dataset(
            dict_dataset_paths=dataset_paths_to_visualize, 
            vis_samples_per_dataset=sample_num, 
            fit=False, 
            dataset_palette=palette)