import os

import gym
from gym.utils import EzPickle
import numpy as np

import envs


def create_wrapper(extract_behavior_func):

    class ObservationWrapper(gym.Env, EzPickle):
        def __init__(self, env_name, **kwargs):
            EzPickle.__init__(self)
            self.env = gym.make(env_name, **kwargs)
            self.extract_behavior_func = extract_behavior_func
            self.action_space = self.env.action_space
            self.observation_space = self.env.observation_space
            self._max_episode_steps = self.env._max_episode_steps
            self.obs_memory = []

        def reset(self):
            obs = self.env.reset()
            self.obs_memory = [obs]
            return {"obs": obs, "behavior": self.extract_behavior_func(self.obs_memory)}

        def step(self, action):
            obs, reward, done, info = self.env.step(action)
            self.obs_memory.append(obs)
            behavior = self.extract_behavior_func(self.obs_memory, info)
            wrapped_obs = {"obs": obs, "behavior": behavior}
            return wrapped_obs, reward, done, info

        def render(self):
            self.env.render()

    return ObservationWrapper


def behavior_func_point_maze(obs_memory, info=False):
    last_obs = obs_memory[-1]
    return np.array(last_obs)


def behavior_func_point_maze_inertia(obs_memory, info=False):
    last_obs = obs_memory[-1]
    return np.array(last_obs)[:2]


def behavior_func_ant(obs_memory, info=False):
    if info:
        return np.array([info["x_position"], info["y_position"]])
    else:
        return np.array([0, 0])


def behavior_func_humanoid(obs_memory, info=False):
    if info:
        return np.array([info["x_position"], info["y_position"]])
    else:
        return np.array([0, 0])