# P4RL: this is a data sample visualization script. 

import seaborn as sns
import torch
import pandas as pd
import seaborn as sns
import umap
from typing import Dict
import numpy as np
from sklearn.decomposition import PCA
import pickle
from rsl_rl.addons.invdynamics.inv_dynamics_dataset_paths import dataset_paths
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsMLP
from rsl_rl.addons.invdynamics.inv_dynamics_utils import DynamicSlidingWindowDataset
import os
import random


def sns_jointplot(embedding, colors, title, palette=None, save_dir="./logs/analysis/plots/", color_notation="Dataset"):
    # Convert to DataFrame for Seaborn
    embedding_df = pd.DataFrame(embedding, columns=["Dim 1", "Dim 2"])
    embedding_df[color_notation] = colors

    # Plot using sns.jointplot
    plot = sns.jointplot(
        data=embedding_df,
        x="Dim 1",
        y="Dim 2",
        hue=color_notation,
        kind="scatter",
        palette=palette if palette else "Set1",
        edgecolor="black",
        linewidth=0,
        s=20, # Size of the points
        alpha=0.5,
    )

    # plot.figure.suptitle(title, fontsize=14)

    plot.figure.tight_layout()  # first tighten layout
    plot.figure.subplots_adjust(top=0.95)  # move title up (1.0 = top edge of figure)

    # Save the plot to an image file with specified DPI
    fig_name = save_dir + title.replace(" ", "_") + ".png"
    plot.figure.savefig(fig_name, dpi=500)
    print(f"Plot saved as '{fig_name}'")


# get the respresentation from disk dataset by forwarding through trained state delta encoding layer

def get_data_samples_pretrain(path_trained_model: str, dataset_path: str, sample_num: int):
    inv_dynamics_cfg = {
    "class_name": "InvDynamicsMLP",
    "dim_states": 33,  # 33 + 9 (contact booleans)
    "dim_actions": 12,
    "representation_dim": 256,
    "hidden_dims": [512, 256, 128],
    "mode": "inv",
    # "mode": "fwd",
    "lstm_core": False,  # True for LSTM, False for MLP
    # "mode": "dl",
    "activation_name": "elu", # or "siren"
    "input_timesteps": 5,
    }
    model: InvDynamicsMLP = eval(inv_dynamics_cfg["class_name"])(device="cuda", **inv_dynamics_cfg)
    model.load_state_dict(torch.load(path_trained_model))

    dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=inv_dynamics_cfg["input_timesteps"], load_into_memory=True)
    x, a = dataset.get_sample_entries_in_file(sample_num=sample_num, seed=42) 
    delta_x = x[:, -1, :] - x[:, -2, :]  # Get the delta state
    samples_embedding = model.state_delta_encoder(delta_x.to(model.device)).detach().cpu().numpy()  # Forward through the state delta encoder
    return samples_embedding

def get_data_samples_RL_from_disk(path: str, buffer_interval: int, num_samples: int = 500):
    # Step 1: Get sorted list of .pt files
    files = [f for f in os.listdir(path) if f.endswith(".pt")]
    files.sort()  # assumes filenames are sortable in correct order

    print(f"Found {len(files)} files in {path}.")
    
    # Step 2: Select every buffer_interval-th file
    selected_files = files[::buffer_interval]

    num_entry_to_read_per_file = num_samples // len(selected_files)
    
    collected_samples = []
    
    for file_name in selected_files:
        file_path = os.path.join(path, file_name)
        
        # Load tensor [batch_size, dim]
        data_tensor = torch.load(file_path)  # shape [batch_size, dim]
        
        if not torch.is_tensor(data_tensor):
            raise ValueError(f"File {file_name} does not contain a tensor.")
        
        batch_size = data_tensor.shape[0]
        
        # Step 3: Randomly select up to 'remaining' samples from this file
        indices = random.sample(range(batch_size), num_entry_to_read_per_file)
        collected_samples.append(data_tensor[indices])
    
    # Step 4: Concatenate and return [num_samples, dim]
    result_tensor = torch.cat(collected_samples, dim=0)
    print(f"Collected {result_tensor.shape[0]} samples from {len(selected_files)} files.")
    return result_tensor.numpy()
    
# plot function

def plot_comparison():

    samples_pretrain = get_data_samples_pretrain(
        path_trained_model="p4rl_assets/inv_dynamics_new/absolute_0811_pedi_output_clamped.pt",
        # dataset_path=dataset_paths["Pedi 100 Absolute"], 
        dataset_path=dataset_paths["Pedipulation Init (Absolute, Noise)"], 
        sample_num=1000,
    )

    color_pretrain = np.array([0] * samples_pretrain.shape[0])

    samples_RL = get_data_samples_RL_from_disk("logs/analysis/high_level_policy_search_latent", buffer_interval=100, num_samples=2000)
    
    color_RL = (np.arange(samples_RL.shape[0])*1.0/ samples_RL.shape[0])*0.5+0.5

    colors = np.concatenate([color_pretrain, color_RL], axis=0)

    # samples_concatenated = samples_pretrain
    samples_concatenated = np.concatenate((samples_pretrain, samples_RL), axis=0)

    reducer = umap.UMAP(
                        n_neighbors=40,
                        min_dist=0.2,
                        random_state=42,
                    )
    
    vis_reps = reducer.fit_transform(samples_concatenated)
    sns_jointplot(vis_reps, 
                  colors, 
                  title="UMAP visualizations of high level policy search space distribution of INV Dynamics", 
                  palette = "rainbow", save_dir="./logs/analysis/plots/")
    

if __name__ == "__main__":
    plot_comparison()