import os
import collections
import numpy as np
import gym
import pdb
import h5py

from contextlib import (
    contextmanager,
    redirect_stderr,
    redirect_stdout,
)

def h5py_to_dict(path):
    data_dict = {}

    hdf5 = h5py.File(path, 'r')
    for k in hdf5.keys():
        if isinstance(hdf5[k], h5py.Dataset):
            data_dict[k] = hdf5[k][:]
        else:
            data_dict[k] = {}
            for m in hdf5[k].keys():
                data_dict[k][m] = hdf5[k][m]

    return data_dict

@contextmanager
def suppress_output():
    """
        A context manager that redirects stdout and stderr to devnull
        https://stackoverflow.com/a/52442331
    """
    with open(os.devnull, 'w') as fnull:
        with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
            yield (err, out)

with suppress_output():
    ## d4rl prints out a variety of warnings
    import d4rl

#-----------------------------------------------------------------------------#
#-------------------------------- general api --------------------------------#
#-----------------------------------------------------------------------------#

def load_environment(name):
    if type(name) != str:
        ## name is already an environment
        return name
    if 'unitree' in name:
        return None
    with suppress_output():
        wrapped_env = gym.make(name)
    env = wrapped_env.unwrapped
    env.max_episode_steps = wrapped_env._max_episode_steps
    env.name = name
    return env

def get_dataset(env):
    dataset = env.get_dataset()

    if 'antmaze' in str(env).lower():
        ## the antmaze-v0 environments have a variety of bugs
        ## involving trajectory segmentation, so manually reset
        ## the terminal and timeout fields
        dataset = antmaze_fix_timeouts(dataset)
        dataset = antmaze_scale_rewards(dataset)
        get_max_delta(dataset)

    return dataset

def sequence_dataset(env, preprocess_fn, dataset=None,po=0,occlude_start_idx=-2):
    """
    Returns an iterator through trajectories.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        An iterator through dictionaries with keys:
            observations
            actions
            rewards
            terminals
    """

        
    if(dataset == None):
        dataset = get_dataset(env)
    dataset = preprocess_fn(dataset)

    N = dataset['rewards'].shape[0]
    if 'gaits' in dataset.keys():
        print("BEFORE: ", dataset['gaits'].shape, dataset['observations'].shape)
        gaits = np.repeat(dataset['gaits'], int(N / dataset['gaits'].shape[0]), 0)
        dataset['gaits'] = gaits
        print("AFTER: ", dataset['gaits'].shape)
    data_ = collections.defaultdict(list)

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = 'timeouts' in dataset

    episode_step = 0
    for i in range(N):
        done_bool = bool(dataset['terminals'][i])
        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)

        for k in dataset:
            if 'metadata' in k: continue
            data_[k].append(dataset[k][i])

        if done_bool or final_timestep:
            episode_step = 0
            episode_data = {}
            for k in data_:
                #episode_data[k] = np.array(data_[k])
                if k == 'observations' and po==1:
                    obs = np.array(data_[k])
                    # Select only the observed features
                    obs = obs[:, :occlude_start_idx]
                    episode_data[k] = obs
                else:
                    episode_data[k] = np.array(data_[k])
                    
            if 'maze2d' in env.name:
                episode_data = process_maze2d_episode(episode_data)
            yield episode_data
            data_ = collections.defaultdict(list)

        episode_step += 1


#-----------------------------------------------------------------------------#
#-------------------------------- maze2d fixes -------------------------------#
#-----------------------------------------------------------------------------#

def process_maze2d_episode(episode):
    '''
        adds in `next_observations` field to episode
    '''
    assert 'next_observations' not in episode
    length = len(episode['observations'])
    next_observations = episode['observations'][1:].copy()
    for key, val in episode.items():
        episode[key] = val[:-1]
    episode['next_observations'] = next_observations
    return episode
