# from tools.base import Environment
# from tools.environments import DrivingEnvironment
# from tools.graphics import Canvas2D
# from tools.logic import RandomPolicy
# from tools.math import Direction2D, Box2D
# from tools.data import HighDSampleReader
# from tools.utils import combine_dicts
# import gym, tqdm
# import numpy as np
# import random
# from copy import deepcopy
# import gc

# class ExiDSampleEnvironment(DrivingEnvironment):
#     """
#     Environment based on ExiD dataset sample.
#     """

#     def __init__(self, static_elements, agents, timelimit, boxes, discrete=False):
#         """
#         Initialize ExiDSampleEnvironment.
#         """
#         self.canvas = Canvas2D(1000, 100,
#             static_elements, agents,
#             ox = 500, oy = 50, scale = 100/100, agentscale=2.
#         )
#         # double lane
#         self.canvas.set_lane_width(boxes[-1].x1, boxes[-1].x2, boxes[-1].y1, boxes[-1].y2, factor=2.)
#         act_space = gym.spaces.Box(
#                     low = np.array([-2, -1]),
#                     high = np.array([2, 1]),
#                 )
#         default_policy = lambda agent: RandomPolicy(act_space)
#         super().__init__(self.canvas, discrete=discrete, default_policy=default_policy)
#         # Specify reward structure
#         euclidean = lambda x1, y1, x2, y2: np.sqrt((x1-x2)**2+(y1-y2)**2)
#         norm = euclidean(boxes[0].center[0], boxes[0].center[1],
#                          boxes[1].center[0], boxes[1].center[1])
#         d = {
#             "caf_dist": lambda p, t: p['ego'].closest_agent_forward(
#                  p['ego'].agents_in_front_behind())['how_far'],
#             "caf_dist_parsed": lambda p, t: p["caf_dist"] if p["caf_dist"] >= 0. else 1000.
#         }
#         p = {
#             "stopped": lambda p, t: t > 0 and p['ego'].f['v'] == 0,
#             "near_goal": lambda p, t: boxes[1].inside(p['ego'].f['x'], p['ego'].f['y']),
#             "time_limit": lambda p, t: t >= timelimit,
#             "out_of_canvas": lambda p, t: p['ego'].ego_appeared and not (-500 <= p['ego'].f['x'] <= 500 and \
#                 -50 <= p['ego'].f['y'] <= 50),
#             "collision": lambda p, t: p['ego'].collided(gap=4.),
#         }
#         r = [
#             ["stopped", -0.1, 'satisfaction'],
#         ]
#         t = [
#             ["time_limit", \
#                 lambda p, t: -p['ego'].Lp(boxes[1].center[0], boxes[1].center[1])/norm, 'satisfaction'],
#             ["collision", \
#                 lambda p, t: -p['ego'].Lp(boxes[1].center[0], boxes[1].center[1])/norm, 'satisfaction'],
#             ["out_of_canvas", \
#                 lambda p, t: -p['ego'].Lp(boxes[1].center[0], boxes[1].center[1])/norm, 'satisfaction'],
#         ]
#         s = [
#             ['near_goal', 1, 'satisfaction'],
#         ]
#         self.reward_structure(d, p, r, t, s, round_to = 3)
#         # Empty state
#         ego_fn = lambda agent, rs: combine_dicts(agent.f.get_dict(), rs._p.get_dict(), rs._d.get_dict())
#         other_fn = lambda agent, rs: {}
#         self.specify_state(ego_fn, other_fn)
#         self.specify_action_multipliers([1, 0])
#         # Make ready
#         self.make_ready()
#         # Change car constants
#         for agent in self.agents:
#             agent.MAX_STEERING_ANGLE = np.pi/2
#             agent.THETA_DEVIATION_ALLOWED = np.pi

# class HighDSampleEnvironmentWrapper(Environment):
#     """
#     Wrapper around HighDSampleEnvironment.
#     """
#     def __init__(self, discrete=False, timelimit=1000, sequential=False):
#         """
#         Initialize HighDSampleEnvironmentWrapper.
#         """
#         self.timelimit = timelimit
#         self.discrete = discrete
#         self.reader = HighDSampleReader(dim=[1000, 100])
#         self.reader.read_data()
#         self.static_elements = [
#             ['StretchBackground', 'assets/highD/17_highway.png'],
#             ['Rectangle', -500, -400, -30, 0, (1, 1, 1, 0.4)],
#             ['Rectangle', 400, 500, -30, 0, (1, 1, 1, 0.4)],
#             ['Text', 't=0', -450, -45, (0, 0, 0, 1)], # Special text
#         ]
#         self.agents = []
#         self.boxes = [
#             Box2D(-500, -400, -30, 0, name="box0"),
#             Box2D(400, 500, -30, 0, name="box1"),
#             Box2D(-500, 500, -30, 0, name="box3"),
#         ]
#         self.possibleegoids = []
#         self.possibleego = lambda c, f: c[-1][0] > c[0][0] and\
#                     self.boxes[0].inside(c[0][0], c[0][1]) and\
#                     self.boxes[1].inside(c[-1][0], c[-1][1])
#         self.start_loc = self.reader.get_best_start(gap=self.timelimit, frames_len_max=self.timelimit)
#         assert(self.start_loc != None)
#         self.sequential = sequential
#         if sequential:
#             self.eid = 0
#         for i in range(len(self.reader.bboxes)):
#             centerpts, angles, frames, speedxs, speedys,\
#                 accxs, accys, h, w = self.reader.get_track(i)
#             if len(frames) > self.timelimit: continue
#             if frames[0] < self.start_loc or frames[-1] > self.start_loc+self.timelimit: continue
#             frames = np.array(frames)
#             if (not (np.isnan(centerpts).any() or np.isnan(angles).any() or \
#                 np.isnan(frames).any())) and self.boxes[0].inside(centerpts[0][0], centerpts[0][1]):
#                 frames -= self.start_loc
#                 frames = frames.astype(float)
#                 frames = np.floor(frames).astype(int)
#                 if self.possibleego(centerpts, frames):
#                     self.possibleegoids += [len(self.agents)]
#                 self.agents += [['VehicleDataHighD', centerpts, angles, frames, speedxs, \
#                     speedys, accxs, accys, None, None]]
#                 # print(np.max(np.sqrt(speedxs**2+speedys**2)), 
#                 #     np.max(np.sqrt(accxs**2+accys**2)))
#         print("%d agents, %d possible egos" % (len(self.agents), len(self.possibleegoids)))
#         self.choose_ego()
#         # self.env = HighDSampleEnvironment(self.static_elements, new_agents, timelimit=self.timelimit,
#         #     boxes=self.boxes, discrete=self.discrete)
#         # self.observation_space = self.env.observation_space
#         # self.action_space = self.env.action_space
#         self.env.reset()
#         for i in range(self.play_until):
#             ret = self.env.step(np.array([0, 0]))
#         # print("Played until", self.play_until)

#     def choose_ego(self):
#         """
#         Choose an ego id from list of ego ids.
#         """
#         if not hasattr(self, "envs"):
#             self.envs = {}
#             self.play_untils = {}
#             for egoid in self.possibleegoids:
#                 direction = '+x' if self.agents[egoid][1][-1][0] > self.agents[egoid][1][0][0] else '-x'
#                 ix = float(self.agents[egoid][1][0][0])
#                 iy = float(self.agents[egoid][1][0][1])
#                 iv = float(np.sqrt(self.agents[egoid][4][0]**2+self.agents[egoid][5][0]**2))
#                 ia = float(np.sqrt(self.agents[egoid][6][0]**2+self.agents[egoid][7][0]**2))
#                 # print(np.sqrt(agents[egoid][4]**2+agents[egoid][5]**2))
#                 # print(np.sqrt(agents[egoid][6]**2+agents[egoid][7]**2))
#                 new_agents = deepcopy(self.agents)
#                 self.play_untils[egoid] = self.agents[egoid][3][0]
#                 new_agents[egoid] = ['Ego', ix, iy, iv, Direction2D(mode = direction), self.agents[egoid][3][0], ia]
#                 # print(egoid, new_agents[egoid])
#                 env = HighDSampleEnvironment(self.static_elements, new_agents, timelimit=self.timelimit,
#                     boxes=self.boxes, discrete=self.discrete)
#                 self.observation_space = env.observation_space
#                 self.action_space = env.action_space
#                 self.envs[egoid] = env
#         if self.sequential:
#             egoid = self.possibleegoids[self.eid]
#             self.eid = (self.eid+1)%len(self.possibleegoids)
#         else:
#             egoid = random.choice(self.possibleegoids)
#         self.env = self.envs[egoid]
#         self.play_until = self.play_untils[egoid]

#     def step(self, action):
#         """
#         Step through the environment.
#         """
#         return self.env.step(action)

#     def seed(self, s=None):
#         """
#         Seed the environment.
#         """
#         return self.env.seed(s=s)

#     @property
#     def state(self):
#         """
#         Return the environment state.
#         """
#         return self.env.state

#     def reset(self, **kwargs):
#         """
#         Choose an ego id, recreate the environment and play until ego is starting.
#         """
#         self.choose_ego()
#         # self.env.canvas.close()
#         # del self.env
#         # gc.collect()
#         # self.env = HighDSampleEnvironment(self.static_elements, new_agents, timelimit=self.timelimit,
#         #     boxes=self.boxes, discrete=self.discrete)
#         # self.observation_space = self.env.observation_space
#         # self.action_space = self.env.action_space
#         ret = {"next_state": self.env.reset(**kwargs)}
#         for i in range(self.play_until):
#             ret = self.env.step(np.array([0, 0]))
#         # print("Played until", self.play_until)
#         return ret["next_state"]
    
#     def render(self, **kwargs):
#         """
#         Render the environment.
#         """
#         return self.env.render(**kwargs)


from tools.base import Environment
from tools.graphics import Plot2D
import copy
import glob
import os
import numpy as np
import random
from tools.utils import get_package_root_path
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from tools.exid import ExiDSampleReader, frame_by_frame
import pandas
from tools.exid import OSMReader
from tools.exid import load_trajectories, load_trajectories2, lane_change_start_end, \
    find_closest_node, find_segment
from shapely.geometry import Point, LineString
from gym import spaces

class ExiDSampleEnvironmentLateral(Environment):

    def __init__(self):
        self.startlaneletid = '1412'
        self.endlaneletid = '1411'
        self.dt = 1/25. # from meta file
        self.trajectories_pickle_file = 'traj.pt' # will create if does not exist
        self.er = ExiDSampleReader()
        self.er.read_data(lonlat=True)
        root_path = get_package_root_path()
        josmfile = glob.glob(os.path.join(root_path, 
            "assets/exiD/*.osm"))[0]
        recordingMeta_file = glob.glob(os.path.join(root_path, 
            "assets/exiD/*_recordingMeta.csv"))[0]
        recordingMeta = pandas.read_csv(recordingMeta_file)
        utmx, utmy = float(recordingMeta["xUtmOrigin"]), float(recordingMeta["yUtmOrigin"])
        self.osm = OSMReader(josmfile, utmx, utmy)
        self.startlaneletcenters = self.osm.get_relation(self.startlaneletid).mid
        self.endlaneletcenters = self.osm.get_relation(self.endlaneletid).mid
        self.frame_dict = frame_by_frame(self.er)
        self.trajs = load_trajectories(self.er, self.osm, self.trajectories_pickle_file)
        self.discrete = False
        high = float('inf')
        self.action_space = spaces.Box(-high, high, dtype=np.float32)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
    
    @property
    def state(self):
        """
        Return the current state.
        """
        return self.curr_state

    def reset(self):
        tid = None
        centerpts, headings, frames, vx, vy, ax, ay, llp, llco = [None,]*9
        while True:
            try:
                tid = self.trajs[np.random.randint(len(self.trajs))]
                centerpts, headings, frames, vx, vy, ax, ay, _, _, llp, llco = self.er.get_track(tid, llp=True)
                sf, ef = lane_change_start_end(self.osm, centerpts)
                closestnode, ipline = find_closest_node(self.startlaneletcenters, centerpts[sf][0], centerpts[sf][1])
            except:
                continue
            else:
                break
        self.data = {
            "tid": tid,
            "centerpts": centerpts,
            "headings": headings,
            "frames": frames,
            "vx": vx,
            "vy": vy,
            "ax": ax,
            "ay": ay,
            "llp": llp,
            "llco": llco,
        }
        startframe, endframe = lane_change_start_end(self.osm, self.data["centerpts"])
        self.data["startframe"], self.data["endframe"] = startframe, endframe
        closestnode, ipline = find_closest_node(self.startlaneletcenters, self.data["centerpts"][startframe][0], self.data["centerpts"][startframe][1])
        self.data["closestnode"], self.data["ipline"] = closestnode, ipline
        dr = (self.data["centerpts"][startframe][0]-closestnode.x, self.data["centerpts"][startframe][1]-closestnode.y)
        self.data["dr"] = dr
        self.data["lon_dist"], self.data["lat_dist"] = 0, 0
        self.bbox = None
        for vehicle, fid in self.frame_dict[self.data["frames"][self.data["startframe"]]]:
            if vehicle == tid:
                self.bbox = self.er.bboxes[vehicle][fid]
                break
        self.i = 0
        self.terminated = False
        self.max_i = endframe-startframe-1
        new_point = self.data["ipline"].interpolate(self.data["lon_dist"])
        segpt1, segpt2 = find_segment(self.data["ipline"], new_point)
        lonveclength = ((segpt1[0]-segpt2[0])**2+(segpt1[1]-segpt2[1])**2)**0.5
        dx = (segpt2[0]-segpt1[0])/lonveclength
        dy = (segpt2[1]-segpt1[1])/lonveclength
        dx *= -1; dy *= 1 # rotate unit
        newx = new_point.x+dy*self.data["lat_dist"]; newx += self.data["dr"][0]
        newy = new_point.y+dx*self.data["lat_dist"]; newy += self.data["dr"][1]
        targetpts = []
        for node in self.endlaneletcenters:
            targetpts += [(node.x, node.y)]
        ipline2 = LineString(targetpts)
        rightside = ipline2.buffer(100, single_sided=True)
        side = -1 if rightside.contains(Point(newx, newy)) else 1
        new_point2 = ipline2.interpolate(self.data["lon_dist"])
        dist_target = Point(newx, newy).distance(new_point2)
        self.curr_state = np.array([side*dist_target])
        return self.curr_state

    def seed(self, s=None):
        """
        Seed this environment.
        """
        random.seed(s)
        np.random.seed(s)

    def step(self, action=None):
        self.i += 1
        self.terminated = self.i > self.max_i
        new_point = self.data["ipline"].interpolate(self.data["lon_dist"])
        segpt1, segpt2 = find_segment(self.data["ipline"], new_point)
        lonveclength = ((segpt1[0]-segpt2[0])**2+(segpt1[1]-segpt2[1])**2)**0.5
        dx = (segpt2[0]-segpt1[0])/lonveclength
        dy = (segpt2[1]-segpt1[1])/lonveclength
        dc = (self.data["centerpts"][self.data["startframe"]+self.i, 0]-self.data["centerpts"][self.data["startframe"]+self.i-1, 0], self.data["centerpts"][self.data["startframe"]+self.i, 1]-self.data["centerpts"][self.data["startframe"]+self.i-1, 1])
        dclon = dc[0]*dx+dc[1]*dy
        self.data["lon_dist"] += dclon
        dx *= -1; dy *= 1 # rotate unit
        dclat = dc[0]*dy+dc[1]*dx
        info = {}
        if action is not None:
            action = float(action)
            dclat = action*self.dt
        else:
            info = {"action": dclat/self.dt}
        self.data["lat_dist"] += dclat
        newx = new_point.x+dy*self.data["lat_dist"]; newx += self.data["dr"][0]
        newy = new_point.y+dx*self.data["lat_dist"]; newy += self.data["dr"][1]
        targetpts = []
        for node in self.endlaneletcenters:
            targetpts += [(node.x, node.y)]
        ipline2 = LineString(targetpts)
        rightside = ipline2.buffer(100, single_sided=True)
        side = -1 if rightside.contains(Point(newx, newy)) else 1
        new_point2 = ipline2.interpolate(self.data["lon_dist"])
        dist_target = Point(newx, newy).distance(new_point2)
        self.curr_state = np.array([side*dist_target])
        return {
            "next_state": self.curr_state,
            "reward": 1-dist_target/10. if dist_target <= 10. else 0.,
            "done": self.terminated, 
            "info": info,
        }

    def render(self, **kwargs):
        new_point = self.data["ipline"].interpolate(self.data["lon_dist"])
        segpt1, segpt2 = find_segment(self.data["ipline"], new_point)
        lonveclength = ((segpt1[0]-segpt2[0])**2+(segpt1[1]-segpt2[1])**2)**0.5
        dx = (segpt2[0]-segpt1[0])/lonveclength
        dy = (segpt2[1]-segpt1[1])/lonveclength
        dx *= -1; dy *= 1 # rotate unit
        newx = new_point.x+dy*self.data["lat_dist"]; newx += self.data["dr"][0]
        newy = new_point.y+dx*self.data["lat_dist"]; newy += self.data["dr"][1]
        dp = (newx-self.data["centerpts"][self.data["startframe"]][0], \
            newy-self.data["centerpts"][self.data["startframe"]][1])
        new_bbox = copy.deepcopy(self.bbox)
        new_bbox[:, 0] += dp[0]
        new_bbox[:, 1] += dp[1]
        self.display_bbox = new_bbox
        if not hasattr(self, "plot"):
            self.plot = Plot2D({
                "env": lambda p, l, t: self,
            }, [
                [
                    lambda p, l, t: not l["env"].terminated,
                    lambda p, l, o, t: p.polygon(l["env"].display_bbox, o=o, facecolor="brown")
                ],
            ], mode="dynamic", interval=50)
            self.plot.ax.axis("equal")
            self.osm.plot({self.startlaneletid:'cyan', self.endlaneletid:'green'}, 
                show_all=False, ax=self.plot.ax)
        self.plot.show(block=False)
        if "mode" in kwargs.keys() and kwargs["mode"] == "rgb_array":
            self.plot.fig.canvas.draw()
            img = np.frombuffer(self.plot.fig.canvas.tostring_rgb(), dtype=np.uint8)
            img = img.reshape(self.plot.fig.canvas.get_width_height()[::-1] + (3,))
            return img

class ExiDSampleEnvironmentLateral2(Environment):

    def __init__(self):
        self.dt = 1/25. # from meta file
        self.ers = []
        DATA_DIR = os.path.expanduser(
            "~/Projects/Datasets/exiD/exiD-dataset-v2.0/data/")
        MAPS_DIR = os.path.expanduser(
            "~/Projects/Datasets/exiD/exiD-dataset-v2.0/maps/lanelet2/")
        nums = [item.split("/")[-1].split("_")[0] \
            for item in glob.glob(DATA_DIR+"*_tracks.csv")]
        for num in nums[:5]:
            t = DATA_DIR + num + "_tracks.csv"
            tm = DATA_DIR + num + "_tracksMeta.csv"
            rm = DATA_DIR + num + "_recordingMeta.csv"
            rmf = pandas.read_csv(rm)
            mapid = "%d" % int(rmf["locationId"])
            osmf = glob.glob(MAPS_DIR + mapid + "*.osm")[0]
            utmx, utmy = float(rmf["xUtmOrigin"]), float(rmf["yUtmOrigin"])
            osm = OSMReader(osmf, utmx, utmy)
            er = ExiDSampleReader(files = {"t": t, "tm": tm, "rm": rm})
            er.read_data(lonlat=True)
            lane_change_data = load_trajectories2(er, osm, "exidtraj/%s.pt" % num, \
                cond1 = lambda x, y: True,
                cond2 = lambda x, y: True)
            self.ers += [(er, osm, lane_change_data)]
        self.discrete = False
        high = float('inf')
        self.action_space = spaces.Box(-high, high, dtype=np.float32)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
        self.allowed_ers = {}
    
    @property
    def state(self):
        """
        Return the current state.
        """
        return self.curr_state

    def reset(self):
        erid = np.random.randint(len(self.ers))
        self.er, self.osm, self.lcd = self.ers[erid]
        if erid not in self.allowed_ers.keys():
            self.allowed_ers[erid] = list(range(len(self.lcd)))
        erid2 = np.random.choice(self.allowed_ers[erid])
        tid, self.startlaneletid, self.endlaneletid, startframe, endframe = \
            self.lcd[erid2]
        self.startlaneletcenters = self.osm.get_relation(self.startlaneletid).mid
        self.endlaneletcenters = self.osm.get_relation(self.endlaneletid).mid
        self.frame_dict = frame_by_frame(self.er)
        centerpts, headings, frames, vx, vy, ax, ay, _, _, llp, llco = self.er.get_track(tid, llp=True)
        self.data = {
            "tid": tid,
            "centerpts": centerpts,
            "headings": headings,
            "frames": frames,
            "vx": vx,
            "vy": vy,
            "ax": ax,
            "ay": ay,
            "llp": llp,
            "llco": llco,
        }
        self.data["startframe"], self.data["endframe"] = startframe, endframe
        closestnode, ipline = find_closest_node(self.startlaneletcenters, self.data["centerpts"][startframe][0], self.data["centerpts"][startframe][1])
        self.data["closestnode"], self.data["ipline"] = closestnode, ipline
        dr = (self.data["centerpts"][startframe][0]-closestnode.x, self.data["centerpts"][startframe][1]-closestnode.y)
        self.data["dr"] = dr
        self.data["lon_dist"], self.data["lat_dist"] = 0, 0
        self.bbox = None
        for vehicle, fid in self.frame_dict[self.data["frames"][self.data["startframe"]]]:
            if vehicle == tid:
                self.bbox = self.er.bboxes[vehicle][fid]
                break
        self.i = 0
        self.terminated = False
        self.max_i = endframe-startframe-1
        new_point = self.data["ipline"].interpolate(self.data["lon_dist"])
        segpt1, segpt2 = find_segment(self.data["ipline"], new_point)
        lonveclength = ((segpt1[0]-segpt2[0])**2+(segpt1[1]-segpt2[1])**2)**0.5
        dx = (segpt2[0]-segpt1[0])/lonveclength
        dy = (segpt2[1]-segpt1[1])/lonveclength
        dx *= -1; dy *= 1 # rotate unit
        newx = new_point.x+dy*self.data["lat_dist"]; newx += self.data["dr"][0]
        newy = new_point.y+dx*self.data["lat_dist"]; newy += self.data["dr"][1]
        targetpts = []
        for node in self.endlaneletcenters:
            targetpts += [(node.x, node.y)]
        ipline2 = LineString(targetpts)
        rightside = ipline2.buffer(100, single_sided=True)
        side = -1 if rightside.contains(Point(newx, newy)) else 1
        new_point2 = ipline2.interpolate(self.data["lon_dist"])
        dist_target = Point(newx, newy).distance(new_point2)
        self.curr_state = np.array([side*dist_target])
        if dist_target > 6:
            self.allowed_ers[erid].remove(erid2)
            # print(["%s:%d" % (k, len(v)) for k, v in self.allowed_ers.items()])
            return self.reset()
        else:
            return self.curr_state

    def seed(self, s=None):
        """
        Seed this environment.
        """
        random.seed(s)
        np.random.seed(s)

    def step(self, action=None):
        self.i += 1
        self.terminated = self.i > self.max_i
        new_point = self.data["ipline"].interpolate(self.data["lon_dist"])
        segpt1, segpt2 = find_segment(self.data["ipline"], new_point)
        lonveclength = ((segpt1[0]-segpt2[0])**2+(segpt1[1]-segpt2[1])**2)**0.5
        dx = (segpt2[0]-segpt1[0])/lonveclength
        dy = (segpt2[1]-segpt1[1])/lonveclength
        dc = (self.data["centerpts"][self.data["startframe"]+self.i, 0]-self.data["centerpts"][self.data["startframe"]+self.i-1, 0], self.data["centerpts"][self.data["startframe"]+self.i, 1]-self.data["centerpts"][self.data["startframe"]+self.i-1, 1])
        dclon = dc[0]*dx+dc[1]*dy
        self.data["lon_dist"] += dclon
        dx *= -1; dy *= 1 # rotate unit
        dclat = dc[0]*dy+dc[1]*dx
        info = {}
        if action is not None:
            action = float(action)
            dclat = action*self.dt
        else:
            info = {"action": dclat/self.dt}
        self.data["lat_dist"] += dclat
        newx = new_point.x+dy*self.data["lat_dist"]; newx += self.data["dr"][0]
        newy = new_point.y+dx*self.data["lat_dist"]; newy += self.data["dr"][1]
        targetpts = []
        for node in self.endlaneletcenters:
            targetpts += [(node.x, node.y)]
        ipline2 = LineString(targetpts)
        rightside = ipline2.buffer(100, single_sided=True)
        side = -1 if rightside.contains(Point(newx, newy)) else 1
        new_point2 = ipline2.interpolate(self.data["lon_dist"])
        dist_target = Point(newx, newy).distance(new_point2)
        self.curr_state = np.array([side*dist_target])
        return {
            "next_state": self.curr_state,
            "reward": 1-dist_target/10. if dist_target <= 10. else 0.,
            "done": self.terminated, 
            "info": info,
        }

    def render(self, **kwargs):
        new_point = self.data["ipline"].interpolate(self.data["lon_dist"])
        segpt1, segpt2 = find_segment(self.data["ipline"], new_point)
        lonveclength = ((segpt1[0]-segpt2[0])**2+(segpt1[1]-segpt2[1])**2)**0.5
        dx = (segpt2[0]-segpt1[0])/lonveclength
        dy = (segpt2[1]-segpt1[1])/lonveclength
        dx *= -1; dy *= 1 # rotate unit
        newx = new_point.x+dy*self.data["lat_dist"]; newx += self.data["dr"][0]
        newy = new_point.y+dx*self.data["lat_dist"]; newy += self.data["dr"][1]
        dp = (newx-self.data["centerpts"][self.data["startframe"]][0], \
            newy-self.data["centerpts"][self.data["startframe"]][1])
        new_bbox = copy.deepcopy(self.bbox)
        new_bbox[:, 0] += dp[0]
        new_bbox[:, 1] += dp[1]
        self.display_bbox = new_bbox
        if not hasattr(self, "plot"):
            self.plot = Plot2D({
                "env": lambda p, l, t: self,
            }, [
                [
                    lambda p, l, t: not l["env"].terminated,
                    lambda p, l, o, t: p.polygon(l["env"].display_bbox, o=o, facecolor="brown")
                ],
            ], mode="dynamic", interval=50)
            self.plot.ax.axis("equal")
            self.osm.plot({self.startlaneletid:'cyan', self.endlaneletid:'green'}, 
                show_all=False, ax=self.plot.ax)
        self.plot.show(block=False)
        if "mode" in kwargs.keys() and kwargs["mode"] == "rgb_array":
            self.plot.fig.canvas.draw()
            img = np.frombuffer(self.plot.fig.canvas.tostring_rgb(), dtype=np.uint8)
            img = img.reshape(self.plot.fig.canvas.get_width_height()[::-1] + (3,))
            return img
