import torch
import os
import einops
from tqdm import tqdm
import h5py
from torch.utils.data import Dataset
import seaborn as sns
import umap
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.nn.utils.rnn import pad_sequence


def read_transtions_sequence_from_disk(path):
    f = torch.load(path)
    return f["observations"], f["actions"], f["dones"]


def generate_trajectories_dataset(transitions_dir, maximum_trajectory_length: int, dataset_file_name):
    """
        Generate a trajectory dataset, with each group is a complete trajectory.
        NOTE: The transition files must be consecutive! 
    """

    # num_envs, obs_dim, act_dim, num_transitions_per_file = 4096, 45+48, 12, 24
    # get the dimension params
    obs_it, actions_it, dones_it = read_transtions_sequence_from_disk(os.path.join(transitions_dir, sorted(os.listdir(transitions_dir))[0]))
    num_transitions_per_file, num_envs, obs_dim = list(obs_it.shape)
    act_dim = actions_it.shape[-1]

    obs_traj_buffer = torch.zeros((num_envs, maximum_trajectory_length, obs_dim), dtype=torch.float32) # [max_trajectory_length, num_envs, obs_dim]
    act_traj_buffer = torch.zeros((num_envs, maximum_trajectory_length, act_dim), dtype=torch.float32) # [max_trajectory_length, num_envs, act_dim]
    pointers = torch.zeros((num_envs,), dtype=torch.int64) # [num_envs,] - pointers to the current position in the trajectory
    
    sample_counter = 0

    with h5py.File(dataset_file_name, "a") as f:    
        for file in tqdm(sorted(os.listdir(transitions_dir)), desc="Stitching trajectories from files:", unit="file"):
            path = os.path.join(transitions_dir, file)
            obs_it, actions_it, dones_it = read_transtions_sequence_from_disk(path) # [num_transitions, num_envs, dim]

            obs_it, actions_it, dones_it = obs_it.permute(1, 0, 2), actions_it.permute(1, 0, 2), dones_it.permute(1, 0, 2) # [num_envs, num_transitions, dim]
            dones_true_idx_row, dones_true_idx_column = torch.nonzero(dones_it.squeeze(-1), as_tuple=True)

            j=0
            # for envs where there is at least one done transition, we dump the trajectory
            # for i in range(dones_true_idx_row.shape[0]):
            for i in range(num_envs):
                transition_idx_old = -1
                # check if the next unprocessed done transition is in the current environment
                while j < dones_true_idx_row.shape[0] and dones_true_idx_row[j] == i:
                    env_idx, done_transition_idx = dones_true_idx_row[j], dones_true_idx_column[j].item()
                    seg_length = done_transition_idx - transition_idx_old
                    transition_start_idx = transition_idx_old + 1
                    # append the trajectory segment to the large matrix
                    obs_traj_buffer[env_idx, pointers[env_idx]:pointers[env_idx]+seg_length].copy_(obs_it[env_idx, transition_start_idx:transition_start_idx+seg_length])
                    act_traj_buffer[env_idx, pointers[env_idx]:pointers[env_idx]+seg_length].copy_(actions_it[env_idx, transition_start_idx:transition_start_idx+seg_length])
                    # dump the complete trajectory to file
                    obs_traj = obs_traj_buffer[env_idx, :pointers[env_idx]+seg_length]
                    act_traj = act_traj_buffer[env_idx, :pointers[env_idx]+seg_length]

                    # update old indices
                    transition_idx_old = done_transition_idx

                    grp = f.create_group(f'sample_{sample_counter}')
                    grp.create_dataset('obs_traj', data=obs_traj.cpu().numpy())
                    grp.create_dataset('act_traj', data=act_traj.cpu().numpy())
                    sample_counter += 1
                    # clear the pointers for the next trajectory
                    pointers[env_idx] = 0
                    j += 1

                seg_length = num_transitions_per_file - (transition_idx_old + 1)
                transition_start_idx = transition_idx_old + 1

                # if there is no more unhandled done transition in the current environment, we just append the trajectory segment
                obs_traj_buffer[env_idx, pointers[env_idx]:pointers[env_idx]+seg_length].copy_(obs_it[env_idx, transition_start_idx:])
                act_traj_buffer[env_idx, pointers[env_idx]:pointers[env_idx]+seg_length].copy_(actions_it[env_idx, transition_start_idx:])

                # update the pointers
                pointers[env_idx] += seg_length

    print(f"[INFO] Dataset saved to {dataset_file_name}")



class DynamicsTrajectoryDataset(Dataset):
    def __init__(self, h5_path):
        """
        Args:
            h5_path (str): Path to the HDF5 file.
        """
        self.h5_path = h5_path
        self.sample_names = []

        # Index sample names
        with h5py.File(self.h5_path, 'r') as f:
            self.sample_names = list(f.keys())

    def __len__(self):
        return len(self.sample_names)

    def __getitem__(self, idx):
        sample_name = self.sample_names[idx]

        with h5py.File(self.h5_path, 'r') as f:
            obs_traj = torch.from_numpy(f[sample_name]['obs_traj'][...]).float()
            act_traj = torch.from_numpy(f[sample_name]['act_traj'][...]).float()

        return obs_traj, act_traj
    
def obs_dim_indexing(obs): 
    history_length_in_one_obs = 6
    jp, a, lin_vel, ang_vel, grav = obs[..., 0:72], obs[..., 72:144], obs[..., 144:162], obs[..., 162:180], obs[..., 180:198]
    N, L = jp.shape[:2]
    jp, a, lin_vel, ang_vel, grav = jp.reshape(N, L, history_length_in_one_obs, -1), a.reshape(N, L, history_length_in_one_obs, -1), lin_vel.reshape(N, L, history_length_in_one_obs, -1), ang_vel.reshape(N, L, history_length_in_one_obs, -1), grav.reshape(N, L, history_length_in_one_obs, -1)
    obs_segment = torch.cat([jp, lin_vel, ang_vel, grav], dim=-1) # [history_length_in_one_obs, 21]
    return obs_segment[:, :, -1, :] 

    


def collate_variable_length(batch):
    """
    Pads oct_traj and act_traj sequences in the batch, returns mask and lengths.
    NOTE: this function does obs dim indexing!

    Args:
        batch: List of tuples (oct_traj, act_traj)

    Returns:
        oct_padded: [B, T_max, dim_obs]
        act_padded: [B, T_max, dim_act]
        mask: [B, T_max] (True = valid, False = pad)
        lengths: [B] original lengths
    """
    oct_trajs, act_trajs = zip(*batch)  # Each is a list of tensors

    lengths = torch.tensor([t.shape[0] for t in oct_trajs], dtype=torch.long)
    max_len = lengths.max()

    obs_padded = obs_dim_indexing(pad_sequence(oct_trajs, batch_first=True))  # [B, T_max, dim_obs]
    act_padded = pad_sequence(act_trajs, batch_first=True)  # [B, T_max, dim_act]

    # Create mask where True = valid (not padding)
    mask = torch.arange(max_len)[None, :] < lengths[:, None]  # [B, T_max]

    return obs_padded, act_padded, mask, lengths


if __name__ == "__main__":

    # generate_trajectories_dataset(
    #     transitions_dir = "logs/rsl_rl/pedipulation_EAC_baseline_rel/2025-06-11_09-39-25/transitions", 
    #     maximum_trajectory_length = 300, 
    #     dataset_file_name = "logs/datasets/dynamics_rel/pedi_data_only_init.h5")
    
    dataset = DynamicsTrajectoryDataset('logs/datasets/dynamics_rel/pedi_data_only_init.h5')
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_variable_length)

    for oct_batch, act_batch, mask, lengths in dataloader:
        print(oct_batch.shape, act_batch.shape, mask.shape, lengths)