import time

import numpy as np

from traffic_junction_env import TrafficJunctionEnv
import argparse
import sys
import signal

class RandomAgent(object):
    def __init__(self, action_space):
        self.action_space = action_space

    def act(self):
        return self.action_space.sample()


if __name__ == '__main__':
    parser = argparse.ArgumentParser('Example GCCNet environment random agent')
    parser.add_argument('--nagents', type=int, default=4, help="Number of agents")
    parser.add_argument('--display', action="store_true", default=False,
                        help="Use to display environment")

    max_episodes = 100
    sleep_time = 0.01

    env = TrafficJunctionEnv()
    env.init_curses()
    env.init_args(parser)

    args = parser.parse_args()
    # env.init_arg_values(args)
    env.init_ic3net_default('easy', False)
    env.nagents = 3
    env.add_rate_max = env.add_rate_min = 1
    env.multi_agent_init()

    def signal_handler(signal, frame):
        print('You pressed Ctrl+C! Exiting gracefully.')
        if args.display:
            env.exit_render()
        sys.exit(0)

    signal.signal(signal.SIGINT, signal_handler)

    agent = RandomAgent(env.action_space)

    successes = np.empty(max_episodes)

    for episode in range(0, max_episodes):
        obs = env.reset()
        done = False
        step = 0
        while not done and step < env.max_steps:
            actions = []

            for _ in range(env.nagents):
                action = agent.act()
                actions.append(action)

            obs, reward, done, info = env.step(actions)

            if args.display:
                env.render()
                time.sleep(sleep_time)

            step += 1

        successes[episode] = env.stat['success']

    env.close()

    print(successes.mean())
