from pybullet_envs.bullet import MinitaurBulletEnv, motor
import torch
from .safe_env_spec import SafeEnv
import numpy as np
import math
import gym

OBSERVATION_EPS = 0.01


class SafeMinitaurGymEnv(MinitaurBulletEnv, SafeEnv):
    terminated = False

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, env_randomizer=None)
        observation_high = (self._get_observation_upper_bound() + OBSERVATION_EPS)
        observation_low = -observation_high
        self.observation_space = gym.spaces.Box(observation_low, observation_high, dtype=np.float32)

    def _get_observation(self):
        observation = []
        observation.extend(self.minitaur.GetMotorAngles().tolist())
        observation.extend(self.minitaur.GetMotorVelocities().tolist())
        observation.extend(self.minitaur.GetMotorTorques().tolist())
        observation.extend(list(self.minitaur.GetBaseOrientation()))
        observation.extend([self.minitaur.GetBasePosition()[2]])  # only add height, so that we can compute barrier
        self._observation = observation
        return self._observation

    def reset(self):
        obs = super().reset()
        self.terminated = False
        return obs

    def step(self, action):
        breakpoint()
        assert self._observation_noise_stdev == 0.0
        next_obs, reward, done, info = super().step(action)
        return next_obs, reward, done, {'episode.unsafe': done}
    #     if self.terminated:
    #         return self._get_observation(), 0, True, {'episode.unsafe': True}
    #     next_obs, reward, done, info = super().step(action)
    #     self.terminated = done
    #     return next_obs, reward, False, {'episode.unsafe': False}

    def _get_observation_upper_bound(self):
        upper_bound = np.array([0.0] * (self.minitaur.GetObservationDimension() + 1))
        num_motors = self.minitaur.num_motors
        upper_bound[0:num_motors] = math.pi  # Joint angle.
        upper_bound[num_motors:2 * num_motors] = motor.MOTOR_SPEED_LIMIT  # Joint velocity.
        upper_bound[2 * num_motors:3 * num_motors] = motor.OBSERVED_TORQUE_LIMIT  # Joint torque.
        upper_bound[3 * num_motors:3 * num_motors + 4] = 1.0  # Quaternion of base orientation.
        upper_bound[3 * num_motors + 4:] = np.float('inf')
        return upper_bound

    def is_state_safe(self, states: torch.Tensor):
        return self.barrier_fn(states) <= 1

    def barrier_fn(self, states: torch.Tensor) -> torch.Tensor:
        quat = states[..., -4:-1]
        height = states[..., -1]

        quat_sqr = quat * quat
        dot = 1 - 2 * (quat_sqr[..., 0] + quat_sqr[..., 1]) / quat_sqr.sum(dim=-1)

        # The default distance_limit is inf
        # dist_sqr = pos[..., 0]**2 + pos[..., 1]**2

        def barrier(x, t):   # requirement: x >= t
            return (t - x).clamp(min=0) * 100 + 1

        return barrier(height, 0.13).maximum(barrier(dot, 0.85))


gym.register('SafeMinitaur-v0', entry_point=SafeMinitaurGymEnv, max_episode_steps=1000)
