import argparse
import ray
from ray.tune.registry import register_env
from ray.rllib.env import ParallelPettingZooEnv
import supersuit as ss
from time import sleep
from random import randint
args = argparse.ArgumentParser()


args.add_argument("agents", type = str, help = "Path to the safed policy")
args.add_argument("mode", type = str, choices = ["render", "save-img", "eval"], help = "whether to render the environments, save screenshots or only evaluate")
args.add_argument("--adversaries", type = str, nargs="+", default = "", help = "agents that will be controlled by the adversary")
args.add_argument("--iter", type = int, default = 1, help = "Amount of episodes")
args.add_argument("--use-seed", action = "store_true", help = "whether to use a seed for more consistant results")

args = args.parse_args()

ray.init(log_to_driver=False)


from pettingzoo.sisl import pursuit_v4
from ray.rllib.algorithms.ppo import PPOConfig

env = pursuit_v4.parallel_env(shared_reward=False, render_mode = "human" if args.mode == "render" else "rgb_array")
env = ss.frame_stack_v1(env, 8)
env = ss.flatten_v0(env)
register_env("pursuit", lambda _ : ParallelPettingZooEnv(env))

def policy_mapping_fn(agent_id, episode, worker):
    return f"adv_{agent_id[-1]}_e" if agent_id in args.adversaries else agent_id

algo = (PPOConfig()
        .framework("tf")
        .environment("pursuit")
        .resources(num_gpus=0)
        .multi_agent(
            policies={f"pursuer_{i}" for i in range(8)},
            policy_mapping_fn= lambda agent_id, episode, worker : agent_id
        )
        .rollouts(num_rollout_workers=6)
        .training(train_batch_size=5000,
                    sgd_minibatch_size=500,
                    lambda_ = 0.95,
                    entropy_coeff = 0.01
                    )
        .build())
        
average = 0
algo.restore(args.agents)
for episode in range(args.iter):
    done = {"adversary_0" : False}
    if args.use_seed:
        env.seed(episode)
    obs = env.reset()
    totalRew = {agent : 0 for agent in env.agents}
    steps = 0
    while not all(list(done.values())):
        steps += 1
        if args.mode == "save-img":
            import matplotlib.pyplot as plt
            img = env.render()
            plt.imsave("screenshots/{:03d}.png".format(steps), img)
        elif args.mode == "render":
            env.render()
        actions = {}
        for agent in env.agents:
            pol = policy_mapping_fn(agent, episode, None)
            act, a, b = algo.get_policy(pol).compute_single_action(obs = obs[agent], explore = True)
            actions[agent] = act
        obs, rew, done, trunc, info = env.step(actions)
        for agent in rew.keys():
            totalRew[agent] += rew[agent]
    for agent in totalRew.keys():
        if agent not in args.adversaries:
            average += totalRew[agent]
    print(episode, totalRew)

print(average/args.iter / (8 - len(args.adversaries)))
env.close()
