


import gym
import d4rl # Import required to register environments, you may need to also import the submodule

from d3rlpy.datasets import MDPDataset
import numpy as np
import os

# maze2d-umaze-expert-v1

# 规定maze2d的环境名称格式为maze2d-{type}-{policy}-v1
def get_maze2d(env_name):    # 新加入的函数

    env_attr = env_name.split('-')

    env_type = env_attr[1]

    policy = env_attr[2]
    
    reward_type = env_attr[3]

    if "dense" in env_name:
        env_name_gym = f"maze2d-{env_type}-{reward_type}-v1"
    else:
        env_name_gym = f"maze2d-{env_type}-v1"

    # print(env_attr)
    # print(env_type, policy, reward_type)
    # print(env_name_gym)


    env = gym.make(env_name_gym)

    # path = os.path.join('/home/shjj/文档/codes/Offline-Imitation-Learning/data/maze2d/', 'data', policy, env_type, 'data.npz')
    path = os.path.join('data/maze2d/', 'data', policy, env_type, 'data.npz')

    # print(path)
    # exit(0)
    # path 为数据存储的路径

    dataset = np.load(path)

    return dataset, env


def get_d4rl2d3rl(env_name: str):
    """Returns d4rl dataset and envrironment.

    The dataset is provided through d4rl.

    .. code-block:: python

        from d3rlpy.datasets import get_d4rl

        dataset, env = get_d4rl('hopper-medium-v0')

    References:
        * `Fu et al., D4RL: Datasets for Deep Data-Driven Reinforcement
          Learning. <https://arxiv.org/abs/2004.07219>`_
        * https://github.com/rail-berkeley/d4rl

    Args:
        env_name: environment id of d4rl dataset.

    Returns:
        tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.

    """
    import d4rl  # type: ignore

    if 'maze2d' in env_name:

        try:
            env = gym.make(env_name)
            dataset = env.get_dataset()
        except:
            dataset, env = get_maze2d(env_name)
    else:
        env = gym.make(env_name)
        dataset = env.get_dataset()

    observations = dataset["observations"]
    actions = dataset["actions"]
    rewards = dataset["rewards"]
    terminals = dataset["terminals"]
    timeouts = dataset["timeouts"]
    episode_terminals = np.logical_or(terminals, timeouts)
    # N = observations.shape[0]
    # print('N = ',N)
    # print('terminals.sum()',terminals.sum())
    # print('timeouts.sum()',timeouts.sum())
    # print('episode_terminals.sum()',episode_terminals.sum())
    # # for i in range(N):
    # #     if(terminals[i] == 1):print(terminals[i])
    # print('before get_d4rl2d3rl episode num: ',episode_terminals.sum())
    # print('before get_d4rl2d3rl s-a num: ',episode_terminals.shape[0])
    # print('before get_d4rl2d3rl reward num: ',rewards.sum())
    mdp_dataset = MDPDataset(
        observations=np.array(observations, dtype=np.float32),
        actions=np.array(actions, dtype=np.float32),
        rewards=np.array(rewards, dtype=np.float32),
        terminals=np.array(terminals, dtype=np.float32),
        episode_terminals=np.array(episode_terminals, dtype=np.float32),
        )
    # print('After get_d4rl2d3rl episode num: ',mdp_dataset.size())
    # mdp_sa_num = 0
    # for episode in mdp_dataset:
    #     mdp_sa_num += episode.size()
    # print('After get_d4rl2d3rl s-a num: ',mdp_sa_num)
    # print('After get_d4rl2d3rl observations.shape[0]: ',mdp_dataset.observations.shape[0])
    # print('before get_d4rl2d3rl reward num: ',mdp_dataset.rewards.sum())
    
    

    return mdp_dataset, env


def build_mdpdata_from_episodes(episodes, env): 

    observations = []
    actions = []
    rewards = []
    terminals = []
    episode_terminals = []

    ep_rets = []
    for episode in episodes:
      ep_rets.append(episode.rewards.sum())


      for idx, transition in enumerate(episode):
        observations.append(transition.observation)
        if isinstance(env.action_space, gym.spaces.Box):
          actions.append(np.reshape(transition.action, env.action_space.shape))
        else:
          actions.append(transition.action)
        rewards.append(transition.reward)
        terminals.append(transition.terminal)
        episode_terminals.append(idx == len(episode) - 1)
    dataset = MDPDataset(
      observations=np.stack(observations),
      actions=np.stack(actions),
      rewards=np.stack(rewards),
      terminals=np.stack(terminals).astype(float),
      episode_terminals=np.stack(episode_terminals).astype(float))
    
    rewards=np.stack(rewards)
    print(f"ep_reward_sum   Max/Mean/Median/Min: {np.max(ep_rets)}/{np.mean(ep_rets)}/{np.median(ep_rets)}/{np.min(ep_rets)}")
    print(f"state-action reward     Max/Mean/Median/Min:   {np.max(rewards)}/{np.mean(rewards)}/{np.median(rewards)}/{np.min(rewards)} ")

    # dataset.dump(mix_data_path)
    # print("save mixed dataset >>> ", mix_data_path)

    return dataset


def get_sorted_idx(data, discount = 0.99):
    
    returns = []
    for episode_i in data.episodes:
        returns.append(episode_i.rewards.sum())

    return np.argsort(returns)[::-1]

# def get_episode_list(data:MDPDataset):
#     N = data.observations.shape[0]
#     ret = []
#     observations = []
#     rewards = []
#     actions = []
#     terminals = []
#     episode_terminals = []
#     for i in range(N):
#         observations.append(data.observations[i])
#         rewards.append(data.rewards[i])
#         actions.append(data.actions[i])
#         terminals.append(data.terminals[i])
#         episode_terminals.append(data.episode_terminals[i])
#         if(data.episode_terminals[i] == 1):
#             ret.append({
#                 observations
#             })
            

def get_offline_imitation_data(expert_name, offline_name, expert_num=10, offline_exp=0):

    data_e, env = get_d4rl2d3rl(expert_name)

    # exp_sorted_idx = get_sorted_idx(data_e)


    # exp_exp_idx = exp_sorted_idx[:expert_num]
    # off_exp_idx = exp_sorted_idx[expert_num: expert_num + offline_exp]

    # expert_episodes = [data_e.episodes[idx] for idx in exp_exp_idx]

    expert_episodes = data_e.episodes[:expert_num]
    data_o, env = get_d4rl2d3rl(offline_name)


    # offline_episodes = [data_e.episodes[idx] for idx in off_exp_idx]
    offline_episodes = []
    for episode_i in data_e.episodes[expert_num:expert_num+offline_exp]:
        offline_episodes.append(episode_i)
    for episode_i in data_o.episodes:
        offline_episodes.append(episode_i)
    rewards = 0
    # terminals = 0
    for episode in expert_episodes:
        rewards += episode.rewards.sum()
        # terminals += episode.terminals.sum()
    # print('expert data: {} [{} episodes] reward sum : {}'.format(expert_name, len(expert_episodes),rewards))
    # print('expert data: {} [{} episodes] terminals sum : {}'.format(expert_name, len(expert_episodes),terminals))
    print("expert data: {} [{} episodes]".format(expert_name, len(expert_episodes)))
    data_expert = build_mdpdata_from_episodes(expert_episodes, env)
    # print("* expert data: {} [{} episodes]".format(expert_name, len(data_expert.episodes)))
    # print('* expert data: {} [{} episodes] reward sum : {}'.format(expert_name, len(data_expert),data_expert.rewards.sum()))
    # print('* expert data: {} [{} episodes] terminals sum : {}'.format(expert_name, len(data_expert),data_expert.terminals.sum()))
    
    rewards = 0
    # terminals = 0
    for episode in offline_episodes:
        rewards += episode.rewards.sum()
        # terminals += episode.terminals.sum()
    # print('offline data: {} [{} episodes] reward sum : {}'.format(offline_name, len(expert_episodes),rewards))
    # print('offline data: {} [{} episodes] terminals sum : {}'.format(offline_name, len(expert_episodes),terminals))
    print("offline data: {} [{} episodes]".format(offline_name, len(offline_episodes)))
    data_offline = build_mdpdata_from_episodes(offline_episodes, env)
    # print("* offline data: {} [{} episodes]".format(offline_name, len(data_offline.episodes)))
    # print('* offline data: {} [{} episodes] reward sum : {}'.format(offline_name, len(data_offline),data_offline.rewards.sum()))
    # print('* offline data: {} [{} episodes] terminals sum : {}'.format(offline_name, len(data_offline),data_offline.terminals.sum()))
    


    print(data_expert.observations.shape)
    print(data_expert.observations[0])

    return data_expert, data_offline, env







    

    



if __name__ == "__main__":

    import argparse

    '''
    data_name = "hopper-expert-v2"

    print(data_name)

    dataset, env = get_d4rl2d3rl(data_name)

    print(len(dataset.episodes), "episodes")

    count = 0
    for item in dataset.episodes:
        count += len(item)
    
    print(count, ": transitions")
    '''

    # get_offline_imitation_data("hopper-expert-v2", "hopper-random-v2")
    # get_offline_imitation_data("maze2d-umaze-expert-v1", "maze2d-umaze-random-v1")
    get_offline_imitation_data("maze2d-umaze-expert-v1", "maze2d-umaze-v1")


