import ray
from ray.rllib.policy.policy import PolicySpec
from ray.tune.registry import register_env
from ray.tune.logger import pretty_print
import supersuit as ss
import argparse

args = argparse.ArgumentParser()

args.add_argument("environment", type = str, choices=["pursuit", "ant", "spread_indp"], help = "The environment the agent will be trained for")
args.add_argument("agents", type = str, help = "Path to pre-trained policies")
args.add_argument("--adversaries", type = str, nargs= "+", help = "Agents that will be controlled by the adversary")
args.add_argument("--train", type = str, default = "agent_0", help = "Agent for which the training will be continued")
args.add_argument("--iter", type = int, default = 10, help = "Number of iteration the agents will be trained for")
args.add_argument("--name", type = str, help = "Name of the directory the trained policy will be saved in")

args = args.parse_args()
ray.init(log_to_driver=False, local_mode=False)

if args.environment == "pursuit":
    from pettingzoo.sisl import pursuit_v4
    from ray.rllib.algorithms.ppo import PPOConfig
    from utils import IndependentAdversarialEnvironment
    
    env = pursuit_v4.parallel_env(shared_reward=False)
    env = ss.frame_stack_v1(env, 8)
    env = ss.flatten_v0(env)
    env = IndependentAdversarialEnvironment(env)
    print(args.adversaries)
    env.set_adversaries(args.adversaries, "social-welfare")
    register_env("pursuit", lambda _ : env) 

    def policy_mapping_fn(agent_id, episode, worker):
        if agent_id in args.adversaries:
            return f"adv_{agent_id[-1]}_e"
        return agent_id

    algo = (PPOConfig()
            .framework("tf")
            .environment("pursuit")
            .resources(num_gpus=0)
            .multi_agent(
                policies={f"pursuer_{i}" for i in range(8)},
                policy_mapping_fn= lambda agent_id, episode, worker : agent_id
            )
            .rollouts(num_rollout_workers=6)
            .training(train_batch_size=5000,
                      sgd_minibatch_size=500,
                      lambda_ = 0.95,
                      entropy_coeff = 0.01
                      )
            .build())

elif args.environment == "ant":
    from gymnasium_robotics import mamujoco_v0

    from ray.rllib.algorithms.ppo import PPOConfig
    from ray.rllib.policy.policy import PolicySpec

    from utils import IndependentAdversarialEnvironment

    env = mamujoco_v0.parallel_env("Ant", "4x2", agent_obsk = 2, ctrl_cost_weight = 0, contact_cost_weight = 0, healthy_reward = 0)
    env = IndependentAdversarialEnvironment(env)
    env.set_adversaries(args.adversaries, args.train)
    register_env("ant", lambda _ : env)
    agents = env._agent_ids

    policies = {}
    for i in range(4):
        policies[f"agent_{i}"] = PolicySpec(config = {"lr" : 0.0003/2})
        policies[f"agent_{i}_adv"] = PolicySpec(config = {"lr" : 0.0003})
    
    def policy_mapping_fn(agent_id, episode, worker):
        if agent_id in args.adversaries:
            return agent_id + "_adv"
        return agent_id

    algo = (PPOConfig()
            .framework("tf")
            .environment("ant")
            .resources(num_gpus=1)
            .multi_agent(
                policies = policies, 
                policy_mapping_fn = policy_mapping_fn,
                policies_to_train=[args.train])
            .rollouts(num_rollout_workers = 12)
            .training(
                gamma = 0.99,
                train_batch_size = 65536,
                kl_coeff = 1.0,
                sgd_minibatch_size = 4096,
                vf_loss_coeff = 0.5,
                clip_param = 0.2,
                grad_clip = 0.5,
                model = {"fcnet_hiddens" : [400,300], "fcnet_activation" : "tanh"})
            .build())

elif args.environment == "spread_indp":
    from envs.Spread.Spread import Spread_Indp
    from ray.rllib.algorithms.ppo import PPOConfig
    from utils import IndependentAdversarialEnvironment
    from ray.rllib.policy.policy import PolicySpec

    env = Spread_Indp(render=False)
    env = IndependentAdversarialEnvironment(env)
    env.set_adversaries(args.adversaries, args.train)
    register_env("spread_indp", lambda _ : env)
    agents = env._agent_ids

    policies = {}
    for i in range(3):
        for j in range(3):
            if i == j:
                policies[f"agent_{i}"] = PolicySpec(config = {"lr" : 5e-05/2})
            else:
                policies[f"adv_{i}_{j}"] = PolicySpec()

    def policy_mapping_fn(agent_id, episode, worker):
        if agent_id in args.adversaries:
            return f"adv_{agent_id[-1]}_{args.train[-1]}"
        return agent_id


    algo = (PPOConfig()
            .framework("tf")
            .environment("spread_indp")
            .multi_agent(
                policies={f"agent_{i}" for i in range(3)},
                policy_mapping_fn = policy_mapping_fn,
                policies_to_train= [args.train]
            )
            .resources(num_gpus = 0)
            .rollouts(num_rollout_workers=9, num_envs_per_worker=1)
            .build())
    
algo.restore(args.agents)
# seem to be required to change the policy_mapping_fn
algo.add_policy(
    "dummy_0",
    policy_cls = type(algo.get_policy(args.train)),
    policies_to_train= {args.train},
    policy_mapping_fn=policy_mapping_fn,
)

for i in range(args.iter):
    info = algo.train()
    print(pretty_print(info))

algo.save(f"RobustTraining/Agents/eps-check/{args.environment}/{args.name}")
