from abc import abstractmethod
from typing import List

import os
import gym
import numpy as np
import torch
from tensordict import TensorDict

from torchrl.data import LazyTensorStorage, LazyMemmapStorage
from grl.utils.log import log


class GPDataset(torch.utils.data.Dataset):
    """
    Overview:
        Dataset for Generative Policy algorithm.
        The training of Generative Policy algorithm sometimes needs true action and fake action.
        The true action is sampled from the dataset, and the fake action \
        is sampled from the behaviour policy, which is data augmentation.
    Interface:
        ``__init__``, ``__getitem__``, ``__len__``.
    """

    def __init__(self):
        """
        Overview:
            Initialization method of GPD4RLDataset class
        """
        pass

    def __getitem__(self, index):
        """
        Overview:
            Get data by index
        Arguments:
            index (:obj:`int`): Index of data
        Returns:
            data (:obj:`dict`): Data dict
        
        .. note::
            The data dict contains the following keys:
            
            s (:obj:`torch.Tensor`): State
            a (:obj:`torch.Tensor`): Action
            r (:obj:`torch.Tensor`): Reward
            s_ (:obj:`torch.Tensor`): Next state
            d (:obj:`torch.Tensor`): Is finished
            fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \
                (fake action is sampled from the action support generated by the behaviour policy)
            fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \
                (fake action is sampled from the action support generated by the behaviour policy)
        """

        data = {
            "s": self.states[index % self.len],
            "a": self.actions[index % self.len],
            "r": self.rewards[index % self.len],
            "s_": self.next_states[index % self.len],
            "d": self.is_finished[index % self.len],
            "fake_a": (
                self.fake_actions[index % self.len]
                if hasattr(self, "fake_actions")
                else 0.0
            ),  # self.fake_actions <D, 16, A>
            "fake_a_": (
                self.fake_next_actions[index % self.len]
                if hasattr(self, "fake_next_actions")
                else 0.0
            ),  # self.fake_next_actions <D, 16, A>
        }
        return data

    def __len__(self):
        return self.len

    def load_fake_actions(self, fake_actions, fake_next_actions):
        self.fake_actions = fake_actions
        self.fake_next_actions = fake_next_actions

    @abstractmethod
    def return_range(self, dataset, max_episode_steps):
        raise NotImplementedError

class GPTensorDictDataset(torch.utils.data.Dataset):
    """
    Overview:
        Dataset for Generative Policy algorithm.
        The training of Generative Policy algorithm sometimes needs true action and fake action.
        The true action is sampled from the dataset, and the fake action \
        is sampled from the behaviour policy, which is data augmentation.
    Interface:
        ``__init__``, ``__getitem__``, ``__len__``.
    """

    def __init__(self):
        """
        Overview:
            Initialization method of GPD4RLDataset class
        """
        pass

    def __getitem__(self, index):
        """
        Overview:
            Get data by index
        Arguments:
            index (:obj:`int`): Index of data
        Returns:
            data (:obj:`dict`): Data dict
        
        .. note::
            The data dict contains the following keys:
            
            s (:obj:`torch.Tensor`): State
            a (:obj:`torch.Tensor`): Action
            r (:obj:`torch.Tensor`): Reward
            s_ (:obj:`torch.Tensor`): Next state
            d (:obj:`torch.Tensor`): Is finished
            fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \
                (fake action is sampled from the action support generated by the behaviour policy)
            fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \
                (fake action is sampled from the action support generated by the behaviour policy)
        """

        data = self.storage.get(index=index)
        return data

    def __len__(self):
        return self.len

    def load_fake_actions(self, fake_actions, fake_next_actions):
        self.fake_actions = fake_actions
        self.fake_next_actions = fake_next_actions
        if self.action_augment_num:
            self.storage.set(
                range(self.len), TensorDict(
                    {
                        "s": self.states,
                        "a": self.actions,
                        "r": self.rewards,
                        "s_": self.next_states,
                        "d": self.is_finished,
                        "fake_a": self.fake_actions,
                        "fake_a_": self.fake_next_actions,
                    },
                    batch_size=[self.len],
                )
            )
        else:
            self.storage.set(
                range(self.len), TensorDict(
                    {
                        "s": self.states,
                        "a": self.actions,
                        "r": self.rewards,
                        "s_": self.next_states,
                        "d": self.is_finished,
                    },
                    batch_size=[self.len],
                )
            )

    @abstractmethod
    def return_range(self, dataset, max_episode_steps):
        raise NotImplementedError


class GPD4RLDataset(GPDataset):
    """
    Overview:
        D4RL Dataset for Generative Policy algorithm.
    Interface:
        ``__init__``, ``__getitem__``, ``__len__``.
    """

    def __init__(
        self,
        env_id: str,
    ):
        """
        Overview:
            Initialization method of GPD4RLDataset class
        Arguments:
            env_id (:obj:`str`): The environment id
        """

        super().__init__()
        import d4rl

        data = d4rl.qlearning_dataset(gym.make(env_id))
        self.states = torch.from_numpy(data["observations"]).float()
        self.actions = torch.from_numpy(data["actions"]).float()
        self.next_states = torch.from_numpy(data["next_observations"]).float()
        reward = torch.from_numpy(data["rewards"]).view(-1, 1).float()
        self.is_finished = torch.from_numpy(data["terminals"]).view(-1, 1).float()

        reward_tune = "iql_antmaze" if "antmaze" in env_id else "iql_locomotion"
        if reward_tune == "normalize":
            reward = (reward - reward.mean()) / reward.std()
        elif reward_tune == "iql_antmaze":
            reward = reward - 1.0
        elif reward_tune == "iql_locomotion":
            min_ret, max_ret = GPD4RLDataset.return_range(data, 1000)
            reward /= max_ret - min_ret
            reward *= 1000
        elif reward_tune == "cql_antmaze":
            reward = (reward - 0.5) * 4.0
        elif reward_tune == "antmaze":
            reward = (reward - 0.25) * 2.0
        self.rewards = reward
        self.len = self.states.shape[0]
        log.info(f"{self.len} data loaded in GPD4RLDataset")

    def return_range(dataset, max_episode_steps):
        returns, lengths = [], []
        ep_ret, ep_len = 0.0, 0
        for r, d in zip(dataset["rewards"], dataset["terminals"]):
            ep_ret += float(r)
            ep_len += 1
            if d or ep_len == max_episode_steps:
                returns.append(ep_ret)
                lengths.append(ep_len)
                ep_ret, ep_len = 0.0, 0
        # returns.append(ep_ret)    # incomplete trajectory
        lengths.append(ep_len)  # but still keep track of number of steps
        assert sum(lengths) == len(dataset["rewards"])
        return min(returns), max(returns)


class GPOnlineDataset(GPDataset):
    """
    Overview:
        Dataset for Generative Policy algorithm for online data collection.
    Interface:
        ``__init__``, ``__getitem__``, ``__len__``.
    """

    def __init__(
        self,
        fake_action_shape: int = None,
        data: List = None,
    ):
        """
        Overview:
            Initialization method of GPD4RLDataset class
        Arguments:
            data (:obj:`List`): The data list
        """

        super().__init__()
        self.fake_action_shape = fake_action_shape
        if data is not None:
            self.states = torch.from_numpy(data["observations"]).float()
            self.actions = torch.from_numpy(data["actions"]).float()
            self.next_states = torch.from_numpy(data["next_observations"]).float()
            reward = torch.from_numpy(data["rewards"]).view(-1, 1).float()
            self.is_finished = torch.from_numpy(data["terminals"]).view(-1, 1).float()
            self.rewards = reward
            # self.fake_actions = torch.zeros_like(self.actions.unsqueeze(1).expand(-1, fake_action_shape, -1))
            # self.fake_next_actions = torch.zeros_like(self.actions.unsqueeze(1).expand(-1, fake_action_shape, -1))
            self.len = self.states.shape[0]
        else:
            self.states = torch.tensor([])
            self.actions = torch.tensor([])
            self.next_states = torch.tensor([])
            self.is_finished = torch.tensor([])
            self.rewards = torch.tensor([])
            # self.fake_actions = torch.tensor([])
            # self.fake_next_actions = torch.tensor([])
            self.len = 0
        log.debug(f"{self.len} data loaded in GPOnlineDataset")

    def drop_data(self, drop_ratio: float, random: bool = True):
        # drop the data from the dataset
        drop_num = int(self.len * drop_ratio)
        # randomly drop the data if random is True
        if random:
            drop_indices = torch.randperm(self.len)[:drop_num]
        else:
            drop_indices = torch.arange(drop_num)
        keep_mask = torch.ones(self.len, dtype=torch.bool)
        keep_mask[drop_indices] = False
        self.states = self.states[keep_mask]
        self.actions = self.actions[keep_mask]
        self.next_states = self.next_states[keep_mask]
        self.is_finished = self.is_finished[keep_mask]
        self.rewards = self.rewards[keep_mask]
        # self.fake_actions = self.fake_actions[keep_mask]
        # self.fake_next_actions = self.fake_next_actions[keep_mask]
        self.len = self.states.shape[0]
        log.debug(f"{drop_num} data dropped in GPOnlineDataset")

    def load_data(self, data: List):
        # concatenate the data into the dataset

        # collate the data by sorting the keys

        keys = ["obs", "action", "done", "next_obs", "reward"]

        collated_data = {
            k: torch.tensor(np.stack([item[k] for item in data]))
            for i, k in enumerate(keys)
        }

        self.states = torch.cat([self.states, collated_data["obs"].float()], dim=0)
        self.actions = torch.cat([self.actions, collated_data["action"].float()], dim=0)
        self.next_states = torch.cat(
            [self.next_states, collated_data["next_obs"].float()], dim=0
        )
        reward = collated_data["reward"].view(-1, 1).float()
        self.is_finished = torch.cat(
            [self.is_finished, collated_data["done"].view(-1, 1).float()], dim=0
        )
        self.rewards = torch.cat([self.rewards, reward], dim=0)
        # self.fake_actions = torch.cat([self.fake_actions, torch.zeros_like(collated_data['action'].unsqueeze(1).expand(-1, self.fake_action_shape, -1))], dim=0)
        # self.fake_next_actions = torch.cat([self.fake_next_actions, torch.zeros_like(collated_data['action'].unsqueeze(1).expand(-1, self.fake_action_shape, -1))], dim=0)
        self.len = self.states.shape[0]
        log.debug(f"{self.len} data loaded in GPOnlineDataset")

    def return_range(dataset, max_episode_steps):
        returns, lengths = [], []
        ep_ret, ep_len = 0.0, 0
        for r, d in zip(dataset["rewards"], dataset["terminals"]):
            ep_ret += float(r)
            ep_len += 1
            if d or ep_len == max_episode_steps:
                returns.append(ep_ret)
                lengths.append(ep_len)
                ep_ret, ep_len = 0.0, 0
        # returns.append(ep_ret)    # incomplete trajectory
        lengths.append(ep_len)  # but still keep track of number of steps
        assert sum(lengths) == len(dataset["rewards"])
        return min(returns), max(returns)


class GPD4RLOnlineDataset(GPDataset):
    """
    Overview:
        D4RL Dataset for GP algorithm for online data collection.
    Interface:
        ``__init__``, ``__getitem__``, ``__len__``.
    """

    def __init__(
        self,
        env_id: str,
        fake_action_shape: int = None,
    ):
        """
        Overview:
            Initialization method of GPD4RLDataset class
        Arguments:
            data (:obj:`List`): The data list
        """

        super().__init__()
        self.fake_action_shape = fake_action_shape
        import d4rl

        data = d4rl.qlearning_dataset(gym.make(env_id))
        self.states = torch.from_numpy(data["observations"]).float()
        self.actions = torch.from_numpy(data["actions"]).float()
        self.next_states = torch.from_numpy(data["next_observations"]).float()
        reward = torch.from_numpy(data["rewards"]).view(-1, 1).float()
        self.is_finished = torch.from_numpy(data["terminals"]).view(-1, 1).float()

        reward_tune = "iql_antmaze" if "antmaze" in env_id else "iql_locomotion"
        if reward_tune == "normalize":
            reward = (reward - reward.mean()) / reward.std()
        elif reward_tune == "iql_antmaze":
            reward = reward - 1.0
        elif reward_tune == "iql_locomotion":
            min_ret, max_ret = GPD4RLDataset.return_range(data, 1000)
            reward /= max_ret - min_ret
            reward *= 1000
        elif reward_tune == "cql_antmaze":
            reward = (reward - 0.5) * 4.0
        elif reward_tune == "antmaze":
            reward = (reward - 0.25) * 2.0
        self.rewards = reward
        self.len = self.states.shape[0]

        log.debug(f"{self.len} data loaded in GPD4RLOnlineDataset")

    def drop_data(self, drop_ratio: float, random: bool = True):
        # drop the data from the dataset
        drop_num = int(self.len * drop_ratio)
        # randomly drop the data if random is True
        if random:
            drop_indices = torch.randperm(self.len)[:drop_num]
        else:
            drop_indices = torch.arange(drop_num)
        keep_mask = torch.ones(self.len, dtype=torch.bool)
        keep_mask[drop_indices] = False
        self.states = self.states[keep_mask]
        self.actions = self.actions[keep_mask]
        self.next_states = self.next_states[keep_mask]
        self.is_finished = self.is_finished[keep_mask]
        self.rewards = self.rewards[keep_mask]
        # self.fake_actions = self.fake_actions[keep_mask]
        # self.fake_next_actions = self.fake_next_actions[keep_mask]
        self.len = self.states.shape[0]
        log.debug(f"{drop_num} data dropped in GPOnlineDataset")

    def load_data(self, data: List):
        # concatenate the data into the dataset

        # collate the data by sorting the keys

        keys = ["obs", "action", "done", "next_obs", "reward"]

        collated_data = {
            k: torch.tensor(np.stack([item[k] for item in data]))
            for i, k in enumerate(keys)
        }

        self.states = torch.cat([self.states, collated_data["obs"].float()], dim=0)
        self.actions = torch.cat([self.actions, collated_data["action"].float()], dim=0)
        self.next_states = torch.cat(
            [self.next_states, collated_data["next_obs"].float()], dim=0
        )
        reward = collated_data["reward"].view(-1, 1).float()
        self.is_finished = torch.cat(
            [self.is_finished, collated_data["done"].view(-1, 1).float()], dim=0
        )
        self.rewards = torch.cat([self.rewards, reward], dim=0)
        # self.fake_actions = torch.cat([self.fake_actions, torch.zeros_like(collated_data['action'].unsqueeze(1).expand(-1, self.fake_action_shape, -1))], dim=0)
        # self.fake_next_actions = torch.cat([self.fake_next_actions, torch.zeros_like(collated_data['action'].unsqueeze(1).expand(-1, self.fake_action_shape, -1))], dim=0)
        self.len = self.states.shape[0]
        log.debug(f"{self.len} data loaded in GPOnlineDataset")

    def return_range(dataset, max_episode_steps):
        returns, lengths = [], []
        ep_ret, ep_len = 0.0, 0
        for r, d in zip(dataset["rewards"], dataset["terminals"]):
            ep_ret += float(r)
            ep_len += 1
            if d or ep_len == max_episode_steps:
                returns.append(ep_ret)
                lengths.append(ep_len)
                ep_ret, ep_len = 0.0, 0
        # returns.append(ep_ret)    # incomplete trajectory
        lengths.append(ep_len)  # but still keep track of number of steps
        assert sum(lengths) == len(dataset["rewards"])
        return min(returns), max(returns)


class GPCustomizedDataset(GPDataset):
    """
    Overview:
        Dataset for Generative Policy algorithm for customized data.
    Interface:
        ``__init__``, ``__getitem__``, ``__len__``.
    """

    def __init__(
        self,
        env_id: str = None,
        numpy_data_path: str = None,
    ):
        """
        Overview:
            Initialization method of GPCustomizedDataset class
        Arguments:
            env_id (:obj:`str`): The environment id
            numpy_data_path (:obj:`str`): The path to the numpy data
        """

        super().__init__()

        data = np.load(numpy_data_path)

        self.states = torch.from_numpy(data["obs"]).float()
        self.actions = torch.from_numpy(data["action"]).float()
        self.next_states = torch.from_numpy(data["next_obs"]).float()
        reward = torch.from_numpy(data["reward"]).view(-1, 1).float()
        self.is_finished = torch.from_numpy(data["done"]).view(-1, 1).float()

        self.rewards = reward
        self.len = self.states.shape[0]
        log.info(f"{self.len} data loaded in GPCustomizedDataset")


class GPD4RLTensorDictDataset(GPTensorDictDataset):
    """
    Overview:
        D4RL Dataset for Generative Policy algorithm.
    Interface:
        ``__init__``, ``__getitem__``, ``__len__``.
    """

    def __init__(
        self,
        env_id: str,
        action_augment_num: int = None,
    ):
        """
        Overview:
            Initialization method of GPD4RLDataset class
        Arguments:
            env_id (:obj:`str`): The environment id
        """

        super().__init__()
        import d4rl

        data = d4rl.qlearning_dataset(gym.make(env_id))
        self.states = torch.from_numpy(data["observations"]).float()
        self.actions = torch.from_numpy(data["actions"]).float()
        self.next_states = torch.from_numpy(data["next_observations"]).float()
        reward = torch.from_numpy(data["rewards"]).view(-1, 1).float()
        self.is_finished = torch.from_numpy(data["terminals"]).view(-1, 1).float()

        reward_tune = "iql_antmaze" if "antmaze" in env_id else "iql_locomotion"
        if reward_tune == "normalize":
            reward = (reward - reward.mean()) / reward.std()
        elif reward_tune == "iql_antmaze":
            reward = reward - 1.0
        elif reward_tune == "iql_locomotion":
            min_ret, max_ret = GPD4RLDataset.return_range(data, 1000)
            reward /= max_ret - min_ret
            reward *= 1000
        elif reward_tune == "cql_antmaze":
            reward = (reward - 0.5) * 4.0
        elif reward_tune == "antmaze":
            reward = (reward - 0.25) * 2.0
        self.rewards = reward
        self.len = self.states.shape[0]
        log.info(f"{self.len} data loaded in GPD4RLDataset")
        self.action_augment_num = action_augment_num
        self.storage = LazyTensorStorage(max_size=self.len)
        if self.action_augment_num:
            self.storage.set(
                range(self.len), TensorDict(
                    {
                        "s": self.states,
                        "a": self.actions,
                        "r": self.rewards,
                        "s_": self.next_states,
                        "d": self.is_finished,
                        "fake_a": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(self.action_augment_num, dim=1),
                        "fake_a_": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(self.action_augment_num, dim=1),
                    },
                    batch_size=[self.len],
                )
            )
        else:
            self.storage.set(
                range(self.len), TensorDict(
                    {
                        "s": self.states,
                        "a": self.actions,
                        "r": self.rewards,
                        "s_": self.next_states,
                        "d": self.is_finished,
                    },
                    batch_size=[self.len],
                )
            )

    def return_range(dataset, max_episode_steps):
        returns, lengths = [], []
        ep_ret, ep_len = 0.0, 0
        for r, d in zip(dataset["rewards"], dataset["terminals"]):
            ep_ret += float(r)
            ep_len += 1
            if d or ep_len == max_episode_steps:
                returns.append(ep_ret)
                lengths.append(ep_len)
                ep_ret, ep_len = 0.0, 0
        # returns.append(ep_ret)    # incomplete trajectory
        lengths.append(ep_len)  # but still keep track of number of steps
        assert sum(lengths) == len(dataset["rewards"])
        return min(returns), max(returns)

    def __getitem__(self, index):
        """
        Overview:
            Get data by index
        Arguments:
            index (:obj:`int`): Index of data
        Returns:
            data (:obj:`dict`): Data dict
        
        .. note::
            The data dict contains the following keys:
            
            s (:obj:`torch.Tensor`): State
            a (:obj:`torch.Tensor`): Action
            r (:obj:`torch.Tensor`): Reward
            s_ (:obj:`torch.Tensor`): Next state
            d (:obj:`torch.Tensor`): Is finished
            fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \
                (fake action is sampled from the action support generated by the behaviour policy)
            fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \
                (fake action is sampled from the action support generated by the behaviour policy)
        """

        data = self.storage.get(index=index)
        return data

    def __len__(self):
        return self.len

class GPCustomizedTensorDictDataset(GPTensorDictDataset):
    """
    Overview:
        Dataset for Generative Policy algorithm for customized data.
    Interface:
        ``__init__``, ``__getitem__``, ``__len__``.
    """

    def __init__(
        self,
        env_id: str = None,
        action_augment_num: int = 16,
        numpy_data_path: str = None,
    ):
        """
        Overview:
            Initialization method of GPCustomizedDataset class
        Arguments:
            env_id (:obj:`str`): The environment id
            numpy_data_path (:obj:`str`): The path to the numpy data
        """

        super().__init__()

        data = np.load(numpy_data_path)

        self.states = torch.from_numpy(data["obs"]).float()
        self.actions = torch.from_numpy(data["action"]).float()
        self.next_states = torch.from_numpy(data["next_obs"]).float()
        reward = torch.from_numpy(data["reward"]).view(-1, 1).float()
        self.is_finished = torch.from_numpy(data["done"]).view(-1, 1).float()

        self.rewards = reward
        self.len = self.states.shape[0]
        log.info(f"{self.len} data loaded in GPCustomizedDataset")
        self.action_augment_num = action_augment_num
        self.storage = LazyTensorStorage(max_size=self.len)
        self.storage.set(
            range(self.len), TensorDict(
                {
                    "s": self.states,
                    "a": self.actions,
                    "r": self.rewards,
                    "s_": self.next_states,
                    "d": self.is_finished,
                    "fake_a": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1),
                    "fake_a_": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1),
                },
                batch_size=[self.len],
            )
        )


class GPDMcontrolTensorDictDataset():
    def __init__(
        self,
        path: str,
    ):
        state_dicts = {}
        next_states_dicts = {}
        actions_list = []
        rewards_list = []
        
        data = np.load(path, allow_pickle=True)
        obs_keys = list(data[0]["s"].keys())
        
        for key in obs_keys:
            if key not in state_dicts:
                state_dicts[key] = []
                next_states_dicts[key] = []
            
            state_values = np.array([item["s"][key] for item in data], dtype=np.float32)
            next_state_values = np.array([item["s_"][key] for item in data], dtype=np.float32)
            
            state_dicts[key].append(torch.tensor(state_values))
            next_states_dicts[key].append(torch.tensor(next_state_values))
                
        actions_values = np.array([item["a"] for item in data], dtype=np.float32)
        rewards_values = np.array([item["r"] for item in data], dtype=np.float32).reshape(-1, 1)
        actions_list.append(torch.tensor(actions_values))
        rewards_list.append(torch.tensor(rewards_values))
            
            
        self.actions = torch.cat(actions_list, dim=0)
        self.rewards = torch.cat(rewards_list, dim=0)
        self.len = self.actions.shape[0]
        self.states = TensorDict(
            {key: torch.cat(state_dicts[key], dim=0) for key in obs_keys},
            batch_size=[self.len],
        )
        self.next_states = TensorDict(
            {key: torch.cat(next_states_dicts[key], dim=0) for key in obs_keys},
            batch_size=[self.len],
        )
        self.is_finished = torch.zeros_like(self.rewards, dtype=torch.bool)
        self.storage = LazyMemmapStorage(max_size=self.len)
        self.storage.set(
            range(self.len), TensorDict(
                {
                    "s": self.states,
                    "a": self.actions,
                    "r": self.rewards,
                    "s_": self.next_states,
                    "d": self.is_finished,
                },
                batch_size=[self.len],
            )
        )


class GPDMControlVisualTensorDictDataset_backup(torch.utils.data.Dataset):
    def __init__(
        self,
        env_id: str,
        policy_type: str,
        pixel_size: int,
        path: str,
    ):
        assert env_id in ["cheetah_run", "humanoid_walk", "walker_walk"]
        assert policy_type in ["expert", "medium", "medium_expert", "medium_replay", "random"]
        assert pixel_size in [64, 84]
        if pixel_size == 64:
            npz_folder_path = os.path.join(path, env_id, policy_type, "64px")
        else:
            npz_folder_path = os.path.join(path, env_id, policy_type, "84px")

        # find all npz files in the folder
        npz_files = [f for f in os.listdir(npz_folder_path) if f.endswith(".npz")]
        
        transition_counter = 0

        obs_list = []
        action_list = []
        reward_list = []
        next_obs_list = []
        is_finished_list = []
        episode_list = []
        step_list = []

        # open all npz files in the folder
        for index, npz_file in enumerate(npz_files):
            
            npz_path = os.path.join(npz_folder_path, npz_file)
            data = np.load(npz_path, allow_pickle=True)
            
            obs = torch.from_numpy(data["image"][:-1])
            action = torch.from_numpy(data["action"][1:])
            reward = torch.from_numpy(data["reward"][1:])
            next_obs = torch.from_numpy(data["image"][1:])
            is_finished = torch.from_numpy(data["is_last"][1:] + data["is_terminal"][1:])
            episode = torch.tensor([index] * obs.shape[0])
            step = torch.arange(obs.shape[0])
            transition_counter += obs.shape[0]
            obs_list.append(obs)
            action_list.append(action)
            reward_list.append(reward)
            next_obs_list.append(next_obs)
            is_finished_list.append(is_finished)
            episode_list.append(episode)
            step_list.append(step)

        self.states = torch.cat(obs_list, dim=0)
        self.actions = torch.cat(action_list, dim=0)
        self.rewards = torch.cat(reward_list, dim=0)
        self.next_states = torch.cat(next_obs_list, dim=0)
        self.is_finished = torch.cat(is_finished_list, dim=0)
        self.episode = torch.cat(episode_list, dim=0)
        self.step = torch.cat(step_list, dim=0)
        self.len = self.states.shape[0]
        self.storage = LazyMemmapStorage(max_size=self.len)

        self.storage.set(
            range(self.len), TensorDict(
                {
                    "s": self.states,
                    "a": self.actions,
                    "r": self.rewards,
                    "s_": self.next_states,
                    "d": self.is_finished,
                    "episode": self.episode,
                    "step": self.step,
                },
                batch_size=[self.len],
            )
        )

    def __getitem__(self, index):
        """
        Overview:
            Get data by index
        Arguments:
            index (:obj:`int`): Index of data
        Returns:
            data (:obj:`dict`): Data dict
        
        .. note::
            The data dict contains the following keys:
            
            s (:obj:`torch.Tensor`): State
            a (:obj:`torch.Tensor`): Action
            r (:obj:`torch.Tensor`): Reward
            s_ (:obj:`torch.Tensor`): Next state
            d (:obj:`torch.Tensor`): Is finished
            episode (:obj:`torch.Tensor`): Episode index
        """

        data = self.storage.get(index=index)
        return data

    def __len__(self):
        return self.len
    

class GPDMControlVisualTensorDictDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        env_id: str,
        policy_type: str,
        pixel_size: int,
        path: str,
        stack_frames: int,
    ):
        assert env_id in ["cheetah_run", "humanoid_walk", "walker_walk"]
        assert policy_type in ["expert", "medium", "medium_expert", "medium_replay", "random"]
        assert pixel_size in [64, 84]
        if pixel_size == 64:
            npz_folder_path = os.path.join(path, env_id, policy_type, "64px")
        else:
            npz_folder_path = os.path.join(path, env_id, policy_type, "84px")

        # find all npz files in the folder
        npz_files = [f for f in os.listdir(npz_folder_path) if f.endswith(".npz")]
        
        transition_counter = 0

        obs_list = []
        action_list = []
        reward_list = []
        next_obs_list = []
        is_finished_list = []
        episode_list = []
        step_list = []

        # open all npz files in the folder
        for index, npz_file in enumerate(npz_files):
            
            npz_path = os.path.join(npz_folder_path, npz_file)
            data = np.load(npz_path, allow_pickle=True)
            
            length = data["image"].shape[0]
            obs = torch.stack([torch.from_numpy(data["image"][i:length-stack_frames+i]) for i in range(stack_frames)], dim=1)
            next_obs = torch.stack([torch.from_numpy(data["image"][i+1:length-stack_frames+i+1]) for i in range(stack_frames)], dim=1)

            action = torch.from_numpy(data["action"][stack_frames:])
            reward = torch.from_numpy(data["reward"][stack_frames:])
            
            is_finished = torch.from_numpy(data["is_last"][stack_frames:] + data["is_terminal"][stack_frames:])
            episode = torch.tensor([index] * obs.shape[0])
            step = torch.arange(obs.shape[0])
            transition_counter += obs.shape[0]
            obs_list.append(obs)
            action_list.append(action)
            reward_list.append(reward)
            next_obs_list.append(next_obs)
            is_finished_list.append(is_finished)
            episode_list.append(episode)
            step_list.append(step)

        self.states = torch.cat(obs_list, dim=0)
        self.actions = torch.cat(action_list, dim=0)
        self.rewards = torch.cat(reward_list, dim=0)
        self.next_states = torch.cat(next_obs_list, dim=0)
        self.is_finished = torch.cat(is_finished_list, dim=0)
        self.episode = torch.cat(episode_list, dim=0)
        self.step = torch.cat(step_list, dim=0)
        self.len = self.states.shape[0]
        self.storage = LazyMemmapStorage(max_size=self.len)

        self.storage.set(
            range(self.len), TensorDict(
                {
                    "s": self.states,
                    "a": self.actions,
                    "r": self.rewards,
                    "s_": self.next_states,
                    "d": self.is_finished,
                    "episode": self.episode,
                    "step": self.step,
                },
                batch_size=[self.len],
            )
        )

    def __getitem__(self, index):
        """
        Overview:
            Get data by index
        Arguments:
            index (:obj:`int`): Index of data
        Returns:
            data (:obj:`dict`): Data dict
        
        .. note::
            The data dict contains the following keys:
            
            s (:obj:`torch.Tensor`): State
            a (:obj:`torch.Tensor`): Action
            r (:obj:`torch.Tensor`): Reward
            s_ (:obj:`torch.Tensor`): Next state
            d (:obj:`torch.Tensor`): Is finished
            episode (:obj:`torch.Tensor`): Episode index
        """

        data = self.storage.get(index=index)
        return data

    def __len__(self):
        return self.len
 
