import supersuit as ss

import ray
from ray.tune.registry import register_env
from ray.tune.logger import Logger, pretty_print
from ray.rllib.env import ParallelPettingZooEnv

import warnings
warnings.filterwarnings("ignore")

import argparse

args = argparse.ArgumentParser()

environments = [
    "pursuit",
    "ant",
    "spread_indp"
]

args.add_argument("environment", type = str, choices=environments, help = "The environment the agent will be trained for")
args.add_argument("--iter", type = int, default = 5000, help = "Number of iteration the agents will be trained for")
args.add_argument("--restore", type = str, help = "Restore pre-trained policy")
args.add_argument("--name", type = str, default = "unnamed_experiment", help = "Name of the directory the trained policy will be saved in")

args = args.parse_args()
ray.init(log_to_driver=False, local_mode=False)


if args.environment == "pursuit":
    from pettingzoo.sisl import pursuit_v4
    from ray.rllib.algorithms.ppo import PPOConfig
    
    env = pursuit_v4.parallel_env(shared_reward=False)
    env = ss.frame_stack_v1(env, 8)
    env = ss.flatten_v0(env)
    env = ParallelPettingZooEnv(env)
    register_env("pursuit", lambda _ : env)

    algo = (PPOConfig()
            .framework("tf")
            .environment("pursuit")
            .resources(num_gpus=0)
            .multi_agent(
                policies={f"pursuer_{i}" for i in range(8)},
                policy_mapping_fn= lambda agent_id, episode, worker : agent_id # no adversarial agents => map every agent to one benign policy
            )
            .rollouts(num_rollout_workers=10)
            .training(train_batch_size=5000,
                      sgd_minibatch_size=500,
                      lambda_ = 0.95,
                      entropy_coeff = 0.01
                      )
            .build())

elif args.environment == "ant":
    from gymnasium_robotics import mamujoco_v0
    from ray.rllib.algorithms.ppo import PPO, PPOConfig

    env = mamujoco_v0.parallel_env("Ant", "4x2", agent_obsk = 2, ctrl_cost_weight = 0, contact_cost_weight = 0, healthy_reward = 0)
    env = ParallelPettingZooEnv(env)
    register_env("ant", lambda _ : env)
    
    def mapping_fn(agent_id, episode, worker):
        return agent_id


    algo = (PPOConfig()
            .framework("tf")
            .environment("ant")
            .resources(num_gpus=1)
            .multi_agent(
                policies = {"agent_0", "agent_1", "agent_2", "agent_3"},
                policy_mapping_fn = mapping_fn)
            .rollouts(num_rollout_workers = 12)
            .training(
                gamma = 0.99,
                lr = 0.0003,
                train_batch_size = 65536,
                kl_coeff = 1.0,
                sgd_minibatch_size = 4096,
                vf_loss_coeff = 0.5,
                clip_param = 0.2,
                grad_clip = 0.5,
                model = {"fcnet_hiddens" : [400,300], "fcnet_activation" : "tanh"})
            .build())

elif args.environment == "spread_indp":
    from envs.Spread.Spread import Spread_Indp
    from ray.rllib.algorithms.ppo import PPOConfig
    import supersuit as ss

    env = Spread_Indp()
    env = ParallelPettingZooEnv(env)
    register_env("spread_indp", lambda _ : env )

    algo = (PPOConfig()
            .framework("tf")
            .environment("spread_indp")
            .multi_agent(
                policies={f"agent_{i}" for i in range(3)},
                policy_mapping_fn = lambda agent_id, episode, worker : agent_id
            )
            .resources(num_gpus = 1)
            .rollouts(num_rollout_workers=8, num_envs_per_worker=1)
            .build())

if args.restore:
    algo.restore(args.restore)

# Main training Loop
for i in range(args.iter):
    print(pretty_print(algo.train()))

# Save policy
algo.save(f"RobustTraining/Agents/{args.environment}/{args.name}")
