

import gym
import d4rl # Import required to register environments, you may need to also import the submodule


# data_name = "maze2d-umaze"

data_name = "halfcheetah-random"

# Create the environment
env = gym.make(data_name)

# d4rl abides by the OpenAI gym interface
env.reset()
env.step(env.action_space.sample())

# Each task is associated with a dataset
# dataset contains observations, actions, rewards, terminals, and infos
dataset = env.get_dataset()
print(dataset['observations']) # An N x dim_observation Numpy array of observations


# Alternatively, use d4rl.qlearning_dataset which
# also adds next_observations.
dataset = d4rl.qlearning_dataset(env)

print(dataset["next_observations"])

print(dataset["observations"])
print(dataset["actions"])
print(dataset["rewards"])

states, actions, rewards = dataset["observations"], dataset["actions"], dataset["rewards"]

print(states.shape)
print(actions.shape)
print(rewards.shape)

print(rewards.sum())

print(dataset.keys())
print(dataset["terminals"])



from d3rlpy.datasets import MDPDataset
import numpy as np

# for item in dataset.episodes:
    # print(item)



def get_d4rl2d3rl(env_name: str):
    """Returns d4rl dataset and envrironment.

    The dataset is provided through d4rl.

    .. code-block:: python

        from d3rlpy.datasets import get_d4rl

        dataset, env = get_d4rl('hopper-medium-v0')

    References:
        * `Fu et al., D4RL: Datasets for Deep Data-Driven Reinforcement
          Learning. <https://arxiv.org/abs/2004.07219>`_
        * https://github.com/rail-berkeley/d4rl

    Args:
        env_name: environment id of d4rl dataset.

    Returns:
        tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.

    """
    import d4rl  # type: ignore

    env = gym.make(env_name)
    dataset = env.get_dataset()

    observations = dataset["observations"]
    actions = dataset["actions"]
    rewards = dataset["rewards"]
    terminals = dataset["terminals"]
    timeouts = dataset["timeouts"]
    episode_terminals = np.logical_or(terminals, timeouts)

    mdp_dataset = MDPDataset(
        observations=np.array(observations, dtype=np.float32),
        actions=np.array(actions, dtype=np.float32),
        rewards=np.array(rewards, dtype=np.float32),
        terminals=np.array(terminals, dtype=np.float32),
        episode_terminals=np.array(episode_terminals, dtype=np.float32),
        )

    return mdp_dataset, env



dataset, env = get_d4rl2d3rl(data_name)

print(dataset)


print(len(dataset.episodes), "episodes")

episode_0 = dataset.episodes[0]

print(len(episode_0) , ": transitions")

# for trans_i in episode_0:
#     # transition.observation
#     # transition.action
#     # transition.reward
#     # transition.next_observation
#     print(trans_i.observation)
#     print(trans_i.action)
#     print(trans_i.reward)
#     print(trans_i.next_observation)
