import ray
from ray.tune.registry import register_env
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print

from torch import nn
import torch

from itertools import combinations
from argparse import ArgumentParser
import random
import math

environments = ["pursuit", "spread_indp", "ant"]

args = ArgumentParser()

args.add_argument("environment", choices = environments, help = "The environment the agent will be trained for")
args.add_argument("--max-adversaries", type = int, default= 1, help = "The Cardinality k of the subsets of adversaries")
args.add_argument("--iter", type = int, default = 100, help = "amount of iteration for training")
args.add_argument("--name", type = str, default = "unnamed_experiment", help = "Name of the directory the trained policy will be saved in")
args = args.parse_args()

ray.init(log_to_driver=False, local_mode=False)

if args.environment == "pursuit":
    from pettingzoo.sisl import pursuit_v4
    from ray.rllib.algorithms.ppo import PPOConfig
    import supersuit as ss
    from utils import IndependentAdversarialEnvironment
    from ray.rllib.policy.policy import PolicySpec
    
    env = pursuit_v4.parallel_env(shared_reward=False)
    env = ss.frame_stack_v1(env, 8)
    env = ss.flatten_v0(env)
    env = IndependentAdversarialEnvironment(env)
    register_env("pursuit", lambda _ : env)
    agents = env._agent_ids

    policies = {}
    for i in range(8):
        policies[f"pursuer_{i}"] = PolicySpec(config = {"lr" : 5e-05})
        policies[f"pursuer_{i}_adv"] = PolicySpec()

    def policy_mapping_fn(agent_id, episode, worker):
        trained_agent = worker.worker_index % 8
        # preselected adversaries
        if trained_agent < 3:
            adversaries = ["pursuer_0", "pursuer_1", "pursuer_2", "pursuer_3"][:args.max_adversaries]
        else:
            adversaries = ["pursuer_4", "pursuer_5", "pursuer_6", "pursuer_7"][:args.max_adversaries]
        
        worker.env.set_adversaries(adversaries, "social-welfare")
        if agent_id in adversaries:
            return f"pursuer_{agent_id[-1]}_adv"
        return agent_id

    algo = (PPOConfig()
            .framework("tf")
            .environment("pursuit")
            .resources(num_gpus=0)
            .multi_agent(
                policies=policies,
                policy_mapping_fn = policy_mapping_fn
            )
            .rollouts(num_rollout_workers=4)
            .training(train_batch_size=5000,
                      sgd_minibatch_size=500,
                      lambda_ = 0.95,
                      entropy_coeff = 0.01
                      )
            .build())

elif args.environment == "spread_indp":
    from envs.Spread.Spread import Spread_Indp
    from ray.rllib.algorithms.ppo import PPOConfig
    from utils import IndependentAdversarialEnvironment
    from ray.rllib.policy.policy import PolicySpec

    env = Spread_Indp(render=False)
    env = IndependentAdversarialEnvironment(env)
    register_env("spread_indp", lambda _ : env)
    agents = env._agent_ids

    policies = {}
    for i in range(3):
        for j in range(3):
            if i == j:
                policies[f"agent_{i}"] = PolicySpec(config = {"lr" : 5e-05/2})
            else:
                policies[f"adv_{i}_{j}"] = PolicySpec()

    def policy_mapping_fn(agent_id, episode, worker):
        if worker.worker_index < 4:
            trained_agent = 0
            if args.max_adversaries == 1:
            # preselected adversaries
                adversaries = ("agent_1")
            else:
                adversaries = ("agent_1", "agent_2")
        elif worker.worker_index < 7:
            trained_agent = 1
            if args.max_adversaries == 1:
                adversaries = ("agent_0")
            else:
                adversaries = ("agent_0", "agent_2")
        else:
            trained_agent = 2
            if args.max_adversaries == 1:
                adversaries = ("agent_1")
            else:
                adversaries = ("agent_0", "agent_1")
        
        random.seed(episode.episode_id)
        
        worker.env.set_adversaries(adversaries, f"agent_{trained_agent}")
        if agent_id in adversaries:
            return f"adv_{agent_id[-1]}_{trained_agent}"
        return agent_id

    class SetAdversaryCallback(DefaultCallbacks):
        def on_sample_end(self, *, worker, samples, **kwargs) -> None:
            # Remove non-adversarial policies of agents that are not currently trained
            if worker.worker_index < 4:
                trained_agent = 0
            elif worker.worker_index < 7:
                trained_agent = 1
            else:
                trained_agent = 2
            for i in range(3):
                if i != trained_agent and f"agent_{i}" in samples.policy_batches.keys():
                    samples.policy_batches.pop(f"agent_{i}")

    algo = (PPOConfig()
            .framework("tf")
            .environment("spread_indp")
            .multi_agent(
                policies=policies,
                policy_mapping_fn = policy_mapping_fn,
            )
            .callbacks(SetAdversaryCallback)
            .resources(num_gpus = 1)
            .rollouts(num_rollout_workers=9, num_envs_per_worker=1)
            .build())

elif args.environment == "ant":
    from gymnasium_robotics import mamujoco_v0

    from ray.rllib.algorithms.ppo import PPOConfig
    from ray.rllib.policy.policy import PolicySpec

    from utils import IndependentAdversarialEnvironment

    env = mamujoco_v0.parallel_env("Ant", "4x2", agent_obsk = 2, ctrl_cost_weight = 0, contact_cost_weight = 0, healthy_reward = 0)
    env = IndependentAdversarialEnvironment(env)
    register_env("ant", lambda _ : env)
    agents = env._agent_ids

    policies = {}
    for i in range(4):
        policies[f"agent_{i}"] = PolicySpec(config = {"lr" : 0.0003/2})
        policies[f"agent_{i}_adv"] = PolicySpec(config = {"lr" : 0.0003})
    
    def policy_mapping_fn(agent_id, episode, worker):
        if worker.worker_index < 4:
            trained_agent = 0
            if args.max_adversaries == 1:
                adversaries = ("agent_1")
            else:
                adversaries = ("agent_1", "agent_2")
        elif worker.worker_index < 7:
            trained_agent = 1
            if args.max_adversaries == 1:
                adversaries = ("agent_0")
            else:
                adversaries = ("agent_0", "agent_2")
        elif worker.worker_index < 10:
            trained_agent = 2
            if args.max_adversaries == 1:
                adversaries = ("agent_1")
            else:
                adversaries = ("agent_1", "agent_3")
        else:
            trained_agent = 3
            if args.max_adversaries == 1:
                adversaries = ("agent_1")
            else:
                adversaries = ("agent_1", "agent_2")
        
        random.seed(episode.episode_id)
        # select adversaries
        worker.env.set_adversaries(adversaries, f"agent_{trained_agent}")
        if agent_id in adversaries:
            return agent_id + "_adv"
        return agent_id

    class SetAdversaryCallback(DefaultCallbacks):
        def on_episode_start(self, *, worker, base_env, policies, episode, env_index = None, **kwargs) -> None:
            episode.user_data["first_obs"] = base_env.env_states[0].last_obs

    algo = (PPOConfig()
            .framework("tf")
            .environment("ant")
            .resources(num_gpus=0.5)
            .multi_agent(
                policies = policies, 
                policy_mapping_fn = policy_mapping_fn)
            .rollouts(num_rollout_workers = 12)
            .callbacks(SetAdversaryCallback)
            .training(
                gamma = 0.99,
                train_batch_size = 65536,
                kl_coeff = 1.0,
                sgd_minibatch_size = 4096,
                vf_loss_coeff = 0.5,
                clip_param = 0.2,
                grad_clip = 0.5,
                model = {"fcnet_hiddens" : [400,300], "fcnet_activation" : "tanh"})
            .build())

for i in range(args.iter):
    info = algo.train()
    print(pretty_print(info))

algo.save(f"RobustTraining/Agents/Robust-fixed-k/{args.environment}/{args.name}")
