import flow
import flow.envs
from flow.controllers import (
    RLController,
    SimCarFollowingController,
    SimLaneChangeController,
)
from flow.controllers.routing_controllers import ContinuousRouter
from flow.core.params import (
    EnvParams,
    InFlows,
    InitialConfig,
    NetParams,
    SumoCarFollowingParams,
    SumoLaneChangeParams,
    SumoParams,
    TrafficLightParams,
    VehicleParams,
)
from flow.envs import BottleneckDesiredVelocityEnv
from flow.networks import BottleneckNetwork
from flow.networks.ring import ADDITIONAL_NET_PARAMS


def bottleneck(render="drgb"):
    # time horizon of a single rollout
    HORIZON = 1500

    SCALING = 1
    NUM_LANES = 4 * SCALING  # number of lanes in the widest highway
    DISABLE_TB = True
    DISABLE_RAMP_METER = True
    AV_FRAC = 0.10

    vehicles = VehicleParams()
    vehicles.add(
        veh_id="human",
        routing_controller=(ContinuousRouter, {}),
        car_following_params=SumoCarFollowingParams(
            speed_mode=9,
        ),
        lane_change_params=SumoLaneChangeParams(
            lane_change_mode=0,
        ),
        num_vehicles=1 * SCALING,
    )
    vehicles.add(
        veh_id="rl",
        acceleration_controller=(RLController, {}),
        routing_controller=(ContinuousRouter, {}),
        car_following_params=SumoCarFollowingParams(
            speed_mode=9,
        ),
        lane_change_params=SumoLaneChangeParams(
            lane_change_mode=0,
        ),
        num_vehicles=1 * SCALING,
    )

    controlled_segments = [
        ("1", 1, False),
        ("2", 2, True),
        ("3", 2, True),
        ("4", 2, True),
        ("5", 1, False),
    ]
    num_observed_segments = [("1", 1), ("2", 3), ("3", 3), ("4", 3), ("5", 1)]

    additional_env_params = {
        "target_velocity": 40,
        "disable_tb": True,
        "disable_ramp_metering": True,
        "controlled_segments": controlled_segments,
        "symmetric": False,
        "observed_segments": num_observed_segments,
        "reset_inflow": False,
        "lane_change_duration": 5,
        "max_accel": 3,
        "max_decel": 3,
        "inflow_range": [1200, 2500],
    }

    # flow rate
    flow_rate = 2500 * SCALING

    # percentage of flow coming out of each lane
    inflow = InFlows()
    inflow.add(
        veh_type="human",
        edge="1",
        vehs_per_hour=flow_rate * (1 - AV_FRAC),
        depart_lane="random",
        depart_speed=10,
    )
    inflow.add(
        veh_type="rl",
        edge="1",
        vehs_per_hour=flow_rate * AV_FRAC,
        depart_lane="random",
        depart_speed=10,
    )

    traffic_lights = TrafficLightParams()
    if not DISABLE_TB:
        traffic_lights.add(node_id="2")
    if not DISABLE_RAMP_METER:
        traffic_lights.add(node_id="3")

    additional_net_params = {"scaling": SCALING, "speed_limit": 23}
    net_params = NetParams(inflows=inflow, additional_params=additional_net_params)

    flow_params = dict(
        # name of the experiment
        exp_tag="bottleneck_0",
        # name of the flow environment the experiment is running on
        env_name=BottleneckDesiredVelocityEnv,
        # name of the network class the experiment is running on
        network=BottleneckNetwork,
        # simulator that is used by the experiment
        simulator="traci",
        # sumo-related parameters (see flow.core.params.SumoParams)
        sim=SumoParams(
            sim_step=0.5,
            render=render,
            save_render=True,
            print_warnings=False,
            restart_instance=True,
        ),
        # environment related parameters (see flow.core.params.EnvParams)
        env=EnvParams(
            warmup_steps=40,
            sims_per_step=1,
            horizon=HORIZON,
            additional_params=additional_env_params,
        ),
        # network-related parameters (see flow.core.params.NetParams and the
        # network's documentation or ADDITIONAL_NET_PARAMS component)
        net=NetParams(
            inflows=inflow,
            additional_params=additional_net_params,
        ),
        # vehicles to be placed in the network at the start of a rollout (see
        # flow.core.params.VehicleParams)
        veh=vehicles,
        # parameters specifying the positioning of vehicles upon initialization/
        # reset (see flow.core.params.InitialConfig)
        initial=InitialConfig(
            spacing="uniform",
            min_gap=5,
            lanes_distribution=float("inf"),
            edges_distribution=["2", "3", "4", "5"],
        ),
        # traffic lights to be introduced to specific nodes (see
        # flow.core.params.TrafficLightParams)
        tls=traffic_lights,
    )
    return flow_params
