from gym import register, Env
from gym.spaces import Box
from .safe_env_spec import SafeEnv, interval_barrier
import numpy as np
import torch


class RealEnv(Env, SafeEnv):
    def __init__(self, s0=0.3, threshold=1, dt=0.001, barrier=0.7):
        super().__init__()
        self.s0 = np.array([s0])
        self.threshold = threshold
        self.dt = dt
        self.barrier = barrier
        self.state = None
        self.observation_space = Box(-10, 10, [1])
        self.action_space = Box(-1, 1, [1])

    def step(self, action: np.ndarray):
        self.state = self.state + self.state * (abs(self.state) - self.barrier) * self.dt
        done = abs(self.state[0]) > self.threshold
        info = {}
        if done:
            info['episode.unsafe'] = True
        return self.state, 1, done, info

    def reset(self):
        self.state = self.s0
        return self.state

    def is_state_safe(self, states):
        return states[..., 0].abs() <= self.threshold

    def barrier_fn(self, states):
        return interval_barrier(states[..., 0], -self.threshold, self.threshold)

    def reward_fn(self, states, actions, next_states):
        return torch.full_like(actions, 1.)

    def done_fn(self, states, actions, next_states):
        return torch.full_like(actions, False, dtype=torch.bool)

    def trans_fn(self, states, actions):
        return states + states * (states.abs() - self.barrier) * self.dt

    def render(self, mode='human'):
        pass


register('SafeReal-v1', entry_point=RealEnv, max_episode_steps=1000)
