if __name__ == "__main__":
    import sys
    import os
    import pathlib

    ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
    sys.path.append(ROOT_DIR)


import os
import click
import h5py
import pickle
import copy
import zarr
import pathlib
import numpy as np
from tqdm import tqdm
from tf_agents.environments.wrappers import TimeLimit
from tf_agents.environments.gym_wrapper import GymWrapper
from tf_agents.trajectories.time_step import StepType
from diffusion_policy.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal
from collect_square_exp_data import save_agentview_video

import pickle
import io
import h5py
import diffusion_policy.env.block_pushing  # Ensure the module is imported

class CustomUnpickler(pickle.Unpickler):
    """Custom unpickler that remaps missing module names by prepending 'diffusion_policy.env.'."""
    def find_class(self, module, name):
        print(f"Original module = {module}, class = {name}")

        # Rewrite module name by prepending "diffusion_policy.env."
        new_module = "diffusion_policy.env." + module if module.startswith("block_pushing") else module

        try:
            return getattr(__import__(new_module, fromlist=[name]), name)
        except ImportError:
            raise ModuleNotFoundError(f"Could not remap {module} to {new_module}")

def load_teleoperation_data(hdf5_file_path):
    """
    Load teleoperation data from an HDF5 file, reconstructing pickled states.
    """
    data = []
    with h5py.File(hdf5_file_path, "r") as hdf5_file:
        sorted_episode_keys = sorted(hdf5_file.keys(), key=lambda x: int(x.split('_')[1]))
        for episode_key in sorted_episode_keys:
            episode_data = []
            episode_group = hdf5_file[episode_key]
            sorted_group_keys = sorted(episode_group.keys(), key=lambda x: int(x))
            for step_key in sorted_group_keys:
                step_group = episode_group[step_key]
                
                # Wrap in io.BytesIO to simulate a file
                state_bytes = step_group['state'][()]
                state = CustomUnpickler(io.BytesIO(state_bytes)).load()

                episode_data.append({'state': state})
            data.append(episode_data)
    return data
