"""Racer environment."""
import xml.etree.ElementTree as etree

import deluca.core
import jax
import jax.numpy as jnp
import matplotlib
from matplotlib import animation
from matplotlib import path
from matplotlib.collections import PatchCollection
from matplotlib.patches import ArrowStyle
from matplotlib.patches import Circle
from matplotlib.patches import FancyArrowPatch
from matplotlib.patches import PathPatch
import matplotlib.pyplot as plt
import numpy as np
from six import StringIO

#import scripts.svgpath2mpl as svgpath2mpl
#from pyglib import resources
#from pyglib import gfile
matplotlib.rcParams['animation.ffmpeg_path'] = 'ffmpeg'

'''
class RacerViz:
    """Racer visualization."""

    def __init__(self,
               url='reachability-regret/car.svg',
               scale=5,
               facecolors='#565656',
               edgecolors='#191919',
               linewidths=1):
        text = gfile.GFile(url, 'r').read()
        tree = etree.parse(StringIO(text.decode('utf-8')))
        root = tree.getroot()
        path_elems = root.findall('.//{http://www.w3.org/2000/svg}path')

        paths = [svgpath2mpl.parse_path(elem.attrib['d']) for elem in path_elems]
        verts = np.concatenate([path.vertices for path in paths], axis=0)
        mean = verts.mean(axis=0)
        vert_max = verts.max(axis=0)
        vert_min = verts.min(axis=0)
        size = vert_max - vert_min
        self.verts = [
            np.matmul((path.vertices - mean) / size[1] * scale,
                      np.array([[0, 1], [-1, 0]])) for path in paths
        ]
        self.codes = [path.codes for path in paths]
        self.facecolors = [
            elem.attrib.get('fill', facecolors) for elem in path_elems
        ]
        self.edgecolors = [
            elem.attrib.get('stroke', edgecolors) for elem in path_elems
        ]
        self.linewidths = [
            elem.attrib.get('stroke_width', linewidths) for elem in path_elems
        ]

    def to_path_patch_array(self, shift=np.zeros(2)):
        verts = [vert + shift for vert in self.verts]
        return [
            PathPatch(path.Path(vert, code))
            for vert, code in zip(verts, self.codes)
        ]

'''

class RacerState(deluca.Obj):
    """Racer state.

    Environment is represented as rectangular field with `width` and `height`
    x-axis is width
    y-axis is height
    origin is bottom-left
    """
    # shape: (4, 1): x, y, x_dot, y_dot
    arr: jnp.array = deluca.field(
      jnp.array([0., 0., 0., 0.]).reshape(-1, 1), jaxed=True)

    # shape: (attributes, num_objects)
    # each instance: [pos_x, pos_y, vel_x, vel_y, r]
    obstacles: jnp.array = deluca.field(jaxed=True)
    obstacles_observed: jnp.ndarray = deluca.field(jaxed=True)
    obstacles_collided: jnp.ndarray = deluca.field(jaxed=True)
    ten_obstacles_observed: jnp.ndarray = deluca.field(jnp.zeros(10), jaxed=True)
    left: jnp.ndarray = deluca.field(jaxed=True)


class Racer(deluca.Env):
    """Racer environment class."""

    center: float = 0.
    width: int = 100
    height: int = 10000

    racer_radius: float = 0.
    racer_mass: float = 1.

    num_obstacles: int = 2000
    sensor_range: float = 25.
    init_clearance: float = 10.
    dt: float = 0.5
    y_dot: float = 1.0
        
    A: jnp.array = jnp.array([[1.-0.05*dt*dt, 0, dt*(1.-0.15*dt), 0], [0, 1, 0, dt], [-0.1*dt, 0, 1-0.3*dt, 0],
                            [0, 0, 0, 1]])
    B: jnp.array = jnp.array([[0.5 * dt * dt / racer_mass, 0],
                            [0, 0.005 * dt * dt / racer_mass],
                            [dt / racer_mass, 0], [0, 0.01*dt / racer_mass]])
    B_inv: jnp.array = jnp.array([[2.0, 0.], [0., 2.0]])  # pylint: disable=invalid-name

    def init(self, key=jax.random.PRNGKey(0), arr=None, obstacles=None):
        """Initialization function."""
        if obstacles is None:
            pos = jnp.matmul(
                jax.random.uniform(key, (self.num_obstacles, 2)),
                jnp.array([[self.width, 0], [0, self.height - self.init_clearance]]))

            # Bless our racer with init_clearance of free height and center.
            pos = pos + jnp.hstack(
              (jnp.ones((self.num_obstacles, 1)) *
               (-(self.width / 2 - self.center)), jnp.ones(
                   (self.num_obstacles, 1)) * self.init_clearance))
            obstacles = jnp.hstack((pos, jnp.zeros(
              (self.num_obstacles, 2)), jnp.ones((self.num_obstacles, 1))))

        obstacles_T = jnp.transpose(obstacles)  # pylint: disable=invalid-name

        # Racer start at x=self.center, y=0 with speed [0, y_dot]
        if arr is None:
            arr = jnp.array([self.center, 0., 0., self.y_dot]).reshape(-1, 1)

        obstacles_observed = self.get_observed(obstacles_T, arr)
        return RacerState(
            arr=arr,
            obstacles=obstacles_T,
            obstacles_observed=obstacles_observed,
            obstacles_collided=jnp.zeros(self.num_obstacles),
            ten_obstacles_observed=self.get_10_obstacles(obstacles_T,
                                                         obstacles_observed))

    def __call__(self, state, action, w=jnp.zeros((4, 1))):
        arr = jnp.matmul(self.A, state.arr) + jnp.matmul(self.B,
                                                     action) + w.reshape(-1, 1)
        x_pos = arr.at[0, :].get()
        arr = arr.at[0, :].set(
            jnp.clip(x_pos, self.center - self.width / 2,
                 self.center + self.width / 2))

        observed = self.get_observed(state.obstacles, arr)
        ten_obstacles = self.get_10_obstacles(state.obstacles, observed)

        pos, vel, r = jnp.split(state.obstacles, [2, 4], axis=0)
        pos = pos + self.dt * vel
        obstacles = jnp.concatenate((pos, vel, r), axis=0)

        new_state = state.replace(
            arr=arr,
            obstacles=obstacles,
            obstacles_observed=observed,
            ten_obstacles_observed=ten_obstacles)

        return new_state, new_state

    def get_observed(self, obstacles, arr):
        """get observed.

        Args:
          obstacles: (5, #objects)
          arr: (4, 1)

        Returns:
          boolean array.
        """
        obstacle_pos, _, obstacle_r = jnp.split(obstacles, [2, 4], axis=0)
        #print(obstacle_pos.shape)
        #print(obstacle_r.shape)
        racer_pos = jnp.take(arr, jnp.array((0, 1)), axis=0)
        #print(racer_pos.shape)
        dy = obstacle_pos[1,:]-racer_pos[1, 0] + 0.5 + obstacle_r[0,:]
        #print(dy.shape)
        distance = jnp.sum(jnp.square(obstacle_pos - racer_pos), axis=0)
        sensor_range = jnp.square(self.sensor_range + obstacle_r.squeeze())
        #print(distance.shape)
        
        tmp1 = jnp.where(distance <= sensor_range, jnp.ones(self.num_obstacles),
                         jnp.zeros(self.num_obstacles))
        tmp2 =jnp.where(dy >= -2., jnp.ones(self.num_obstacles),
                         jnp.zeros(self.num_obstacles))
        
        return jnp.minimum(tmp1, tmp2)

    def get_10_obstacles(self, obstacles, observed_bool_arr):
        observed_idx_arr = jnp.nonzero(observed_bool_arr, size=10, fill_value=0)[0]
        # TODO(dsuo): currently fill with first obstacles if less than 10 obstacles.
        return jnp.take(obstacles, observed_idx_arr, axis=1)

    def render(self,
             states,
             title='Racer',
             xlabel='X (m)',
             ylabel='Y (m)',
             pedestrian=False,
             racer_scale=5.0,
             obstacle_scale=4.0):
        #racer_viz = RacerViz()
        fig, ax = plt.subplots()

        obstacle_patches = PatchCollection([], cmap='rainbow')
        ax.add_collection(obstacle_patches)

        sensor_patch = PatchCollection([], alpha=0.4)
        ax.add_collection(sensor_patch)

        #racer_patch = PatchCollection([],
        #                              alpha=0.6,
        #                              edgecolors=racer_viz.edgecolors,
        #                              linewidths=racer_viz.linewidths,
        #                              facecolors=racer_viz.facecolors)
        
        #print(jnp.array(states.arr).shape)
        #tmpXY = np.array(jnp.take(jnp.asarray(states), jnp.array((0, 1)),
        #                            axis=0)).squeeze()
        racer_patch = PatchCollection([], alpha=0.8)
        ax.add_collection(racer_patch)

        ax.set_xlim((self.center - self.width / 2, self.center + self.width / 2))
        ax.set_ylim((0, self.sensor_range * 2))
        ax.set(title=title, xlabel=xlabel, ylabel=ylabel)

        def animate(n):
            # print(len(states[n]))
            state = states[n]
            # print(state)
            
            
            racer_pos = np.array(jnp.take(state.arr, jnp.array((0, 1)),
                                        axis=0)).squeeze()
            observed = np.array(self.get_observed(state.obstacles,
                                                state.arr)).squeeze()

            obstacle_pos, _, obstacle_r = jnp.split(state.obstacles, [2, 4], axis=0)
            obstacle_pos = obstacle_pos.squeeze().T
            obstacle_pos = np.array(obstacle_pos) - np.array([0., racer_pos[1]])
            obstacle_r = np.array(obstacle_r).squeeze().T
            obstacle_c = np.where(observed, [0.25] * self.num_obstacles,
                                [1.] * self.num_obstacles)

            if pedestrian:
                patches = []
                for pos, r, is_left in zip(obstacle_pos, obstacle_r, state.left):
                    left = (pos[0] - r / 2, pos[1])
                    right = (pos[0] + r / 2, pos[1])
                    patch = FancyArrowPatch(
                        posA=(right if is_left else left),
                        posB=(left if is_left else right),
                        arrowstyle='fancy',
                        mutation_scale=r * obstacle_scale)
                    patches.append(patch)
                    obstacle_patches.set_paths(patches)
            else:
                # print('Not pedestrian mode')
                obstacle_patches.set_paths(
                    [Circle(pos, radius=r) for pos, r in zip(obstacle_pos, obstacle_r)])
                obstacle_patches.set_array(obstacle_c)
                sensor_patch.set_paths(
                  [Circle((racer_pos[0], 0.), radius=self.sensor_range)])
                racer_patch.set_paths(
                  [Circle(np.array((racer_pos[0], 0.)), radius=1.0)])

            return (obstacle_patches, sensor_patch, racer_patch)

        plt.gca().set_aspect('equal')
        fig.set_size_inches(8, 8)
        return animation.FuncAnimation(
            fig, animate, frames=len(states), interval=100)


class PedestrianRacer(Racer):
    """Racer with pedestrian obstacles."""
    min_velocity: float = -1.
    max_velocity: float = 1.

    def init(self, key=jax.random.PRNGKey(0), arr=None, obstacles=None):
        """Initialization function."""
        state = super().init(key, arr, obstacles)

        # Give random x velocity to each obstacle
        obstacles = state.obstacles.at[2].set(
            jax.random.uniform(
                key=key,
                shape=(state.obstacles.shape[1],),
                minval=self.min_velocity,
                maxval=self.max_velocity))

        left = obstacles.at[2].get() < 0

        return state.replace(obstacles=obstacles, left=left)

    def __call__(self, state, action, w=jnp.zeros((4, 1))):
        state, _ = super().__call__(state, action, w)

        # If we hit a boundary, negate sign of obstacle's x velocity.
        lower = self.center - self.width / 2
        upper = self.center + self.width / 2

        x_pos = state.obstacles.at[0].get()
        x_vel = state.obstacles.at[2].get()

        new_x_vel = jnp.where(
            jnp.logical_or(x_pos > upper, x_pos < lower), -x_vel, x_vel)
        obstacles = state.obstacles.at[2].set(new_x_vel)
        left = obstacles.at[2].get() < 0

        new_state = state.replace(obstacles=obstacles, left=left)

        return new_state, new_state