from gym import register
from gym.envs.mujoco.inverted_pendulum import InvertedPendulumEnv
from .safe_env_spec import SafeEnv, interval_barrier
import numpy as np

import torch
import torch.nn as nn

from safe.barrier import Barrier

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def build_lyapunov():
    assert False

U = None


class SafeInvertedPendulumEnv2(InvertedPendulumEnv, SafeEnv):
    def reset_model(self):
        self.set_state(self.init_qpos, self.init_qvel)
        return self._get_obs()

    def _get_obs(self):
        return super()._get_obs().astype(np.float32)

    def step(self, a):
        global U
        if U is None:
            U = build_lyapunov()

        s = self._get_obs()
        next_state, _, done, info = super().step(a)
        assert self.action_space.low[0] == -3 and self.action_space.high[0] == 3
        # breakpoint()
        # HACK: if self.observation_space is None, then we're initialization. mujoco_py use a random action and
        # assert done
        if self.observation_space is not None and \
                U(torch.from_numpy(s).to(device), torch.from_numpy(a / 3).to(device)).item() > 1:
            return next_state, -1000, True, {'U': True}
        reward = (next_state[0]**2 + next_state[1]**2) + a[0]**2 * 0.01
        if done:
            reward = -1000
            info['episode.unsafe'] = True
            done = True
        return next_state, reward, done, info

    def is_state_safe(self, states):
        return states[..., 1].abs() <= 0.2

    def barrier_fn(self, states):
        return interval_barrier(states[..., 1], -0.2, 0.2)

    def reward_fn(self, states, actions, next_states):
        return -(next_states[..., 0]**2 + next_states[..., 1]**2) - actions[..., 0]**2 * 0.01


register('SafeInvertedPendulum-v3', entry_point=SafeInvertedPendulumEnv2, max_episode_steps=1000)
