


import gym
import d4rl # Import required to register environments, you may need to also import the submodule

from d3rlpy.datasets import MDPDataset
import numpy as np
import os

# maze2d-umaze-expert-v1

# 规定maze2d的环境名称格式为maze2d-{type}-{policy}-v1

def get_maze2d(env_name):    # 新加入的函数

    env_attr = env_name.split('-')

    env_type = env_attr[1]

    policy = env_attr[2]
    
    reward_type = env_attr[3]

    if "dense" in env_name:
        env_name_gym = f"maze2d-{env_type}-{reward_type}-v1"
    else:
        env_name_gym = f"maze2d-{env_type}-v1"

    # print(env_attr)
    # print(env_type, policy, reward_type)
    # print(env_name_gym)


    env = gym.make(env_name_gym)

    # path = os.path.join('/home/shjj/文档/codes/Offline-Imitation-Learning/data/maze2d/', 'data', policy, env_type, 'data.npz')
    path = os.path.join('data/maze2d/', 'data', policy, env_type, 'data.npz')

    
    if "dense" in env_name:
        path = os.path.join('data/maze2d/', 'data', policy, env_type + "-dense", 'data.npz')
    # print(path)
    # exit(0)
    # path 为数据存储的路径

    dataset = np.load(path)
    print("load data from <<< ", path)

    return dataset, env


def get_d4rl2d3rl(env_name: str):
    """Returns d4rl dataset and envrironment.

    The dataset is provided through d4rl.

    .. code-block:: python

        from d3rlpy.datasets import get_d4rl

        dataset, env = get_d4rl('hopper-medium-v0')

    References:
        * `Fu et al., D4RL: Datasets for Deep Data-Driven Reinforcement
          Learning. <https://arxiv.org/abs/2004.07219>`_
        * https://github.com/rail-berkeley/d4rl

    Args:
        env_name: environment id of d4rl dataset.

    Returns:
        tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.

    """
    import d4rl  # type: ignore

    if 'maze2d' in env_name:

        try:
            env = gym.make(env_name)
            dataset = env.get_dataset()
        except:
            dataset, env = get_maze2d(env_name)
    else:
        env = gym.make(env_name)
        dataset = env.get_dataset()

    observations = dataset["observations"]
    actions = dataset["actions"]
    rewards = dataset["rewards"]
    terminals = dataset["terminals"]
    timeouts = dataset["timeouts"]
    episode_terminals = np.logical_or(terminals, timeouts)

    mdp_dataset = MDPDataset(
        observations=np.array(observations, dtype=np.float32),
        actions=np.array(actions, dtype=np.float32),
        rewards=np.array(rewards, dtype=np.float32),
        terminals=np.array(terminals, dtype=np.float32),
        episode_terminals=np.array(episode_terminals, dtype=np.float32),
        )


    return mdp_dataset, env


def build_mdpdata_from_episodes(episodes, env): 

    observations = []
    actions = []
    rewards = []
    terminals = []
    episode_terminals = []

    ep_rets = []
    for episode in episodes:
      ep_rets.append(episode.rewards.sum())


      for idx, transition in enumerate(episode):
        observations.append(transition.observation)
        if isinstance(env.action_space, gym.spaces.Box):
          actions.append(np.reshape(transition.action, env.action_space.shape))
        else:
          actions.append(transition.action)
        rewards.append(transition.reward)
        terminals.append(transition.terminal)
        episode_terminals.append(idx == len(episode) - 1)
    dataset = MDPDataset(
      observations=np.stack(observations),
      actions=np.stack(actions),
      rewards=np.stack(rewards),
      terminals=np.stack(terminals).astype(float),
      episode_terminals=np.stack(episode_terminals).astype(float))
    
    rewards=np.stack(rewards)
    print(f"ep_reward_sum   Max/Mean/Median/Min: {np.max(ep_rets)}/{np.mean(ep_rets)}/{np.median(ep_rets)}/{np.min(ep_rets)}")
    print(f"state-action reward     Max/Mean/Median/Min:   {np.max(rewards)}/{np.mean(rewards)}/{np.median(rewards)}/{np.min(rewards)} ")

    # dataset.dump(mix_data_path)
    # print("save mixed dataset >>> ", mix_data_path)

    return dataset


def get_sorted_idx(data, discount = 0.99):
    
    returns = []
    for episode_i in data.episodes:
        returns.append(episode_i.rewards.sum())

    return np.argsort(returns)[::-1]

def get_offline_imitation_data(expert_name, offline_name, expert_num=10, offline_exp=0):

    data_e, env = get_d4rl2d3rl(expert_name)

    # exp_sorted_idx = get_sorted_idx(data_e)


    # exp_exp_idx = exp_sorted_idx[:expert_num]
    # off_exp_idx = exp_sorted_idx[expert_num: expert_num + offline_exp]

    # expert_episodes = [data_e.episodes[idx] for idx in exp_exp_idx]

    expert_episodes = []

    if "antmaze" in expert_name:
        # sort the episode by the return

        ep_rets = []
        for episode_i in data_e.episodes:
            ep_rets.append(episode_i.rewards.sum())

        ranked_idx = np.argsort(ep_rets)[::-1]
        print(ranked_idx)

        for i in range(expert_num):
            expert_episodes.append(data_e.episodes[ranked_idx[i]])
    else:
        expert_episodes = data_e.episodes[:expert_num]



    data_o, env = get_d4rl2d3rl(offline_name)


    # offline_episodes = [data_e.episodes[idx] for idx in off_exp_idx]
    offline_episodes = []
    for episode_i in data_e.episodes[expert_num:expert_num+offline_exp]:
        offline_episodes.append(episode_i)
    for episode_i in data_o.episodes:
        offline_episodes.append(episode_i)

    print("expert data: {} [{} episodes]".format(expert_name, len(expert_episodes)))
    data_expert = build_mdpdata_from_episodes(expert_episodes, env)


    print("offline data: {} [{} episodes]".format(offline_name, len(offline_episodes)))
    data_offline = build_mdpdata_from_episodes(offline_episodes, env)


    # print(data_expert.observations.shape)
    # print(data_expert.observations[0])

    return data_expert, data_offline, env







    

    



if __name__ == "__main__":

    import argparse

    '''
    data_name = "hopper-expert-v2"

    print(data_name)

    dataset, env = get_d4rl2d3rl(data_name)

    print(len(dataset.episodes), "episodes")

    count = 0
    for item in dataset.episodes:
        count += len(item)
    
    print(count, ": transitions")
    '''

    # get_offline_imitation_data("hopper-expert-v2", "hopper-random-v2")
    # get_offline_imitation_data("maze2d-umaze-expert-v1", "maze2d-umaze-random-v1")
    get_offline_imitation_data("maze2d-umaze-expert-v1", "maze2d-umaze-v1")


