import gym
import rware
from utils import Warehouse
from ray.tune.registry import register_env
from argparse import ArgumentParser 
from time import sleep
from matplotlib.pyplot import imsave

args = ArgumentParser()
args.add_argument("agents", type = str, help = "Path to the safed policy")
args.add_argument("--iter", type = int, default = 1, help = "Amount of episodes")
args.add_argument("--render", action="store_true", help = "Render the environment")
args.add_argument("--save-img", action = "store_true", help = "save screenshots of every timestep")
args.add_argument("--adversaries", type = str, nargs="+", default = [], help = "selection of adversarial agents")
args.add_argument("--target", type = str, help = "target of the adversarial agents")

args = args.parse_args()

from envs.Spread.Spread import Spread_Indp
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env import ParallelPettingZooEnv

policy_mapping_fn = lambda agent_id, episode, worker : f"adv_{agent_id[-1]}_{args.target[-1]}" if agent_id in args.adversaries else agent_id

env = Spread_Indp(render=args.render)
env._agent_ids = [f"agent_{i}" for i in range(3)]
register_env("spread_indp", lambda _ : ParallelPettingZooEnv(env))

algo = (PPOConfig()
        .framework("tf")
        .environment("spread_indp")
        .multi_agent(
            policies=[f"agent_{i}" for i in range(3)] + [f"adv_{adv[-1]}_{args.target[-1]}" for adv in args.adversaries],
            policy_mapping_fn = lambda agent_id, episode, worker : agent_id
        )
        .resources(num_gpus = 1)
        .rollouts(num_rollout_workers=1, num_envs_per_worker=1)
        .build())

algo.restore(args.agents)

totalRew = {agent : 0 for agent in env._agent_ids}
for i in range(args.iter):
    steps = 0
    delivered_packages = {agent : 0 for agent in env._agent_ids}
    obs, info = env.reset(seed = i)
    done = {"__all__" : False}
    while not all(done.values()):
        actions = {}
        for agents in env._agent_ids:
            pol = policy_mapping_fn(agents, None, None)
            a, b, c = algo.get_policy(pol).compute_single_action(obs = obs[agents])
            actions[agents] = a
        obs, rew, done, trunc, info = env.step(actions)
        for a, r in rew.items():
            delivered_packages[a] += r
        if args.render:
            sleep(0.05)
        if args.save_img:
            img = env.render()
            imsave(f"screenshots/{steps}.png", img)
        steps += 1
    for a in totalRew.keys():
        totalRew[a] += delivered_packages[a]
    print(delivered_packages)

for a in totalRew.keys():
    totalRew[a] /= args.iter
print("average") 
print(totalRew)
