import numpy as np
import gym
import torch

def to_torch(x, dtype=None, device=None):
    dtype = dtype or torch.float
    device = device or 'cuda:0'
    return torch.tensor(x, dtype=dtype, device=device)

def load_environment(name):
    # need to import d4rl to make env
    from src.utils.suppress import suppress_output
    with suppress_output():
        ## d4rl prints out a variety of warnings
        import d4rl

    with suppress_output():
        wrapped_env = gym.make(name)
    env = wrapped_env.unwrapped
    env.max_episode_steps = wrapped_env._max_episode_steps
    env.name = name
    return env

def qlearning_dataset_with_timeouts(env=None, dataset=None, terminate_on_end=False, esper=False, **kwargs):
    assert((env is not None) or (dataset is not None))
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    realdone_ = []
    esper_return_ = []
    
    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i]
        new_obs = dataset['observations'][i+1]
        action = dataset['actions'][i]
        reward = dataset['rewards'][i]
        done_bool = bool(dataset['terminals'][i])
        realdone_bool = bool(dataset['terminals'][i])
        final_timestep = dataset['timeouts'][i]
        if esper:
            esper_return = dataset['esper_return'][i]

        if i < N - 1:
            done_bool += dataset['timeouts'][i] #+1]

        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        realdone_.append(realdone_bool)
        if esper:
            esper_return_.append(esper_return)
        episode_step += 1

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_)[:,None],
        'terminals': np.array(done_)[:,None],
        'realterminals': np.array(realdone_)[:,None],
        'esper_returns': np.array(esper_return_)[:,None],
    }


def qlearning_dataset(env=None, dataset=None, terminate_on_end=False, esper=False, **kwargs):
    assert((env is not None) or (dataset is not None))
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    realdone_ = []
    esper_return_ = []

    for i in range(N):
        obs = dataset['observations'][i]
        new_obs = dataset['next_observations'][i]
        action = dataset['actions'][i]
        reward = dataset['rewards'][i]
        done_bool = bool(dataset['terminals'][i])
        realdone_bool = bool(dataset['terminals'][i])
        if esper:
            esper_return = dataset['esper_return'][i]

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        realdone_.append(realdone_bool)
        if esper:
            esper_return_.append(esper_return)

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_)[:, None],
        'terminals': np.array(done_)[:, None],
        'realterminals': np.array(realdone_)[:, None],
        'esper_returns': np.array(esper_return_)[:,None],
    }