from ray.tune.registry import register_env
from ray.rllib.env import ParallelPettingZooEnv
from ray.rllib.algorithms import Algorithm

import matplotlib.pyplot as plt

import argparse
from time import sleep
args = argparse.ArgumentParser()

args.add_argument("agents", type = str, help = "Path to the safed policy")
args.add_argument("--save-img", action="store_true", help = "save screenshots of every timestep")
args.add_argument("--render", action="store_true", help = "Render the environment")
args.add_argument("--iter", default=1, type=int, help = "Amount of episodes")
args.add_argument("--adversaries", type = str, nargs="+", default = [], help = "selection of adversarial agents")
args.add_argument("--target", type = str, help = "target of the adversarial agents")


args = args.parse_args()

from gymnasium_robotics import mamujoco_v0

# ray imports
from ray.rllib.algorithms.ppo import PPO, PPOConfig

env = mamujoco_v0.parallel_env("Ant", "4x2", agent_obsk = 3, render_mode = "rgb_array", ctrl_cost_weight = 0, contact_cost_weight = 0, healthy_reward = 0.1)
env = env
register_env("ant", lambda _ : ParallelPettingZooEnv(env))

def mapping_fn(agent_id, episode, worker):
    return f"adv_{agent_id[-1]}_{args.target[-1]}" if agent_id in args.adversaries else agent_id


algo = (PPOConfig()
    .framework("tf")
    .environment("ant")
    .resources(num_gpus=0)
    .multi_agent(
        policies = {"agent_0", "agent_1", "agent_2", "agent_3", "agent_0_adv", "agent_1_adv", "agent_2_adv", "agent_3_adv", "adv_0_1"}, # Store adversarial policies later
        policy_mapping_fn = mapping_fn)
    .rollouts(num_rollout_workers = 12)
    #.callbacks(AdversarySelectionCallbackPPO)
    .training(
        gamma = 0.99,
        lr = 0.0003,
        train_batch_size = 65536,
        kl_coeff = 1.0,
        sgd_minibatch_size = 4096,
        vf_loss_coeff = 0.5,
        clip_param = 0.2,
        grad_clip = 0.5,
        model = {"fcnet_hiddens" : [400, 300]})
    .build())

print(args.adversaries)
algo.restore(args.agents)
totalRew = 0

for i in range(args.iter):
    print(i)
    obs = env.reset()
    done, truncated = {"agent_0" : False}, {"agent_0" : False}

    if args.environment != "multiwalker" and args.render:
        fig, ax = plt.subplots(1,1)
        img = plt.imshow(env.render())
    steps = 0
    while not (any(done.values()) or any(truncated.values())):
        steps += 1
        actions = {}
        for agent in obs.keys():
            if args.environment != "multiwalker":
                pol = mapping_fn(agent, None, None) 
                action, a, b = algo.get_policy(pol).compute_single_action(obs = obs[agent], explore = False)
            else:
                action, a, b = algo.get_policy("shared").compute_single_action(obs = obs[agent])
            actions[agent] = action
        obs, rew, done, truncated, _ = env.step(actions)

        totalRew += sum(rew.values()) / 4
        if args.environment == "multiwalker":
            if args.save_img:
                img = env.render()
                plt.imshow(img)
                plt.savefig(f"screenshots/{steps}.png")
            if args.render:
                env.render()
                sleep(0.01)
        else:
            if args.render:
                img.set_data(env.render())
                fig.canvas.draw_idle()
                plt.pause(0.01)
            if args.save_img:
                plt.savefig(f"screenshots/{steps}.png")
                if steps > 250:
                    break
    # print rolling average
    print(totalRew/(i+1))


