from k_level_policy_gradients.src.core.environment import Environment, MDPInfo
from k_level_policy_gradients.src.utils.spaces import *
from k_level_policy_gradients.src.utils.viewer import Viewer


class Plug(Environment):
    def __init__(self, n_agents=1, horizon=100, gamma=0.99, bool_render=False):
        self._n_agents = n_agents
        self._target = 0
        self._bool_render = bool_render
        self._dt = 0.01

        self.action_space = [Box(-1, 1, shape=(1,)) for _ in range(self._n_agents)]
        state_space = Box(-np.inf, np.inf, shape=(1 * (self._n_agents),))
        observation_space = [
            Box(-np.inf, np.inf, shape=(1,)) for _ in range(self._n_agents)
        ]

        # Set the MDP info
        mdp_info = MDPInfo(
            state_space=state_space,
            observation_space=observation_space,
            action_space=self.action_space,
            discrete_actions=False,
            gamma=gamma,
            horizon=horizon,
            has_obs=True,
            has_action_masks=False,
            n_agents=self._n_agents,
        )

        super().__init__(mdp_info)

    def reset(self):
        self._state = np.array([0])

        obs = [self._state]

        step = {"state": self._state, "obs": obs, "info": {}}

        return step

    def step(self, actions):
        """
        Returns the next state, obs, reward, done, and info.

        Arguments:
            actions (np.ndarray): The actions to take in the environment.

        actions are 2D movements between -1 and 1.
        """
        rewards = []
        obs = []
        for i in range(self._n_agents):
            action = actions[i]
            reward = float(-np.sum(action) ** 2)
            rewards.append(reward)
            obs.append(self._state)

        step = {
            "state": self._state,
            "obs": obs,
            "rewards": rewards,
            "absorbing": True,
            "info": {},
        }

        return step
