"""Produce environments with different obstacle configurations."""

import jax
import jax.numpy as jnp
import scripts.racer as racer


def get_centerline(num_obstacles=200,
                   init_clearance=12.5,
                   height=9862.5,
                   obstacle_radius=3.,
                   **kwargs):
    """Create an environment with a centerline of obstacles."""
    sideGuard = True
    print(sideGuard)
    if sideGuard:
        pos_x = jnp.concatenate((jnp.zeros((num_obstacles-2, 1)), jnp.array([[-55.], [55.]]).reshape(2, 1)), axis=0)
        #print(pos_x.shape)
        #print(pos_x[:5,0])
        pos_y = jnp.concatenate((jnp.linspace(init_clearance, height, num_obstacles-2)[:, None], 0*jnp.ones((2, 1))), axis=0)
        #print(pos_y.shape)
        #print(pos_y[:5,0])
        vel = jnp.concatenate( (jnp.zeros((num_obstacles-2, 2)), jnp.concatenate((jnp.zeros((2, 1)), jnp.ones((2, 1))), axis=1)                                ), axis=0)
        #print(vel.shape)
        #print(vel[num_obstacles-5:, :])
    else:
        pos_x = jnp.zeros((num_obstacles, 1))
        pos_y = jnp.linspace(init_clearance, height, num_obstacles)[:, None]
        vel = jnp.zeros((num_obstacles, 2))
    
    r = jnp.ones((num_obstacles, 1)) * obstacle_radius
    obstacles = jnp.hstack((pos_x, pos_y, vel, r))

    env = racer.Racer.create(
        num_obstacles=num_obstacles,
        init_clearance=init_clearance,
        height=height,
        **kwargs)
    init_state = env.init(obstacles=obstacles)
    print(env.sensor_range)
    return env, init_state


def get_slalom(
        num_gates=10,
        min_gate_width=6,
        max_gate_width=12,
        center=0,
        width=100,
        height=1000,
        init_clearance=50,
        obstacle_radius=4,
        d_min=7,
        **kwargs,
        ):
    """Create an environment with slalom of gates."""
    min_x = center - width / 2
    max_x = center + width / 2
    min_gate_center = center - d_min - min_gate_width / 2
    max_gate_center = center + d_min + min_gate_width / 2

    pos_y = jnp.linspace(init_clearance, height, num_gates)

    gates = []
    for y in pos_y:
        x = jax.random.uniform(
            jax.random.fold_in(jax.random.PRNGKey(0), int(y)),
            shape=(),
            minval=min_gate_center,
            maxval=max_gate_center)
        gate_width = jax.random.uniform(
            jax.random.fold_in(jax.random.PRNGKey(1), int(y)),
            shape=(),
            minval=min_gate_width,
            maxval=max_gate_width)

        left_gate = x - gate_width / 2
        right_gate = x + gate_width / 2

        xs = list(
            range(
                int(left_gate - obstacle_radius), int(min_x - obstacle_radius * 2),
                -obstacle_radius * 2))
        xs += list(
            range(
                int(right_gate + obstacle_radius), int(max_x + obstacle_radius * 2),
                obstacle_radius * 2))

        gate = jnp.concatenate((jnp.array(xs)[:, None], jnp.ones((len(xs), 1)) * y),
                               axis=1)
        gates.append(gate)

    obstacles = jnp.concatenate(gates, axis=0)
    obstacles = jnp.concatenate((obstacles, jnp.zeros(
      (len(obstacles), 2)), jnp.ones((len(obstacles), 1)) * obstacle_radius),
                              axis=1)

    env = racer.Racer.create(
        num_obstacles=len(obstacles),
        init_clearance=init_clearance,
        center=center,
        width=width,
        height=height,
        **kwargs)
    init_state = env.init(obstacles=obstacles)

    return env, init_state


def get_pedestrian(min_velocity=-10., max_velocity=10., **kwargs):
    """Create an environment with pedestrian obstacles."""

    env = racer.PedestrianRacer.create(
        min_velocity=min_velocity,
        max_velocity=max_velocity,
        **kwargs,
        )
    init_state = env.init()

    return env, init_state