import ray
from ray.tune.registry import register_env
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print

from torch import nn
import torch

from itertools import combinations
from argparse import ArgumentParser
import random
import math

environments = ["pursuit", "spread_indp", "ant"]

args = ArgumentParser()

args.add_argument("environment", choices = environments, help = "The environment the agent will be trained for")
args.add_argument("--max-adversaries", type = int, default= 1, help = "The Cardinality k of the subsets of adversaries")
args.add_argument("--iter", type = int, default = 100, help = "Number of iteration the agents will be trained for")
args.add_argument("--name", type = str, default = "unnamed_experiment", help = "Name of the directory the trained policy will be saved in")
args.add_argument("--param", type = int, default=10, help = "Name of the directory the trained policy will be saved in")
args.add_argument("--restore", type = str, help = "Restore pre-trained policy")
args = args.parse_args()

ray.init(log_to_driver=False, local_mode=False)

# Model that selects the subset of adversaries
class AdversarySelection(nn.Module):
    def __init__(self, num_agents, obs_size) -> None:
        # number of all subsets the adversary can control
        num_out = math.comb(num_agents, args.max_adversaries)
        super().__init__()
        self.network = nn.Sequential(
            # Input : initial observation of all agents, and id of attacked agent
            nn.Linear(num_agents * obs_size + 1, 300),
            nn.ReLU(),
            nn.Linear(300, 400),
            nn.ReLU(),
            nn.Linear(400, num_out)
        )

    def forward(self, obs, targets):
        if isinstance(obs, dict):
            obs = [torch.tensor(o, dtype = torch.float) for o in obs.values()] + [torch.tensor([targets])]
            obs = torch.cat(obs)
        else:
            obs = [torch.cat([torch.tensor(o, dtype = torch.float) for o in o.values()]) for o in obs]
            obs = torch.stack(obs)
            obs = torch.cat((obs, targets.unsqueeze(1)), dim = 1)
        return self.network(obs)

if args.environment == "pursuit":
    from pettingzoo.sisl import pursuit_v4
    from ray.rllib.algorithms.ppo import PPOConfig
    import supersuit as ss
    from utils import IndependentAdversarialEnvironment
    from ray.rllib.policy.policy import PolicySpec
    
    env = pursuit_v4.parallel_env(shared_reward=False)
    env = ss.frame_stack_v1(env, 8)
    env = ss.flatten_v0(env)
    env = IndependentAdversarialEnvironment(env)
    register_env("pursuit", lambda _ : env)
    agents = env._agent_ids
    
    # possible subsets of cardinality k
    subsets = list(combinations(agents, args.max_adversaries))
    print(subsets)

    # initialize subset selection model \omega
    selModel = AdversarySelection(len(agents), env.observation_space.shape[0])
    optim = torch.optim.Adam(selModel.parameters())

    policies = {}
    for i in range(8):
        policies[f"pursuer_{i}"] = PolicySpec(config = {"lr" : 5e-05/2})
        policies[f"pursuer_{i}_adv"] = PolicySpec()

    def policy_mapping_fn(agent_id, episode, worker):
        random.seed(episode.episode_id) # keep selection consistent, as this is called once for every agent
        if random.random() < 0.5:
            adversaries = []
        else:
            trained_agent = worker.worker_index % 8 # 1 agent is trained on every worker
            torch.manual_seed(episode.episode_id)
            values = selModel(episode.user_data["first_obs"], trained_agent)
            subsets_with_agents = [i for i,a in enumerate(subsets) if f"agent_{trained_agent}" in a]
            # mask out subsets with trained agent
            values[subsets_with_agents] = float("-inf")
            values = torch.softmax(values, dim = 0)
            selection = values.multinomial(num_samples = 1)
            # select subset
            adversaries = subsets[selection]
        worker.env.set_adversaries(adversaries, "social-welfare")
        if agent_id in adversaries:
            return f"pursuer_{agent_id[-1]}_adv"
        return agent_id

    class SetAdversaryCallback(DefaultCallbacks):
        def on_episode_start(self, *, worker, base_env, policies, episode, env_index = None, **kwargs) -> None:
            episode.user_data["first_obs"] = base_env.env_states[0].last_obs

        # Remove non-adversarial policies of agents that are not currently trained
        # NOTE: Technically it is not required for this environment, and training can be significantly sped up if the following lines are removed
        # However, we kept it to be consistent with the pseudocode
        def on_sample_end(self, *, worker, samples, **kwargs) -> None:
            trained_agent = worker.worker_index % 8 # Train 1 agent + adversarial policies against this agent per worker
            for i in range(8):
                if i != trained_agent and f"pursuer_{i}" in samples.policy_batches.keys():
                    samples.policy_batches.pop(f"pursuer_{i}")

    algo = (PPOConfig()
            .framework("tf")
            .environment("pursuit")
            .resources(num_gpus=0)
            .multi_agent(
                policies=policies,
                policy_mapping_fn = policy_mapping_fn
            )
            .callbacks(SetAdversaryCallback)
            .rollouts(num_rollout_workers=8)
            .training(train_batch_size=5000,
                      sgd_minibatch_size=500,
                      lambda_ = 0.95,
                      entropy_coeff = 0.01
                      )
            .build())
    

elif args.environment == "spread_indp":
    from envs.Spread.Spread import Spread_Indp
    from ray.rllib.algorithms.ppo import PPOConfig
    from utils import IndependentAdversarialEnvironment
    from ray.rllib.policy.policy import PolicySpec

    env = Spread_Indp(render=False, param=args.param)
    env = IndependentAdversarialEnvironment(env)
    register_env("spread_indp", lambda _ : env)
    agents = env._agent_ids

    # possible subsets for adversary selection
    subsets = list(combinations(agents, args.max_adversaries))
    print(subsets)

    # initialize subset selection model \omega
    selModel = AdversarySelection(len(agents), env.observation_space.shape[0])
    optim = torch.optim.Adam(selModel.parameters())

    # add one benign policy for every agent, and two adversarial ones attacking the agent
    policies = {}
    for i in range(3):
        for j in range(3):
            if i == j:
                policies[f"agent_{i}"] = PolicySpec(config = {"lr" : 5e-05/2})
            else:
                policies[f"adv_{i}_{j}"] = PolicySpec()

    def policy_mapping_fn(agent_id, episode, worker):
        # 1 agent is trained on every worker
        if worker.worker_index < 4:
            trained_agent = 0
        elif worker.worker_index < 7:
            trained_agent = 1
        else:
            trained_agent = 2
        
        random.seed(episode.episode_id)
        torch.manual_seed(episode.episode_id)
        values = selModel(episode.user_data["first_obs"], trained_agent)
        subsets_with_agents = [i for i,a in enumerate(subsets) if f"agent_{trained_agent}" in a]
        # mask out subsets with trained agent
        values[subsets_with_agents] = float("-inf")
        values = torch.softmax(values, dim = 0)
        # select subset
        selection = values.multinomial(num_samples = 1)
        adversaries = subsets[selection]
        worker.env.set_adversaries(adversaries, f"agent_{trained_agent}")
        if agent_id in adversaries:
            return f"adv_{agent_id[-1]}_{trained_agent}"
        return agent_id

    class SetAdversaryCallback(DefaultCallbacks):
        def on_sample_end(self, *, worker, samples, **kwargs) -> None:
            # Remove non-adversarial policies of agents that are not currently trained
            if worker.worker_index < 4:
                trained_agent = 0
            elif worker.worker_index < 7:
                trained_agent = 1
            else:
                trained_agent = 2
            for i in range(3):
                if i != trained_agent and f"agent_{i}" in samples.policy_batches.keys():
                    samples.policy_batches.pop(f"agent_{i}")
    
        def on_episode_start(self, *, worker, base_env, policies, episode, env_index = None, **kwargs) -> None:
            # Add for omega
            episode.user_data["first_obs"] = base_env.env_states[0].last_obs

    algo = (PPOConfig()
            .framework("tf")
            .environment("spread_indp")
            .multi_agent(
                policies=policies,
                policy_mapping_fn = policy_mapping_fn,
            )
            .callbacks(SetAdversaryCallback)
            .resources(num_gpus = 0)
            .rollouts(num_rollout_workers=9, num_envs_per_worker=1)
            .build())
    
elif args.environment == "ant":
    from gymnasium_robotics import mamujoco_v0

    from ray.rllib.algorithms.ppo import PPOConfig
    from ray.rllib.policy.policy import PolicySpec

    from utils import IndependentAdversarialEnvironment

    env = mamujoco_v0.parallel_env("Ant", "4x2", agent_obsk = 2, ctrl_cost_weight = 0, contact_cost_weight = 0, healthy_reward = 0)
    env = IndependentAdversarialEnvironment(env)
    register_env("ant", lambda _ : env)
    agents = env._agent_ids

    # possible subsets for adversary selection
    subsets = list(combinations(agents, args.max_adversaries))
    print(subsets)

    # initialize subset selection model \omega
    selModel = AdversarySelection(len(agents), env.observation_space.shape[0])
    optim = torch.optim.Adam(selModel.parameters(), lr = 0.0003)

    policies = {}
    for i in range(4):
        policies[f"agent_{i}"] = PolicySpec(config = {"lr" : 0.0003/2})
        policies[f"agent_{i}_adv"] = PolicySpec(config = {"lr" : 0.0003})
    
    def policy_mapping_fn(agent_id, episode, worker):
        trained_agent = worker.worker_index % 4
        
        random.seed(episode.episode_id)
        torch.manual_seed(episode.episode_id)
        # randomly select clean episodes
        if random.random() < 0:
            worker.env.set_adversaries([], f"agent_{trained_agent}")
            return agent_id

        #select adversaries
        values = selModel(episode.user_data["first_obs"], trained_agent) 
        subsets_with_agents = [i for i,a in enumerate(subsets) if f"agent_{trained_agent}" in a]
        # mask out subsets with trained agent
        values[subsets_with_agents] = float("-inf")
        values = torch.softmax(values, dim = 0)
        selection = values.multinomial(num_samples = 1)
        adversaries = subsets[selection]
        # does not matter due to shared reward
        worker.env.set_adversaries(adversaries, f"agent_{trained_agent}")
        if agent_id in adversaries:
            return agent_id + "_adv"
        return agent_id

    class SetAdversaryCallback(DefaultCallbacks):    
        def on_episode_start(self, *, worker, base_env, policies, episode, env_index = None, **kwargs) -> None:
            episode.user_data["first_obs"] = base_env.env_states[0].last_obs

        def on_sample_end(self, *, worker, samples, **kwargs) -> None:
            # Remove non-adversarial policies of agents that are not currently trained
            # NOTE: Technically it is not required for this environment, and training can be significantly sped up if the following lines are removed
            # However, we kept it to be consistent with the pseudocode
            if worker.worker_index < 4:
                trained_agent = 0
            elif worker.worker_index < 7:
                trained_agent = 1
            else:
                trained_agent = 2
            for i in range(3):
                if i != trained_agent and f"agent_{i}" in samples.policy_batches.keys():
                    samples.policy_batches.pop(f"agent_{i}")

    algo = (PPOConfig()
            .framework("tf")
            .environment("ant")
            .resources(num_gpus=0)
            .multi_agent(
                policies = policies, 
                policy_mapping_fn = policy_mapping_fn)
            .rollouts(num_rollout_workers = 8)
            .callbacks(SetAdversaryCallback)
            .training(
                gamma = 0.99,
                train_batch_size = 65536,
                kl_coeff = 1.0,
                sgd_minibatch_size = 4096,
                vf_loss_coeff = 0.5,
                clip_param = 0.2,
                grad_clip = 0.5,
                model = {"fcnet_hiddens" : [400,300], "fcnet_activation" : "tanh"})
            .build())

def sample_one_episode():
    init_obs, info = env.reset()
    obs = init_obs
    target_rew = 0
    done, truncated = {"__all__" : False}, {"__all__" : False}
    target = random.choice(agents)
    subsets_without_target = [i for i,a in enumerate(subsets) if target not in a]
    selection = random.choice(subsets_without_target)
    adversaries = subsets[selection]
    env.set_adversaries(adversaries, target)
    while not done["__all__"] and not truncated["__all__"]:
        actions = {} 
        for agent in obs.keys():
            pol = agent
            if agent in adversaries and args.environment == "spread_indp":
                pol = f"adv_{agent[-1]}_{target[-1]}"
            elif agent in adversaries and args.environment == "ant":
                pol = agent + "_adv"
            pol = algo.get_policy(pol)
            
            a, b, c = pol.compute_single_action(obs = obs[agent])            
            actions[agent] = a
        obs, rew, done, truncated, info = env.step(actions)
        target_rew += rew[target]
    return init_obs, selection, target, target_rew
            
def train_selector_model():
    obs, selections, targets, target_rews = [], [], [], []
    # collect samples
    for _ in range(10):
        init_obs, selection, target, total_rew = sample_one_episode()
        obs.append(init_obs)
        selections.append(selection)
        targets.append(int(target[-1]))
        target_rews.append(total_rew)
    selections = torch.tensor(selections)
    target_rews = torch.tensor(target_rews)
    targets = torch.tensor(targets)

    values = selModel(obs, targets)
    values = torch.softmax(values, dim = 1)
    values = values[torch.arange(0, len(values)), selections]
    entropy = values * torch.log(values)
    values = torch.log(values)
    # update model
    loss = values * target_rews
    loss += 0.01 * entropy.mean()
    loss = loss.mean()
    optim.zero_grad()
    loss.backward()
    optim.step()

if args.restore:
    algo.restore(args.restore)

for i in range(args.iter):
    info = algo.train()
    print(pretty_print(info))
    train_selector_model()


algo.save(f"RobustTraining/Agents/Robust/{args.environment}/{args.name}")
