import numpy as np
import gymnasium as gym
from gymnasium.envs.registration import register
from safety_gym.envs.mujoco import Engine
import gymnasium.spaces
from ldba import make_automaton
from custom_fetch import CustomMujocoFetchPushEnv, MujocoFetchReachEnv

import matplotlib.pyplot as plt
from typing import TYPE_CHECKING, Optional
import gymnasium.spaces as spaces

import functools
def plotlive(func):
    plt.ion()

    @functools.wraps(func)
    def new_func(*args, **kwargs):

        # Clear all axes in the current figure.
        if isinstance(args[0], np.ndarray):
            axes = args[0]
        else:
            axes = np.atleast_1d(args[0]._mdp.ax)

        for axis in axes:
            axis.cla()

        # Call func to plot something
        result = func(*args, **kwargs)

        # Draw the plot
        plt.draw()
        plt.pause(0.001)

        return result

    return new_func 

def make_env(env_id, gamma, kwargs):
    def thunk():
        env = gym.make(env_id, **kwargs)
        env = gym.wrappers.FlattenObservation(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        return env
    return thunk


class CustomVectorEnv(gym.vector.SyncVectorEnv):

    def reset_wait(self, seed=None, options=None):
        observations, infos = super().reset_wait(seed, options)
        infos = {k: (np.stack(v, 0) if isinstance(v, np.ndarray) and v.dtype==object else v) for k, v in infos.items()}
        infos['truncated'] = np.array([False for _ in self.envs])
        return observations, infos

    def step_wait(self):
        observations, rewards, terminated, truncated, infos = super().step_wait()
        infos = {k: (np.stack(v, 0) if isinstance(v, np.ndarray) and v.dtype==object else v) for k, v in infos.items()}
        infos['truncated'] = truncated
        return observations, rewards, terminated, truncated, infos


class GridBase(gym.Env):

    def __init__(self, zone_size=7, allow_rest=True, **kwargs):
        self.zone_size = zone_size
        self.clip_range = (self.zone_size*3 + 2) // 2
        self.pos = np.array([0, 0], dtype=np.float32)
        self.action_space = gym.spaces.Discrete((5 if allow_rest else 4))
        self.observation_space = gym.spaces.Box(low=np.array([-1e2, -1e2]), high=np.array([1e2, 1e2]))
        self.initial_pos = (0, 0)

    def reset(self, seed=None, options=None):
        self.pos = np.array(self.initial_pos, dtype=np.float32)
        self.t = 0
        return self.pos.copy(), {}
            
    def step(self, action):
        action = np.array([
            [1,0],
            [-1,0],
            [0,1],
            [0,-1],
            [0,0]
        ][action], dtype=np.float32)
        self.pos = np.clip(self.pos + action, -self.clip_range, self.clip_range)
        self.t += 1
        return self.pos.copy(), 0., False, self.t >= self.zone_size * 20, {}


class GridCircular(GridBase):

    def __init__(self, wall=True, time_coeff=1, **kwargs):
        super().__init__(**kwargs)
        self.wall = wall
        self.time_coeff = time_coeff
        self.initial_pos = (0, -self.clip_range)
        self.center_ap = {
            'a':  np.array([0., -self.zone_size]),
            'b': np.array([self.zone_size, 0.]),
            'c': np.array([0., self.zone_size]),
            'd': np.array([-self.zone_size, 0.]),
            'z': np.array([0., 0.])
        }
    
    def step(self, action):
        action = np.array([
            [1,0],
            [-1,0],
            [0,1],
            [0,-1],
            [0,0]
        ][action], dtype=np.float32)
        new_pos = self.pos + action
        if self.wall and np.abs(new_pos).max() < self.zone_size / 2:
            new_pos = self.pos
        self.pos = np.clip(new_pos, -self.clip_range, self.clip_range)
        self.t += 1
        return self.pos.copy(), 0., False, self.t >= (self.zone_size+1)*9*self.time_coeff*2, {}


class GridSequential(GridBase):

    def __init__(self, make_linear=True, time_coeff=1, **kwargs):
        super().__init__(**kwargs)
        self.time_coeff = time_coeff
        if make_linear:
            self.clip_range = (self.zone_size*9 + 2) // 2
            self.initial_pos = (-self.zone_size*4, 0)
            self.center_ap = {
                'a': np.array([-self.zone_size*3, 0.]),
                'b': np.array([-self.zone_size*2, 0.]),
                'c': np.array([-self.zone_size, 0.]),
                'd': np.array([0., 0.]),
                'e': np.array([self.zone_size, 0.]),
                'f': np.array([self.zone_size*2, 0.]),
                'g': np.array([self.zone_size*3, 0.]),
                'h': np.array([self.zone_size*4, 0.]),
            }
        else:
            self.initial_pos = (0, 0)
            self.center_ap = {
                'a': np.array([self.zone_size, 0.]),
                'b': np.array([self.zone_size, -self.zone_size]),
                'c': np.array([0., -self.zone_size]),
                'd': np.array([-self.zone_size, -self.zone_size]),
                'e': np.array([-self.zone_size, 0.]),
                'f': np.array([-self.zone_size, self.zone_size]),
                'g': np.array([0, self.zone_size]),
                'h': np.array([self.zone_size, self.zone_size]),
            }
    def step(self, action):
        obs, rew, terminated, _, info = super().step(action)
        return obs, rew, terminated, self.t >= (self.zone_size+1)*9*self.time_coeff, info


class GridAvoid(GridBase):

    def __init__(self, corridor_width=1, extra_time=10, **kwargs):
        super().__init__(**kwargs)
        self.clip_range = self.zone_size * 3
        self.initial_pos = (0, 0)
        self.is_success = lambda x: x[0] > self.zone_size
        self.is_failure = lambda x: (np.abs(x[1]) > (corridor_width-1)) or (x[0] < -(corridor_width-1))
        self.extra_time = extra_time

    def step(self, action):
        obs, rew, terminated, _, info = super().step(action)
        return obs, rew, terminated, self.t >= self.zone_size + self.extra_time, info


class LTLEnv(gym.Env):

    def __init__(self, formula, **kwargs):
        self._mdp = self._make_mdp(**kwargs)
        self._ldba = make_automaton(formula)
        self.aps = self._ldba.aps
        self.observation_space = self._mdp.observation_space
        if isinstance(self._mdp.action_space, gym.spaces.Discrete):
            self.is_discrete = True
            self.mdp_action_size = self._mdp.action_space.n
            act_size = self.mdp_action_size + (0 if self.get_n_jumps() == 1 else self.get_n_jumps())
            self.action_space = gymnasium.spaces.Discrete(act_size)
        else:
            self.is_discrete = False
            self.mdp_action_size = np.prod(self._mdp.action_space.shape)
            act_size = self.mdp_action_size + self.get_n_jumps()
            self.action_space = gymnasium.spaces.Box(low=-np.ones(act_size), high=np.ones(act_size), dtype=np.float64)
        self.buffer = None, None
        self.n_accept = 0.

    def _make_mdp(self, **kwargs):
        raise NotImplementedError('Need to construct inner MDP.')
    
    def reset(self, seed=None, options=None):
        mdp_obs, info = self._mdp.reset(seed, options)
        info['ap'] = self.evaluate_ap()
        self.buffer = mdp_obs, info
        self.n_accept = 0.
        info['ldba_obs'], _ = self._ldba.reset(info['ap'])
        info['n_accept'] = self.n_accept / 10.
        return mdp_obs, info

    def step(self, action):
        if self.is_discrete:
            mdp_action = min(action.item(), self.mdp_action_size - 1)
            jump_id = max(0, action.item() - self.mdp_action_size)
        else:
            mdp_action = action[:self.mdp_action_size]
            jump_id = np.argmax(action[self.mdp_action_size:])
        if jump_id != 0 and self._ldba.curr_state in self._ldba.eps:  # jump!
            # epsilon transition
            self._mdp.t = self._mdp.t + 1
            mdp_obs, info = self.buffer
            terminated, truncated = False, False

            info['ldba_obs'] = self._ldba.epsilon_step(jump_id - 1)
            accepting = False
        else:
            mdp_obs, _, terminated, truncated, info = self._mdp.step(mdp_action)
            info['ap'] = self.evaluate_ap()
            info['ldba_obs'], accepting = self._ldba.step(info['ap'], jump_id=0)

        self.n_accept += float(accepting)
        info['n_accept'] = self.n_accept / 10.
        self.buffer = mdp_obs, info
        return mdp_obs, float(accepting), terminated, truncated, info
    
    def evaluate_ap(self):
        raise NotImplementedError

    def render(self):
        return self._mdp.render()
    
    def get_jump_mask(self):
        return self._ldba.get_jump_mask()

    def get_n_jumps(self):
        return self._ldba.get_n_jumps()
    
    def get_graph(self):
        return self._ldba.get_graph()

    def get_ldba(self):
        return self._ldba


class LTLWorldEnv(LTLEnv):

    def _get_hazard_locations(self):
        return [(0, 0), (0, -self.zone_size), (self.zone_size, 0), (0, self.zone_size), (-self.zone_size, 0.)]

    def _get_config(self, robot_name):
        return {
            'num_steps': self.num_steps,
            'robot_base': f'xmls/{robot_name.lower()}.xml',
            'sensors_obs': ['accelerometer', 'velocimeter', 'gyro', 'magnetometer'] if self.include_sensors else [],
            'observe_goal_lidar': False,
            'observe_box_lidar': False,
            'robot_locations': [(0,-self.zone_size-0.5)],
            'task': 'none',
            'hazards_num': 5,
            'hazards_size': [self.zone_size/2]*5,
            'hazards_color': [np.array([1., 0., 0., 1.])] + [np.array([0., 0., (i+1)/4, 1.]) for i in range(4)],
            'hazards_keepout': 0.01,
            'constrain_hazards': False,
            'observe_hazards': self.include_lidar,
            'hazards_locations': [np.array(h) for h in self._get_hazard_locations()],
            'robot_rot': np.pi if robot_name=='Car' else -np.pi/2,
            'observe_qpos': self.include_clean_obs,  # Observe the qpos of the world
            'observe_qvel': self.include_clean_obs,  # Observe the qvel of the robot
            'observe_ctrl': False,  # Observe the previous action
            'observe_freejoint': False,  # Observe base robot free joint
            'observe_com': self.include_clean_obs,  # Observe the center of mass of the robot
        }

    def _make_mdp(self, robot_name, zone_size=7, include_clean_obs=False, include_sensors=True,
                  include_lidar = True, num_steps=500, corridor_width=2.0, make_linear=True, config={}, **kwargs):
        assert robot_name in ('Point', 'Car', 'Doggo')
        self.zone_size = zone_size / 5
        self.robot_name = robot_name
        self.include_clean_obs = include_clean_obs
        self.include_sensors = include_sensors
        self.include_lidar = include_lidar
        self.num_steps = num_steps
        self.corridor_width = corridor_width
        self.make_linear = make_linear
        config = self._get_config(robot_name)
        if robot_name == 'Doggo':
            config['sensors_obs'] += ['touch_ankle_1a', 'touch_ankle_2a', 'touch_ankle_3a', 
                'touch_ankle_4a', 'touch_ankle_1b', 'touch_ankle_2b', 'touch_ankle_3b', 'touch_ankle_4b']
        if robot_name == 'Car':
            config.update({
                'box_size': 0.125,  # Box half-radius size
                'box_keepout': 0.125,  # Box keepout radius for placement
                'box_density': 0.0005,
            })
        return Engine(config)

    def step(self, action):
        scaled_action = action.copy()
        mdp_a = scaled_action[:self.mdp_action_size]
        alpha = {'Point': 0.0, 'Car': 0.2, 'Doggo': 0.0}[self.robot_name]
        scaled_action[:self.mdp_action_size] = np.where(mdp_a < alpha, (mdp_a-alpha)/(1+alpha), (mdp_a+alpha)/(1-alpha))
        return super().step(scaled_action)

    def evaluate_ap(self):
        aps = []
        for hazard, zone_ap in zip(self._get_hazard_locations(), 'zabcd'):
            if np.linalg.norm(self._mdp.world.robot_pos()[:2] - hazard, ord=2) <= self.zone_size / 2:
                aps.append(zone_ap)
        return np.array([ap in aps for ap in self.aps])


class LTLAvoidEnv(LTLWorldEnv):

    def _get_hazard_locations(self):
        return []

    def _get_config(self, robot_name):
        config = super()._get_config(robot_name)
        config.update({
            'robot_locations': [(0,0)],
            'hazards_num': 0,
            'hazards_size': [],
            'hazards_color': [],
            'robot_rot': np.pi/2 if robot_name=='Car' else 0.,
        })
        return config

    def render(self):
        viewer = self._mdp.viewer
        if viewer:
            viewer.add_marker(
                pos=np.array([0.0, 0.0, 0.0]),
                size=np.array([3.5, 3.5, 0.005]),
                rgba=np.fromstring("0.5 0.0 0.0 0.4", dtype=np.float32, sep=" "),
                type=6, label="",
            )
            viewer.add_marker(
                    pos=np.array([0.5*self.zone_size-0.1, 0.0, 0.0]),
                    size=np.array([0.5*self.zone_size+0.1, 0.2, 0.01]),
                    rgba=np.fromstring("0.0 0.0 1.0 1.0", dtype=np.float32, sep=" "),
                    type=6, label="",
                )
            viewer.add_marker(
                pos=np.array([1.75+0.5*self.zone_size, 0.0, 0.0]),
                size=np.array([1.75-0.5*self.zone_size, 0.2, 0.01]),
                rgba=np.fromstring("0.0 1.0 0.0 1.0", dtype=np.float32, sep=" "),
                type=6, label="",
            )
        return self._mdp.render()

    def evaluate_ap(self):
        aps = []
        if self._mdp.world.robot_pos()[0] > self.zone_size: aps.append('a')
        if (np.abs(self._mdp.world.robot_pos()[1]) > self.corridor_width) or (self._mdp.world.robot_pos()[0] < -0.2): aps.append('z')
        return np.array([ap in aps for ap in self.aps])


class LTLNavigateEnv(LTLWorldEnv):

    def _get_hazard_locations(self):
        if self.make_linear:
            return [((i+1)*self.zone_size, 0.) for i in range(8)]
        else:
            return [(self.zone_size, 0.), (self.zone_size, -self.zone_size), (0., -self.zone_size),
                    (-self.zone_size, -self.zone_size), (-self.zone_size, 0.), (-self.zone_size, self.zone_size),
                    (0, self.zone_size), (self.zone_size, self.zone_size)]

    def _get_config(self, robot_name):
        config = super()._get_config(robot_name)
        config.update({
            'robot_locations': [(0,0)],
            'hazards_num': 8,
            'hazards_size': [self.zone_size/2]*8,
            'hazards_color': [np.array([0., 0., (i+1)/8, 1.]) for i in range(8)],
            'robot_rot': np.pi/2 if robot_name=='Car' else 0.,
        })
        return config

    def evaluate_ap(self):
        aps = []
        for hazard, zone_ap in zip(self._get_hazard_locations(), 'abcdefgh'):
            if np.linalg.norm(self._mdp.world.robot_pos()[:2] - hazard, ord=2) <= self.zone_size / 2:
                aps.append(zone_ap)
        return np.array([ap in aps for ap in self.aps])


class LTLGridworldEnv(LTLEnv):

    def _make_mdp(self, **kwargs):
        return GridCircular(**kwargs)

    def evaluate_ap(self):
        aps = []
        for zone_ap in ['a', 'b', 'c', 'd', 'z']:
            if np.abs(self._mdp.pos - self._mdp.center_ap[zone_ap]).max() <= self._mdp.zone_size / 2:
                aps.append(zone_ap)
        return np.array([ap in aps for ap in self.aps])


class LTLGridspiralEnv(LTLEnv):

    def _make_mdp(self, **kwargs):
        return GridSequential(**kwargs)

    def evaluate_ap(self):
        aps = []
        for zone_ap in ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']:
            if np.abs(self._mdp.pos - self._mdp.center_ap[zone_ap]).max() <= self._mdp.zone_size / 2:
                aps.append(zone_ap)
        return np.array([ap in aps for ap in self.aps])


class LTLGridavoidEnv(LTLEnv):

    def _make_mdp(self, **kwargs):
        return GridAvoid(**kwargs)

    def evaluate_ap(self):
        aps = []
        if self._mdp.is_success(self._mdp.pos): aps.append('a')
        if self._mdp.is_failure(self._mdp.pos): aps.append('z')
        return np.array([ap in aps for ap in self.aps])


class LTLFetchEnv(LTLEnv):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.observation_space = self._mdp.observation_space['observation']

    def reset(self, seed=None, options=None):
        mdp_obs, info = self._mdp.reset()
        self.n_step = 0
        info['ap'] = self.evaluate_ap()
        self.buffer = mdp_obs, info
        self.n_accept = 0.
        info['ldba_obs'], _ = self._ldba.reset(info['ap'])
        info['n_accept'] = self.n_accept / 10.
        return mdp_obs['observation'], info

    def step(self, action):
        obs, accepting, terminated, truncated, info = super().step(action)
        self.n_step += 1
        return obs['observation'], accepting, terminated, (truncated or (self.n_step >= 50)), info

    def render(self, **kwargs):
        self._mdp.render_mode='human'
        return self._mdp.render()


class LTLFetchAvoidEnv(LTLFetchEnv):

    def _make_mdp(self, **kwargs):
        self.initial_xpos = np.array([1.34185368, 0.74910045, 0.5547072])
        return MujocoFetchReachEnv()

    def evaluate_ap(self):
        aps = []
        xpos = self._mdp._get_gripper_xpos()
        if xpos[0] > 1.54: aps.append('a')
        if (np.abs(xpos[1]-self.initial_xpos[1]) > 0.06) or (xpos[0] < self.initial_xpos[0]-0.06): aps.append('z')
        return np.array([ap in aps for ap in self.aps])


class LTLFetchAlignEnv(LTLFetchEnv):

    def _make_mdp(self, **kwargs):
        self.n_cubes = 3
        return CustomMujocoFetchPushEnv(n_cubes=self.n_cubes)

    def evaluate_ap(self):
        aps = []
        for i, zone_ap in enumerate(['a', 'b', 'c', 'd'][:self.n_cubes]):
            target_pos = self._mdp.initial_cube_position[i] + np.array([0.1, 0.])
            xpos = self._mdp._utils.get_site_xpos(self._mdp.model, self._mdp.data, f"object{i}")
            if np.abs(xpos[:2] - target_pos).max() <= 0.04:
                aps.append(zone_ap)
        return np.array([ap in aps for ap in self.aps])

class LTLHalfCheetahRoundEnv(LTLEnv):

    def reset(self, seed=None, options=None):
        mdp_obs, info = self._mdp.reset()
        info['ap'] = self.evaluate_ap()
        self.buffer = mdp_obs, info
        self.n_accept = 0.
        info['ldba_obs'], _ = self._ldba.reset(info['ap'])
        self.ldba_obs = info['ldba_obs']
        info['n_accept'] = self.n_accept / 10.
        return mdp_obs, info

    def _make_mdp(self, render_mode=None, **kwargs):
        return gym.make('HalfCheetah-v4', render_mode=render_mode)

    def evaluate_ap(self):
        aps = []
        angle = np.copy(self._mdp.unwrapped.data.qpos[2])
        angle = angle - (np.floor(angle/(np.pi*2)) * np.pi * 2)  # angle between 0 and 2pi
        if angle > 1.7*np.pi/4 and  angle < 3.3*np.pi/4: 
            aps.append('b')
        # 2_1218
        if angle > 6*np.pi/4 and  angle < 6.3*np.pi/4: 
        # 2_1219
        # if angle > 6.2*np.pi/4 and  angle < 6.3*np.pi/4: 
            aps.append('d')
        # pdb.set_trace()
        return np.array([ap in aps for ap in self.aps])


from CARLO.world import World
from CARLO.agents import Car, RingBuilding, CircleBuilding, Painting, Pedestrian
from CARLO.geometry import Point, Line

import gymnasium.spaces as spaces

class CarloEnv:
    def __init__(self, continuous_actions=True):
        dt = 0.1 # time steps in terms of seconds. In other words, 1/dt is the FPS.
        world_width = 120/2 # in meters
        world_height = 120/2
        self.inner_building_radius = 30/2
        num_lanes = 2
        self.lane_marker_width = 0.5
        num_of_lane_markers = 52/2
        self.lane_width = 6# 3.5
        self.continuous_actions = continuous_actions
        self.ppm = 20

        w = World(dt, width = world_width, height = world_height, ppm = self.ppm) # The world is 120 meters by 120 meters. ppm is the pixels per meter.

        # Let's add some sidewalks and RectangleBuildings.
        # A Painting object is a rectangle that the vehicles cannot collide with. So we use them for the sidewalks / zebra crossings / or creating lanes.
        # A CircleBuilding or RingBuilding object is also static -- they do not move. But as opposed to Painting, they can be collided with.

        # To create a circular road, we will add a CircleBuilding and then a RingBuilding around it
        cb = CircleBuilding(Point(world_width/2, world_height/2), self.inner_building_radius, 'gray80')
        w.add(cb)
        rb = RingBuilding(Point(world_width/2, world_height/2), self.inner_building_radius + num_lanes * self.lane_width + (num_lanes - 1) * self.lane_marker_width, 1+np.sqrt((world_width/2)**2 + (world_height/2)**2), 'gray80')
        w.add(rb)

        # w.add(CircleBuilding(Point(72.5, 107.5), Point(95, 25))) 


        # Let's also add some lane markers on the ground. This is just decorative. Because, why not.
        self.waypoints = []
        for lane_no in range(num_lanes - 1):
            lane_markers_radius = self.inner_building_radius + (lane_no + 1) * self.lane_width + (lane_no + 0.5) * self.lane_marker_width
            lane_marker_height = np.sqrt(2*(lane_markers_radius**2)*(1-np.cos((2*np.pi)/(2*num_of_lane_markers)))) # approximate the circle with a polygon and then use cosine theorem
            for i, theta in enumerate(np.arange(0, 2*np.pi, 2*np.pi / num_of_lane_markers)):
                dx = lane_markers_radius * np.cos(theta)
                dy = lane_markers_radius * np.sin(theta)
                if (i % 13 == 0) & (i > -1):
                    self.waypoints.append(
                        CircleBuilding(Point(world_width/2 + dx + (0 * self.lane_width) * np.cos(theta) , world_height/2 + dy + (0 * self.lane_width) * np.sin(theta)), 5, 'blue')
                        # Painting(Point(world_width/2 + dx + (-.5 * self.lane_width) * np.cos(theta) , world_height/2 + dy + (-.5 * self.lane_width) * np.sin(theta)), Point(self.lane_marker_width, lane_marker_height), 'red', heading = theta)
                        )
                    wp = self.waypoints[-1]
                    wp.collidable = False
                    w.add(wp)
                # w.add(Painting(Point(world_width/2 + dx, world_height/2 + dy), Point(self.lane_marker_width, lane_marker_height), 'white', heading = theta))
        
        # Let's also add some lane markers on the ground. This is just decorative. Because, why not.
        # self.waypoints = self.waypoints[[]]

        # A Car object is a dynamic object -- it can move. We construct it using its center location and heading angle.
        self.world_width = world_width
        self.world_height = world_height
        self.dt = dt
        self.world = w
        self.reset()
        self.world.render() # This visualizes the world we just constructed.

        # gym environment specific variables
        if continuous_actions:
            self.action_space = spaces.Box(-1., 1., shape=(2,), dtype=np.float32)
        else:
            # up, right, down, left, nothing
            self.action_space = spaces.Discrete(30)

        
        self.observation_space = spaces.Box(-np.inf, np.inf, shape=self.get_state().shape, dtype='float32')

    def reset(self, seed=None, options=None):
        try:
            self.world.visualizer.close()
        except:
            pass

        self.world.reset()
        
        self.current_wp = 0
        self.t = 0
        xs = np.linspace(np.pi/2, 2*np.pi+np.pi/2, 10)
        theta = np.random.choice(xs) % (np.pi*2)
        lane_markers_radius = self.inner_building_radius + (0 + 1) * self.lane_width + (0 + 0.5) * self.lane_marker_width
        dx = lane_markers_radius * np.cos(theta)
        dy = lane_markers_radius * np.sin(theta)
        x = self.world_width/2 + dx
        y = self.world_height/2 + dy

        c1 = Car(Point(x, y), theta + np.pi/2)
        c1.max_speed = 30.0 # let's say the maximum is 30 m/s (108 km/h)
        c1.min_speed = 3.0 # let's say the maximum is 30 m/s (108 km/h)
        c1.velocity = Point(0.0, 0.0)
        self.agent = c1
        self.world.add(c1)

        self.state = self.get_state()
        return self.state, {}
    
    def unnormalize(self, states):
        return states * np.array([self.world_width, self.world_height, self.agent.max_speed, self.agent.max_speed, 2*np.pi]) 

    def get_state(self):
        return np.array([self.agent.x, self.agent.y, self.agent.xp, self.agent.yp, self.agent.heading]) / np.array([self.world_width, self.world_height, self.agent.max_speed, self.agent.max_speed, 2*np.pi]) 

    # def render(self, states = [], save_dir=None, save_states=True):
        
    #     if not self.world.headless:
    #         self.world.render()

    #         ppm = self.world.visualizer.ppm
    #         dh = self.world.visualizer.display_height
    #         if len(states):
    #             for (x1,y1), (x2, y2) in zip(np.array(states)[:-1, 0:2], np.array(states)[1:, 0:2]):
    #                 # self.world.add(Line(Point(x1,y1), Point(x2, y2)))
    #                 self.world.visualizer.win.plot_line_(ppm*x1, dh - ppm*y1, ppm*x2, dh - ppm*y2, fill='red', width='2')
    #             # coords = np.array(states)[0:2]
    #             # coords[:, 0] *= ppm
    #             # coords[:, 1] = dh - ppm*coords[:, 1]
    #             # self.world.visualizer.win.plot_line(coords.T.flatten().tolist(), color='green')

    #         self.world.render()

    #         if save_dir is not None:
    #             self.world.remove_agents()
    #             self.world.visualizer.save_fig(save_dir + '.png')
        
    #     if save_states:
    #         np.save(save_dir + '.npy', np.array(states))

    def step(self, action):
        u = action
        self.t += 1
        self.agent.set_control(u[0], u[1])
        # self.agent.set_control(1, .2)
        self.world.tick() # This ticks the world for one time step (dt second)

        cost = np.linalg.norm(u)
        terminated = self.world.collision_exists()
        self.state = self.get_state()
        # if (np.linalg.norm(self.state[2:4]*self.agent.max_speed)+1e-7) < self.agent.min_speed:
        #     import pdb; pdb.set_trace()
        # if self.distance_to_waypoints(self.state) < 2:

        return self.state, cost, False, self.t >= 500, {}
        
class LTLCarloEnv(LTLEnv):

    def _make_mdp(self, **kwargs):
        return CarloEnv()
    
    def distance_to_waypoints(self, state):
        new_state = state * np.array([self._mdp.world_width, self._mdp.world_height, self._mdp.agent.max_speed, self._mdp.agent.max_speed, 2*np.pi]) 
        return np.array([np.linalg.norm([new_state[0] - wp.x, new_state[1] - wp.y]) for wp in self._mdp.waypoints])

    def evaluate_ap(self):
        aps = []
        state = self._mdp.get_state()
        for idx, distance in enumerate(self.distance_to_waypoints(state)):
            if distance <= 5:
                key = 'wp_%s' % idx
                aps.append(key)
        if self._mdp.world.collision_exists():
            aps.append('crash')
        return np.array([ap in aps for ap in self.aps])
    
    def myrender(self, states = [], save_dir=None):
        # pdb.set_trace()
        if not self._mdp.world.headless:
            self._mdp.world.render()

            ppm = self._mdp.world.visualizer.ppm
            dh = self._mdp.world.visualizer.display_height
            if len(states):
                for (x1,y1), (x2, y2) in zip(np.array(states)[:-1, 0:2], np.array(states)[1:, 0:2]):
                    # self.world.add(Line(Point(x1,y1), Point(x2, y2)))
                    self._mdp.world.visualizer.win.plot_line_(ppm*x1, dh - ppm*y1, ppm*x2, dh - ppm*y2, fill='red', width='2')
                # coords = np.array(states)[0:2]
                # coords[:, 0] *= ppm
                # coords[:, 1] = dh - ppm*coords[:, 1]
                # self.world.visualizer.win.plot_line(coords.T.flatten().tolist(), color='green')

            self._mdp.world.render()

            if save_dir is not None:
                self._mdp.world.remove_agents()
                self._mdp.world.visualizer.save_fig(save_dir)
    
    def unnormalize(self, states):
        try:
            return self._mdp.unnormalize(states)
        except:
            return states

class FlatWorld(gym.Env):
    
    def __init__(
        self,
        render_mode: Optional[str] = "human",
    ):
        
        low = np.array([-2, 2]).astype(np.float32)
        high = np.array([-2, 2]).astype(np.float32)

        # useful range is -1 .. +1, but spikes can be higher
        self.observation_space = spaces.Box(low, high)

        self.action_space = spaces.Box(-1, +1, (2,), dtype=np.float32)
            

        # self.obs_1 = np.array([0.0, 0.9, -1.0, -0.5])     # red box in bottom right corner
        # self.obs_2 = np.array([.2, 0.7, 0.8, 1.2])        # green box in top right corner
        # self.obs_3 = np.array([0.0, 0.0, 0.4])            # blue circle in the center
        # self.obs_4 = np.array([-1.0, -0.7, -0.2, 0.5])    # orange box on the left

        # self.goal = np.array([1, 1, .2])
        self.obs_1 = np.array([.9/2, -1.5/2., .3])     # red box in bottom right corner
        self.obs_2 = np.array([.9/2, 1., .3])        # green box in top right corner
        self.obs_3 = np.array([0.0, 0.0, 0.8])            # blue circle in the center
        self.obs_4 = np.array([-1.7/2, .3/2, .3])    # orange box on the left
        
        self.circles_2 = [(self.obs_1, 'r'), (self.obs_2, 'g'), (self.obs_4, 'y'), (self.obs_3, 'b')]
        self.circles = [(self.obs_1, 'r'), (self.obs_4, 'y'), (self.obs_3, 'b')]
        # self.circles = [(self.obs_1, 'r'), (self.obs_2, 'y'), (self.obs_3, 'b')]
        self.paths = [(self.obs_4, 'y'), (self.obs_3, 'b')]

        self.state = np.array([-1, -1])
        self.render_mode = render_mode
        self.fig, self.ax = plt.subplots(1, 1)
        
        
    def reset(self, seed=None, options=None):
        self.state = np.array([-1, -1])
        self.t = 0
        return self.state.copy(), {}
    
    def get_state(self):
        return self.state

    def step(self, action):
        u = action
        self.t += 1
        Δt = .4
        A = np.eye(2)
        B = np.eye(2) * Δt
        # action = np.clip(u, -1, +1).astype(np.float32)
        action = u

        self.state = A @ self.state.reshape(2, 1) + B @ action.reshape(2, 1)
        self.state = self.state.reshape(-1)
        cost = np.linalg.norm(action)
        terminated = False

        return self.state, cost, False, self.t >= 50, {}

class LTLFlatPathEnv(LTLEnv):

    def _make_mdp(self, **kwargs):
        return FlatWorld(**kwargs)
    
    def evaluate_ap(self):
        aps = []
        state = self._mdp.get_state()
        for circle, color in self._mdp.paths:
            val = np.linalg.norm(state - circle[:-1])
            if val < circle[-1]:
                aps.append(color)
        return np.array([ap in aps for ap in self.aps])
    
    @plotlive
    def myrender(self, states = [], save_dir=None):
        # if self._mdp.render_mode is None:
        #     gym.logger.warn(
        #         "You are calling render method without specifying any render mode. "
        #         "You can specify the render_mode at initialization, "
        #         f'e.g. gym("{self.spec.id}", render_mode="rgb_array")'
        #     )
        #     return

        # plot the environment given the obstacles
        # plt.figure(figsize=(10,10))
        for obs, color in self._mdp.paths:           
            # theta = np.linspace( 0 , 2 * np.pi , 150 )
 
            # radius = obs[2]
            
            # x = radius * np.cos( theta ) + obs[0]
            # y = radius * np.sin( theta ) + obs[1]
 
            # # x, y = [obs[0] + obs[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [obs[1] + obs[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)]
            # self.ax.plot(x, y, c=color, linewidth=2)

            patch = plt.Circle((obs[0], obs[1]), obs[2], color=color, fill=True, alpha=.2)
            # # x, y = [obs[0] + obs[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [obs[1] + obs[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)]
            # # self.ax.plot(x, y, c=color, linewidth=2)
            self._mdp.ax.add_patch(patch)
        
        # for obs, color in [(self.goal, 'orange')]:
        #     self.ax.plot([obs[0] + obs[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [obs[1] + obs[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)], c=color, linewidth=2)
        
        # plt.plot([self.obs_1[0], self.obs_1[0], self.obs_1[1], self.obs_1[1], self.obs_1[0]], [self.obs_1[2], self.obs_1[3], self.obs_1[3], self.obs_1[2], self.obs_1[2]], c="red", linewidth=5)
        # plt.plot([self.obs_2[0], self.obs_2[0], self.obs_2[1], self.obs_2[1], self.obs_2[0]], [self.obs_2[2], self.obs_2[3], self.obs_2[3], self.obs_2[2], self.obs_2[2]], c="green", linewidth=5)
        # plt.plot([self.obs_4[0], self.obs_4[0], self.obs_4[1], self.obs_4[1], self.obs_4[0]], [self.obs_4[2], self.obs_4[3], self.obs_4[3], self.obs_4[2], self.obs_4[2]], c="orange", linewidth=5)
        # plt.plot([self.obs_3[0] + self.obs_3[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [self.obs_3[1] + self.obs_3[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)], c="blue", linewidth=5)

        # for state in states:
        #     self.ax.scatter([state[0]], [state[1]], s=100, marker='-', c="g")
        self._mdp.ax.plot(np.array(states)[:, 0], np.array(states)[:, 1], color='green', marker='o', linestyle='dashed',
            linewidth=2, markersize=4)

        self._mdp.ax.scatter([self._mdp.state[0]], [self._mdp.state[1]], s=100, marker='o', c="g")


        # self.ax.scatter([self.goal[0]], [self.goal[1]], s=20, marker='*', c="orange")
        self._mdp.ax.axis('square')
        self._mdp.ax.set_xlim([-2, 2])
        self._mdp.ax.set_ylim([-2, 2])
        # pdb.set_trace()
        if save_dir is not None:
            self._mdp.fig.savefig(save_dir)

class LTLFlatCycleworldEnv(LTLEnv):

    def _make_mdp(self, **kwargs):
        return FlatWorld(**kwargs)
    
    def evaluate_ap(self):
        aps = []
        state = self._mdp.get_state()
        for circle, color in self._mdp.circles:
            val = np.linalg.norm(state - circle[:-1])
            if val < circle[-1]:
                aps.append(color)
        return np.array([ap in aps for ap in self.aps])
    
    @plotlive
    def myrender(self, states = [], save_dir=None):
        # if self._mdp.render_mode is None:
        #     gym.logger.warn(
        #         "You are calling render method without specifying any render mode. "
        #         "You can specify the render_mode at initialization, "
        #         f'e.g. gym("{self.spec.id}", render_mode="rgb_array")'
        #     )
        #     return

        # plot the environment given the obstacles
        # plt.figure(figsize=(10,10))
        for obs, color in self._mdp.circles:           
            # theta = np.linspace( 0 , 2 * np.pi , 150 )
 
            # radius = obs[2]
            
            # x = radius * np.cos( theta ) + obs[0]
            # y = radius * np.sin( theta ) + obs[1]
 
            # # x, y = [obs[0] + obs[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [obs[1] + obs[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)]
            # self.ax.plot(x, y, c=color, linewidth=2)

            patch = plt.Circle((obs[0], obs[1]), obs[2], color=color, fill=True, alpha=.2)
            # # x, y = [obs[0] + obs[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [obs[1] + obs[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)]
            # # self.ax.plot(x, y, c=color, linewidth=2)
            self._mdp.ax.add_patch(patch)
        
        # for obs, color in [(self.goal, 'orange')]:
        #     self.ax.plot([obs[0] + obs[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [obs[1] + obs[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)], c=color, linewidth=2)
        
        # plt.plot([self.obs_1[0], self.obs_1[0], self.obs_1[1], self.obs_1[1], self.obs_1[0]], [self.obs_1[2], self.obs_1[3], self.obs_1[3], self.obs_1[2], self.obs_1[2]], c="red", linewidth=5)
        # plt.plot([self.obs_2[0], self.obs_2[0], self.obs_2[1], self.obs_2[1], self.obs_2[0]], [self.obs_2[2], self.obs_2[3], self.obs_2[3], self.obs_2[2], self.obs_2[2]], c="green", linewidth=5)
        # plt.plot([self.obs_4[0], self.obs_4[0], self.obs_4[1], self.obs_4[1], self.obs_4[0]], [self.obs_4[2], self.obs_4[3], self.obs_4[3], self.obs_4[2], self.obs_4[2]], c="orange", linewidth=5)
        # plt.plot([self.obs_3[0] + self.obs_3[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [self.obs_3[1] + self.obs_3[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)], c="blue", linewidth=5)

        # for state in states:
        #     self.ax.scatter([state[0]], [state[1]], s=100, marker='-', c="g")
        self._mdp.ax.plot(np.array(states)[:, 0], np.array(states)[:, 1], color='green', marker='o', linestyle='dashed',
            linewidth=2, markersize=4)

        self._mdp.ax.scatter([self._mdp.state[0]], [self._mdp.state[1]], s=100, marker='o', c="g")


        # self.ax.scatter([self.goal[0]], [self.goal[1]], s=20, marker='*', c="orange")
        self._mdp.ax.axis('square')
        self._mdp.ax.set_xlim([-2, 2])
        self._mdp.ax.set_ylim([-2, 2])
        # pdb.set_trace()
        if save_dir is not None:
            self._mdp.fig.savefig(save_dir)

class LTLFlatRoundEnv(LTLEnv):

    def _make_mdp(self, **kwargs):
        return FlatWorld(**kwargs)
    
    def evaluate_ap(self):
        aps = []
        state = self._mdp.get_state()
        for circle, color in self._mdp.circles_2:
            val = np.linalg.norm(state - circle[:-1])
            if val < circle[-1]:
                aps.append(color)
        return np.array([ap in aps for ap in self.aps])
    
    @plotlive
    def myrender(self, states = [], save_dir=None):
        # if self._mdp.render_mode is None:
        #     gym.logger.warn(
        #         "You are calling render method without specifying any render mode. "
        #         "You can specify the render_mode at initialization, "
        #         f'e.g. gym("{self.spec.id}", render_mode="rgb_array")'
        #     )
        #     return

        # plot the environment given the obstacles
        # plt.figure(figsize=(10,10))
        # pdb.set_trace()
        for obs, color in self._mdp.circles_2:           
            # theta = np.linspace( 0 , 2 * np.pi , 150 )
 
            # radius = obs[2]
            
            # x = radius * np.cos( theta ) + obs[0]
            # y = radius * np.sin( theta ) + obs[1]
 
            # # x, y = [obs[0] + obs[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [obs[1] + obs[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)]
            # self.ax.plot(x, y, c=color, linewidth=2)

            patch = plt.Circle((obs[0], obs[1]), obs[2], color=color, fill=True, alpha=.2)
            # # x, y = [obs[0] + obs[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [obs[1] + obs[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)]
            # # self.ax.plot(x, y, c=color, linewidth=2)
            self._mdp.ax.add_patch(patch)
        
        # for obs, color in [(self.goal, 'orange')]:
        #     self.ax.plot([obs[0] + obs[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [obs[1] + obs[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)], c=color, linewidth=2)
        
        # plt.plot([self.obs_1[0], self.obs_1[0], self.obs_1[1], self.obs_1[1], self.obs_1[0]], [self.obs_1[2], self.obs_1[3], self.obs_1[3], self.obs_1[2], self.obs_1[2]], c="red", linewidth=5)
        # plt.plot([self.obs_2[0], self.obs_2[0], self.obs_2[1], self.obs_2[1], self.obs_2[0]], [self.obs_2[2], self.obs_2[3], self.obs_2[3], self.obs_2[2], self.obs_2[2]], c="green", linewidth=5)
        # plt.plot([self.obs_4[0], self.obs_4[0], self.obs_4[1], self.obs_4[1], self.obs_4[0]], [self.obs_4[2], self.obs_4[3], self.obs_4[3], self.obs_4[2], self.obs_4[2]], c="orange", linewidth=5)
        # plt.plot([self.obs_3[0] + self.obs_3[2]*np.cos(t) for t in np.arange(0,3*np.pi,0.1)], [self.obs_3[1] + self.obs_3[2]*np.sin(t) for t in np.arange(0,3*np.pi,0.1)], c="blue", linewidth=5)

        # for state in states:
        #     self.ax.scatter([state[0]], [state[1]], s=100, marker='-', c="g")
        self._mdp.ax.plot(np.array(states)[:, 0], np.array(states)[:, 1], color='green', marker='o', linestyle='dashed',
            linewidth=2, markersize=4)

        self._mdp.ax.scatter([self._mdp.state[0]], [self._mdp.state[1]], s=100, marker='o', c="g")


        # self.ax.scatter([self.goal[0]], [self.goal[1]], s=20, marker='*', c="orange")
        self._mdp.ax.axis('square')
        self._mdp.ax.set_xlim([-2, 2])
        self._mdp.ax.set_ylim([-2, 2])
        # pdb.set_trace()
        if save_dir is not None:
            self._mdp.fig.savefig(save_dir)

ROBOT_NAMES = ('Point', 'Car', 'Doggo')
for robot_name in ROBOT_NAMES:
    register(id=f'LTLNavigate-{robot_name}-v0', entry_point=LTLNavigateEnv, kwargs={'robot_name': robot_name})
    register(id=f'LTLAvoid-{robot_name}-v0', entry_point=LTLAvoidEnv, kwargs={'robot_name': robot_name})
register(id=f'LTLGridCircular-v0', entry_point=LTLGridworldEnv)
register(id=f'LTLGridSequential-v0', entry_point=LTLGridspiralEnv)
register(id=f'LTLGridAvoid-v0', entry_point=LTLGridavoidEnv)
register(id=f'LTLFetchAvoid-v0', entry_point=LTLFetchAvoidEnv)
register(id=f'LTLFetchAlign-v0', entry_point=LTLFetchAlignEnv)

register(id=f'LTLCarlo-v0', entry_point=LTLCarloEnv)
register(id=f'LTLCheetahFrontround-v0', entry_point=LTLHalfCheetahRoundEnv)

register(id=f'LTLFlatCycleWorld-v0', entry_point=LTLFlatCycleworldEnv)
register(id=f'LTLFlatPath-v0', entry_point=LTLFlatPathEnv)
register(id=f'LTLFlatRound-v0', entry_point=LTLFlatRoundEnv)