import ray
from ray.tune.registry import register_env
from ray.tune.logger import pretty_print
from ray.rllib.algorithms.callbacks import DefaultCallbacks

from utils import IndependentAdversarialEnvironment
import supersuit as ss
from ray.rllib.env import ParallelPettingZooEnv

from argparse import ArgumentParser

environments = ["pursuit", "spread_indp", "ant", "gfootball"]

args = ArgumentParser()

args.add_argument("environment", choices = environments, help = "The environment the agent will be trained for")
args.add_argument("agents", type = str, help = "Path to pre-trained policies" )
args.add_argument("--adversaries", type = str, nargs="+", default = [], help = "adversaries that will be trained")
args.add_argument("--target", type = str, help = "agent that will be attacked")
args.add_argument("--iter", type = int, default = 100, help = "amount of iteration for training")
args.add_argument("--name", type = str, default = "unnamed_experiment", help = "Name of the directory the trained policy will be saved in")

args = args.parse_args()

ray.init(log_to_driver=False, local_mode=False)

if args.environment == "pursuit":
    from pettingzoo.sisl import pursuit_v4
    from ray.rllib.algorithms.ppo import PPOConfig
    from utils import IndependentAdversarialEnvironment
    
    env = pursuit_v4.parallel_env(shared_reward=False)
    env = ss.frame_stack_v1(env, 8)
    env = ss.flatten_v0(env)
    env = IndependentAdversarialEnvironment(env)
    print(args.adversaries)
    env.set_adversaries(args.adversaries, "social-welfare")
    register_env("pursuit", lambda _ : env)

    algo = (PPOConfig()
            .framework("tf")
            .environment("pursuit")
            .resources(num_gpus=0)
            .multi_agent(
                policies={f"pursuer_{i}" for i in range(8)},
                policy_mapping_fn= lambda agent_id, episode, worker : agent_id
            )
            .rollouts(num_rollout_workers=10)
            .training(train_batch_size=5000,
                      sgd_minibatch_size=500,
                      lambda_ = 0.95,
                      entropy_coeff = 0.01
                      )
            .build())

elif args.environment == "spread_indp":
    from envs.Spread.Spread import Spread_Indp
    from ray.rllib.algorithms.ppo import PPOConfig


    env = Spread_Indp(render=False)
    env = IndependentAdversarialEnvironment(env)
    env.set_adversaries(args.adversaries, args.target)
    register_env("spread_indp", lambda _ : env )

    algo = (PPOConfig()
            .framework("tf")
            .environment("spread_indp")
            .multi_agent(
                policies={f"agent_{i}" for i in range(3)},
                policy_mapping_fn = lambda agent_id, episode, worker : agent_id
            )
            .resources(num_gpus = 1)
            .rollouts(num_rollout_workers=12, num_envs_per_worker=1)
            .build())

elif args.environment == "ant":
    # environemnt
    from gymnasium_robotics import mamujoco_v0

    # ray imports
    from ray.rllib.algorithms.ppo import PPO, PPOConfig

    env = mamujoco_v0.parallel_env("Ant", "4x2", agent_obsk = 2, ctrl_cost_weight = 0, contact_cost_weight = 0, healthy_reward = 0)
    env = IndependentAdversarialEnvironment(env)
    env.set_adversaries(args.adversaries, "social-welfare")
    register_env("ant", lambda _ : env)
    
    def mapping_fn(agent_id, episode, worker):
        return agent_id

    algo = (PPOConfig()
            .framework("tf")
            .environment("ant")
            .resources(num_gpus=1)
            .multi_agent(
                policies = {"agent_0", "agent_1", "agent_2", "agent_3"}, # Store adversarial policies later
                policy_mapping_fn = mapping_fn)
            .rollouts(num_rollout_workers = 12)
            .training(
                gamma = 0.99,
                lr = 0.0003,
                train_batch_size = 65536,
                kl_coeff = 1.0,
                sgd_minibatch_size = 4096,
                vf_loss_coeff = 0.5,
                clip_param = 0.2,
                grad_clip = 0.5,
                model = {"fcnet_hiddens" : [400,300], "fcnet_activation" : "tanh"})
            .build())

algo.restore(args.agents)

def policy_mapping_fn(agent_id, episode, worker):
    if agent_id in args.adversaries:
        return f"adv_{agent_id[-1]}_{args.target[-1]}"
    else:
        return agent_id

# add adversarial policies to training 
policies_to_train = []
for adv in args.adversaries:
    adv_policy = f"adv_{adv[-1]}_{args.target[-1]}"
    try:
        algo.remove_policy(adv_policy)
    except:
        pass
    # only train adversaries
    policies_to_train.append(adv_policy)
    # add adversarial policy with updated mapping fn and policies to train
    algo.add_policy(
        adv_policy,
        policy_cls = type(algo.get_policy(adv)),
        policies_to_train=policies_to_train,
        policy_mapping_fn=policy_mapping_fn
    )

for i in range(args.iter):
    info = algo.train()
    print(pretty_print(info))

algo.save(f"RobustTraining/Agents/adversarialAgents/{args.environment}_{len(args.adversaries)}/{args.name}")
