import numpy as np
from sklearn.utils.validation import check_random_state
# from rampwf.utils.importing import import_module_from_source

from gym.envs.classic_control import AcrobotEnv

# reward_module = import_module_from_source(
#     'reward_function.py', 'reward_function')
# reward_func = reward_module.reward_func


def reward_func(observations):
    """Computation of the reward from the observations of the acrobot env.

    The observations are also the ones predicted by the model.
    For compatibility with other environments the actions should be appended to
    the inputs observations as the reward can be a function of both the actions
    and the observations.

    The original reward of acrobot is in [-2, 2]. We scale it to [0, 4]

    Parameters
    ----------
    observations : array, shape (n_samples, n_observations + n_actions)
        Observations and actions. The last feature is the action, which is not
        used here but put for compatibility with other environments.
        Note that this is the action leading to the obtained observations.

    Return
    ------
    reward : float
        Reward.
    """

    reward = 2 - (observations[:, 0] +
                  observations[:, 0] * observations[:, 2] -
                  observations[:, 1] * observations[:, 3])

    return reward


def safety_func(observations):
    """
    Computes a  safety cost between the current position (theta_1)
    and  the unsafe region

    Parameters
    ----------
    states State of the Acrobot define as [theta_1, theta_2, theta_dot_1, theta_dot_2]

    Returns Cost as the distance from the center of the unsafe area
    -------

    """
    cost = 2 - (observations[:, 0] +
                  observations[:, 0] * observations[:, 2] -
                  observations[:, 1] * observations[:, 3])

    return cost > 3


class Env(AcrobotEnv):
    """Open AI Gym acrobot env with reward of benchmark paper.

    We also provide a seed method accepting instances of
    numpy.random.RandomState.

    Modify version of acrobot, we add a risky region that needs to be avoided
    The Unsafe area are defined in terms of theta_1 angle in range [-pi,pi]
    """
    def __init__(self, max_episode_steps=200):
        self.max_episode_steps = max_episode_steps
        super(Env, self).__init__()

        self.dynamic_reset = False
        self.real_states_history = []
        self.unsafe_tip_limit = 3
        self.safety_func = safety_func

    def seed(self, seed=None):
        """Same as parent method but passing a RandomState instance is allowed.
        """
        self.np_random = check_random_state(seed)
        return [seed]

    def reset(self):
        """Same as parent method but resetting the number of elapsed steps."""
        if self.dynamic_reset:
            observations = self.real_states_history[
                self.np_random.choice(len(self.real_states_history))]
            self._elapsed_steps = 0
            return observations
        else:
            observations = super(Env, self).reset()
            self._elapsed_steps = 0
            return observations

    def step(self, action):
        """Same as parent method but different reward.
        We also consider that the task is never done.
        """
        observations, _, _, info = super(Env, self).step(action)

        self._elapsed_steps += 1
        reward = reward_func(np.r_[observations, action].reshape(1, -1))[0]
        # using >= in case we need the info when planning with the real env
        done = (self._elapsed_steps >= self.max_episode_steps)
        safety_cost = self.safety_func(np.r_[observations, action].reshape(1, -1))[0]
        info['cost'] = safety_cost

        return observations, reward, int(done), info

    def set_state(self, state_dict):
        """Set state of the environment."""
        self._elapsed_steps = state_dict['_elapsed_steps']
        self.state = np.r_[state_dict['qpos'], state_dict['qvel']]

    def get_state(self):
        """Get state of the environement."""
        state = self.state
        state_dict = {
            'qpos': state[:2],
            'qvel': state[2:],
            '_elapsed_steps': self._elapsed_steps,
        }
        return state_dict

    def get_numpy_state(self):
        """Get the state numpy array from the environment."""
        state_dict = self.get_state()
        return np.r_[state_dict['qpos'], state_dict['qvel']]

    def set_numpy_state(self, numpy_state):
        """Set the state from a numpy array.

        Note that the _elapsed_steps attribute is reset to 0.
        """
        self.state = numpy_state
        self._elapsed_steps = 0