import os
import collections
import numpy as np
import gym
import pdb

from contextlib import (
    contextmanager,
    redirect_stderr,
    redirect_stdout,
)

@contextmanager
def suppress_output():
    """
        A context manager that redirects stdout and stderr to devnull
        https://stackoverflow.com/a/52442331
    """
    with open(os.devnull, 'w') as fnull:
        with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
            yield (err, out)

with suppress_output():
    ## d4rl prints out a variety of warnings
    import d4rl

#-----------------------------------------------------------------------------#
#-------------------------------- general api --------------------------------#
#-----------------------------------------------------------------------------#

def load_environment(name):
    if type(name) != str:
        ## name is already an environment
        return name
    with suppress_output():
        wrapped_env = gym.make(name)
    env = wrapped_env.unwrapped
    env.max_episode_steps = wrapped_env._max_episode_steps
    env.name = name
    return env

def get_dataset(env):
    dataset = env.get_dataset()

    # if 'antmaze' in str(env).lower():
    #     ## the antmaze-v0 environments have a variety of bugs
    #     ## involving trajectory segmentation, so manually reset
    #     ## the terminal and timeout fields
    #     # dataset = antmaze_fix_timeouts(dataset)
    #     # dataset = antmaze_scale_rewards(dataset)
    #     # get_max_delta(dataset)
    #     dataset = antmaze_fix(dataset)

    return dataset

def sequence_dataset(env, preprocess_fn):
    """
    Returns an iterator through trajectories.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        An iterator through dictionaries with keys:
            observations
            actions
            rewards
            terminals
    """
    dataset = get_dataset(env)
    dataset = preprocess_fn(dataset)

    N = dataset['rewards'].shape[0]
    data_ = collections.defaultdict(list)

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = 'timeouts' in dataset

    episode_step = 0
    for i in range(N):
        done_bool = bool(dataset['terminals'][i])
        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)

        for k in dataset:
            if 'metadata' in k: continue
            data_[k].append(dataset[k][i])

        if done_bool or final_timestep:
            episode_step = 0
            episode_data = {}
            for k in data_:
                episode_data[k] = np.array(data_[k])
            if 'maze2d' in env.name:
                episode_data = process_maze2d_episode(episode_data)
            yield episode_data
            data_ = collections.defaultdict(list)

        episode_step += 1


#-----------------------------------------------------------------------------#
#-------------------------------- maze2d fixes -------------------------------#
#-----------------------------------------------------------------------------#

def process_maze2d_episode(episode):
    '''
        adds in `next_observations` field to episode
    '''
    assert 'next_observations' not in episode
    length = len(episode['observations'])
    next_observations = episode['observations'][1:].copy()
    for key, val in episode.items():
        episode[key] = val[:-1]
    episode['next_observations'] = next_observations
    return episode

def antmaze_fix(dset, max_path_length=1000):
    rew = dset['rewards']
    obs = dset['observations']
    actions = dset['actions']
    goal = dset['infos/goal']
    qpos = dset['infos/qpos']
    qvel = dset['infos/qvel']
    terminals = dset['terminals']
    timeouts = dset['timeouts']
    end_episode = (timeouts + terminals) > 0
    rew_reshape = []
    obs_reshape = []
    actions_reshape = []
    goal_reshape = []
    qpos_reshape = []
    qvel_reshape = []
    terminals_reshape = []
    timeouts_reshape = []

    rew_temp = np.zeros(max_path_length)
    obs_temp = np.zeros((max_path_length, obs.shape[1]))
    actions_temp = np.zeros((max_path_length, actions.shape[1]))
    goal_temp = np.zeros((max_path_length, goal.shape[1]))
    qpos_temp = np.zeros((max_path_length, qpos.shape[1]))
    qvel_temp = np.zeros((max_path_length, qvel.shape[1]))
    terminals_temp = np.zeros(max_path_length)
    timeouts_temp = np.zeros(max_path_length)
    idx = 0
    for i in range(rew.shape[0]):
        rew_temp[idx] = rew[i]
        obs_temp[idx, :] = obs[i, :]
        actions_temp[idx, :] = actions[i, :]
        goal_temp[idx, :] = goal[i, :]
        qpos_temp[idx, :] = qpos[i, :]
        qvel_temp[idx, :] = qvel[i, :]
        timeouts_temp[idx] = timeouts[i]
        terminals_temp[idx] = terminals[i]
        idx += 1
        if end_episode[i]:
            rew_reshape.append(rew_temp)
            obs_reshape.append(obs_temp)
            actions_reshape.append(actions_temp)
            goal_reshape.append(goal_temp)
            qpos_reshape.append(qpos_temp)
            qvel_reshape.append(qvel_temp)
            timeouts_reshape.append(timeouts_temp)
            terminals_reshape.append(terminals_temp)
            idx = 0

    dset['rewards'] = np.array(rew_reshape)
    dset['observations'] = np.array(obs_reshape)
    dset['actions'] = np.array(actions_reshape)
    dset['infos/goal'] = np.array(goal_reshape)
    dset['infos/qpos'] = np.array(qpos_reshape)
    dset['infos/qvel'] = np.array(qvel_reshape)
    dset['terminals'] = np.array(terminals_reshape)
    dset['timeouts'] = np.array(timeouts_reshape)

    return dset








