import torch
import os
import einops
from tqdm import tqdm
import h5py
from torch.utils.data import Dataset
import seaborn as sns
import umap
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, random_split


def read_transtions_sequence_from_disk(path):
    f = torch.load(path)
    return f["observations"], f["actions"], f["dones"]

def generate_dynamics_dataset(transitions_dir, sequence_timesteps: int, save_dir: str = "logs/datasets/dynamics/pedipulation_vanilla_EAC/"):
    """
        Generate dataset for dynamics model training from transitions stored on disk. This includes checking on the 
        dones and splitting the trajectories into segments.

        args:
            transitions_dir: str, path to the directory containing the transitions
            sequence_timesteps: int, number of timesteps in each segment. Should be past_timesteps + future_timestep (usually 1)
            input_slice: slice, slice for the input features
            output_slice: slice, slice for the output features
            save_dir: str, path to save the dataset
    """

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    i = 0

    for file in tqdm(sorted(os.listdir(transitions_dir)), desc="Loading transitions", unit="file"):
        path = os.path.join(transitions_dir, file)
        obs_it, actions_it, dones_it = read_transtions_sequence_from_disk(path)

        # [num_transitions, num_envs, obs_dim]
        obs_it, actions_it, dones_it = obs_it.permute(1, 0, 2), actions_it.permute(1, 0, 2), dones_it.permute(1, 0, 2)
        # [num_envs, num_transitions, obs_dim]

        obs_iter_i_list, actions_iter_i_list = [], []

        for j in range(obs_it.shape[0]):
            # get the current trajectory
            obs = obs_it[j]
            actions = actions_it[j]
            dones = dones_it[j]

            for k in range(obs_it.shape[1]-(sequence_timesteps-1)):
                if dones[k:k+sequence_timesteps-1].any():
                    # if any of the dones before the last timestep is True, this segment is stitched and should be skipped
                    continue
                # get the current segment
                obs_segment = obs[k:k+sequence_timesteps]
                actions_segment = actions[k:k+sequence_timesteps]
                # check if the segment is valid
                # if obs_segment.shape[0] != sequence_timesteps:
                #     raise ValueError("Segment is not valid")
                obs_iter_i_list.append(obs_segment)
                actions_iter_i_list.append(actions_segment)

        # Save dataset
        with h5py.File(os.path.join(save_dir, f"it_{i:03d}.h5"), "w") as f:
            f.create_dataset("obs_seq_tensor", data=torch.stack(obs_iter_i_list, dim=0).cpu().numpy())
            f.create_dataset("actions_seq_tensor", data=torch.stack(actions_iter_i_list, dim=0).cpu().numpy())

        i += 1

    print(f"[INFO] Dataset saved to {save_dir}")



def generate_dynamics_dataset_history(transitions_dir, save_dir: str, 
                                        history_length_in_one_obs: int = 6,
                                        mode = "only_joint_position" # "only_joint_position", "all"
                                      ):
    """
        Generate dataset for dynamics model training from transitions stored on disk. ATTENTION: This function assumes that
        the observation of one single timestep follows the following structure:

        ```
        # observation terms for data collection (dim_0=270), noise free
        joint_pos_log = ObsTerm(func=mdp.joint_pos_rel, history_length=6) # 0:72
        actions_log = ObsTerm(func=mdp.last_action, clip=(-100.00, 100.00), history_length=6) # indices 72:144
        base_lin_vel_log = ObsTerm(func=mdp.base_lin_vel, history_length=6) # indices 144:162
        base_ang_vel_log = ObsTerm(func=mdp.base_ang_vel, history_length=6) # indices 162:180
        projected_gravity_log = ObsTerm(
            func=mdp.projected_gravity,
            history_length=6
        ) # indices 180:198
        joint_vel_log = ObsTerm(func=mdp.joint_vel_rel, history_length=6) # length 198:270

        # observations terms for the baseline policy to use (dim_1=48)
        base_lin_vel = ObsTerm(func=mdp.base_lin_vel, noise=Unoise(n_min=-0.1, n_max=0.1)) # indices 0:3
        base_ang_vel = ObsTerm(func=mdp.base_ang_vel, noise=Unoise(n_min=-0.2, n_max=0.2)) # indices 3:6
        projected_gravity = ObsTerm(
            func=mdp.projected_gravity,
            noise=Unoise(n_min=-0.05, n_max=0.05),
        ) 
        joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01)) # length 12, indices 9:21
        joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-1.5, n_max=1.5)) # length 12, indices 21:33
        actions = ObsTerm(func=mdp.last_action, clip=(-100.00, 100.00)) # length 12, indices 33:45
        foot_tracking_commands = ObsTerm(func=mdp.foot_tracking_commands, params={"asset_cfg": SceneEntityCfg("robot", body_names=".*FOOT")}) # indices 9:12
        ```
        
        args:
            transitions_dir: str, path to the directory containing the transitions
            sequence_timesteps: int, number of timesteps in each segment. Should be past_timesteps + future_timestep (usually 1)
            input_slice: slice, slice for the input features
            output_slice: slice, slice for the output features
            save_dir: str, path to save the dataset
    """
    # load all transitions from the transitions dir

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    i = 0

    for file in tqdm(sorted(os.listdir(transitions_dir)), desc="Processing transition buffers", unit="file"):
        path = os.path.join(transitions_dir, file)
        obs_it, actions_it, dones_it = read_transtions_sequence_from_disk(path)
        # [num_transitions, num_envs, obs_dim]
        obs_it, actions_it, dones_it = obs_it.permute(1, 0, 2), actions_it.permute(1, 0, 2), dones_it.permute(1, 0, 2)
        # [num_envs, num_transitions, obs_dim]

        assert obs_it.shape[-1] == 318, "Transition files contain unexpected shapes of data, make sure they are consistent with descriptions in the function comments!"

        obs_iter_i_list, actions_iter_i_list = [], []

        for j in range(obs_it.shape[0]):
            # get the current trajectory
            obs = obs_it[j]
            actions = actions_it[j]
            dones = dones_it[j]

            for k in range(obs_it.shape[1]):
                # get the current segment
                # ATTENTION
                if mode == "only_joint_position":
                    obs_segment = obs[k, 0:72].reshape(history_length_in_one_obs, -1)
                    actions_segment = obs[k, 72:144].reshape(history_length_in_one_obs, -1)
                elif mode == "all":
                    jp, a, lin_vel, ang_vel, grav = obs[k, 0:72], obs[k, 72:144], obs[k, 144:162], obs[k, 162:180], obs[k, 180:198]
                    jp, a, lin_vel, ang_vel, grav = jp.reshape(history_length_in_one_obs, -1), a.reshape(history_length_in_one_obs, -1), lin_vel.reshape(history_length_in_one_obs, -1), ang_vel.reshape(history_length_in_one_obs, -1), grav.reshape(history_length_in_one_obs, -1)
                    obs_segment = torch.cat([jp, lin_vel, ang_vel, grav], dim=-1) # [history_length_in_one_obs, 21]
                    actions_segment = a.reshape(history_length_in_one_obs, -1)
                else:
                    raise ValueError("Mode should be either 'only_joint_position' or 'all'")
                # check if the segment is valid
                # if obs_segment.shape[0] != sequence_timesteps:
                #     raise ValueError("Segment is not valid")
                obs_iter_i_list.append(obs_segment)
                actions_iter_i_list.append(actions_segment)

        # Save dataset
        with h5py.File(os.path.join(save_dir, f"it_{i:03d}.h5"), "w") as f:
            f.create_dataset("obs_seq_tensor", data=torch.stack(obs_iter_i_list, dim=0).cpu().numpy())
            f.create_dataset("actions_seq_tensor", data=torch.stack(actions_iter_i_list, dim=0).cpu().numpy())

        i += 1

    print(f"[INFO] Dataset saved to {save_dir}")



def sns_jointplot(embedding, colors, title, save_dir="./logs/analysis/plots/", hue="buffer"):
        # Convert to DataFrame for Seaborn
        import pandas as pd
        embedding_df = pd.DataFrame(embedding, columns=["Dim 1", "Dim 2"])
        embedding_df["buffer"] = colors

        # Plot using sns.jointplot
        plot = sns.jointplot(
            data=embedding_df,
            x="Dim 1",
            y="Dim 2",
            hue="buffer",
            kind="scatter",
            palette="plasma",
            edgecolor="white",
            linewidth=0.1,
            alpha=0.5,
        )

        plot.fig.suptitle(title, fontsize=14)

        plot.fig.tight_layout()  # first tighten layout
        plot.fig.subplots_adjust(top=0.95)  # move title up (1.0 = top edge of figure)

        # Save the plot to an image file with specified DPI
        fig_name = save_dir + title.replace(" ", "_") + ".png"
        plot.fig.savefig(fig_name, dpi=500)
        print(f"Plot saved as '{fig_name}'")



def visualize_dyna_with_UMAP(dataset_dir, iter_interval = (0, 150, 10), samples_per_iter: int = 200, n_neighbors = 2, min_dist = 0.2, random_state = 42):

    # load dataset
    i = 0
    obs_list = []
    actions_list = []
    colors_list = []

    for file in sorted(os.listdir(dataset_dir)):
        
        if i in range(iter_interval[0], iter_interval[1], iter_interval[2]):
            with h5py.File(os.path.join(dataset_dir, file), "r") as f:
                
                obs_t = np.array(f["obs_seq_tensor"])
                num_seq = obs_t.shape[0]
                indices = np.random.choice(num_seq, size=samples_per_iter, replace=False)

                obs_list.append(obs_t[indices])
                actions_list.append(np.array(f["actions_seq_tensor"])[indices])
                colors_list.append(np.ones(samples_per_iter, dtype=np.int16) * (i))

        i += 1

    obs_tensor = np.concatenate(obs_list, axis=0)[:, :-1].reshape(-1, 5*12)
    # obs_tensor = np.concatenate(obs_list, axis=0)[:, :-1, 9:21].reshape(-1, 3*12)
    actions_tensor = np.concatenate(actions_list, axis=0)[:, :-1].reshape(-1, 5*12)
    colors_tensor = np.concatenate(colors_list, axis=0)

    # input_tensor = np.concatenate([obs_tensor, actions_tensor], axis=-1)
    # input_tensor = obs_tensor
    input_tensor = obs_tensor
    

    reducer = umap.UMAP(
                    n_neighbors=n_neighbors,
                    min_dist=min_dist,
                    random_state=random_state,
                )
    embedding = reducer.fit_transform(input_tensor)

    # Plot using Seaborn
    title = f"input space of dynamics dataset (pedipulation)"
    sns_jointplot(embedding, colors_tensor, title)
    


class DynamicsDataset(Dataset):
    def __init__(self, dataset_root_dir):
        """
        h5_paths: list of paths to h5 files
        dataset_key: key name for the sequence data inside each h5 file
        label_key: key name for labels inside each h5 file
        transform: optional transform to apply to the sequences
        """
        self.h5_paths = [os.path.join(dataset_root_dir, file_name) for file_name in sorted(os.listdir(dataset_root_dir))]
        self.dataset_key = ["obs_seq_tensor", "actions_seq_tensor"]

        self.file_index_map = []  # [(file_idx, local_idx), ...]
        self.file_lens = []       # number of items per file

        # Precompute global index mapping
        for file_idx, path in enumerate(self.h5_paths):
            with h5py.File(path, 'r') as f:
                length = f[self.dataset_key[0]].shape[0]
                self.file_lens.append(length)
                self.file_index_map.extend([(file_idx, i) for i in range(length)])

        self.h5_files = None  # initialized lazily per worker

    def __len__(self):
        return len(self.file_index_map)

    def _init_files(self):
        # Each worker gets its own set of file handles
        self.h5_files = [h5py.File(p, 'r') for p in self.h5_paths]

    def __getitem__(self, idx):
        if self.h5_files is None:
            self._init_files()

        file_idx, local_idx = self.file_index_map[idx]
        h5f = self.h5_files[file_idx]

        out = [torch.tensor(h5f[key][local_idx], dtype=torch.float32) for key in self.dataset_key]

        return out
    
    def find_start_idx_of_file(self, file_idx):
        """
        Find the start index of a file in the dataset
        """
        if file_idx >= len(self.file_lens):
            raise ValueError(f"File index {file_idx} out of range. Max index is {len(self.h5_files) - 1}")
        return sum(self.file_lens[:file_idx])
    
    def num_files(self):
        return len(self.h5_paths)
    
    def get_sample_entries_in_file(self, file_idx, sample_num=500):
        """
        Get all entries in a file
        """
        if self.h5_files is None:
            self._init_files()

        if file_idx >= len(self.h5_files):
            raise ValueError(f"File index {file_idx} out of range. Max index is {len(self.h5_files) - 1}")
        if sample_num > self.file_lens[file_idx]:
            raise ValueError(f"Sample number {sample_num} out of range. Max sample number is {self.file_lens[file_idx]}")

        length = self.h5_files[file_idx][self.dataset_key[0]].shape[0]
        indices = torch.sort(torch.randperm(length)[:sample_num])[0]

        h5f = self.h5_files[file_idx]
        out = [torch.tensor(h5f[key][indices], dtype=torch.float32) for key in self.dataset_key]

        return out

    def __del__(self):
        # Clean up open file handles
        if self.h5_files:
            for f in self.h5_files:
                f.close()


if __name__ == "__main__":

    # Some test code to run the functions
    # 1. read transitions

    # path = "logs/rsl_rl/pedipulation_CDAC_baseline/2025-04-15_10-11-57/transitions/transitions_iter_00000.pt"
    # transitions = read_transtions_sequence_from_disk(path)

    # 2. generate dataset
    # generate_dynamics_dataset("logs/rsl_rl/pedipulation_EAC_baseline/2025-04-15_10-42-19/transitions", 4)

    # 3. visualize dataset
    # visualize_dyna_with_UMAP("logs/datasets/dynamics/pedipulation_vanilla_EAC_new/", samples_per_iter=500)

    # 4. data feeding 
    # ds = DynamicsDataset("logs/datasets/dynamics/pedipulation_vanilla_EAC/")
    # dl = DataLoader(ds, batch_size=32, shuffle=True, num_workers=4)
    # for batch in dl:
    #     obs, actions = batch
    #     print(obs.shape, actions.shape)
    #     break

    # 5. generate dataset with observations containing history
    generate_dynamics_dataset_history("logs/rsl_rl/pedipulation_EAC_baseline_rel/2025-06-11_09-39-25/transitions", 
                                    save_dir="logs/datasets/dynamics_rel/pedi_vanilla_EAC_only_initial/",
                                    history_length_in_one_obs=6, mode="all")
    
    pass