import argparse
import numpy as np
import os
import sys
from typing import Dict
import ray
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import (
    AgentID,
    PolicyID,
)
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import Episode, RolloutWorker, MultiAgentEpisode
from ray.rllib.policy.sample_batch import SampleBatch

import pickle


class BaseFortAttackCallbacks(DefaultCallbacks):
    def __init__(self):
        super().__init__()
        self.not_initialized = True

    def on_episode_start(
        self,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[str, Policy],
        episode: MultiAgentEpisode,
        **kwargs,
    ):
        for i in range(2):
            episode.user_data[f"action_dist_agent_{i}"] = []
            episode.hist_data[f"action_dist_agent_{i}"] = []

    def on_episode_step(
        self,
        worker: RolloutWorker,
        base_env: BaseEnv,
        episode: MultiAgentEpisode,
        **kwargs,
    ):
        for i in range(2):
            episode.user_data[f"action_dist_agent_{i}"].append(
                episode.last_action_for(i)
            )

    def on_episode_end(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[str, Policy],
        episode: Episode,
        env_index: int,
        **kwargs,
    ):

        result = episode.last_info_for(0)["result"]
        episode.custom_metrics["guard_win"] = result[0]
        episode.custom_metrics["timeout"] = result[1]
        episode.custom_metrics["attacker_win"] = result[2]
        for i in range(2):
            episode.hist_data[f"action_dist_agent_{i}"] = episode.user_data[
                f"action_dist_agent_{i}"
            ]

    def on_postprocess_trajectory(
        self,
        *,
        worker: "RolloutWorker",
        episode: Episode,
        agent_id: AgentID,
        policy_id: PolicyID,
        policies: Dict[PolicyID, Policy],
        postprocessed_batch: SampleBatch,
        original_batches: Dict[AgentID, SampleBatch],
        **kwargs,
    ) -> None:
        return None

    def on_learn_on_batch(
        self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
    ) -> None:
        """Called at the beginning of Policy.learn_on_batch().

        Note: This is called before 0-padding via
        `pad_batch_to_sequences_of_same_size`.

        Also note, SampleBatch.INFOS column will not be available on
        train_batch within this callback if framework is tf1, due to
        the fact that tf1 static graph would mistake it as part of the
        input dict if present.
        It is available though, for tf2 and torch frameworks.

        Args:
            policy: Reference to the current Policy object.
            train_batch: SampleBatch to be trained on. You can
                mutate this object to modify the samples generated.
            result: A results dict to add custom metrics to.
            kwargs: Forward compatibility placeholder.
        """
        pass

    def on_train_result(self, *, trainer, result, **kwargs):
        # print(
        #     "Policy weight keys",
        #     trainer.get_policy("good_policy").get_state()["weights"].keys(),
        # )
        pass 


def make_default_fort_attack_callback(args):
    return BaseFortAttackCallbacks


def make_selfplay_with_history(args):
    class SelfPlayHistoryCallback(BaseCallbacks):
        def __init__(self):
            super().__init__()
            # 0=RandomPolicy, 1=1st main policy snapshot,
            # 2=2nd main policy snapshot, etc..
            self.current_opponent = 0

        def on_train_result(self, *, trainer, result, **kwargs):
            # Get the win rate for the train batch.
            # Note that normally, one should set up a proper evaluation config,
            # such that evaluation always happens on the already updated policy,
            # instead of on the already used train_batch.

            main_rew = result["hist_stats"].pop("policy_main_reward")
            opponent_rew = list(result["hist_stats"].values())[0]
            assert len(main_rew) == len(opponent_rew)
            won = 0
            for r_main, r_opponent in zip(main_rew, opponent_rew):
                if r_main > r_opponent:
                    won += 1
            win_rate = won / len(main_rew)
            result["win_rate"] = win_rate
            print(f"Iter={trainer.iteration} win-rate={win_rate} -> ", end="")
            # If win rate is good -> Snapshot current policy and play against
            # it next, keeping the snapshot fixed and only improving the "main"
            # policy.
            if win_rate > args.win_rate_threshold:
                self.current_opponent += 1
                new_pol_id = f"main_v{self.current_opponent}"
                print(f"adding new opponent to the mix ({new_pol_id}).")

                # Re-define the mapping function, such that "main" is forced
                # to play against any of the previously played policies
                # (excluding "random").
                def policy_mapping_fn(agent_id, episode, worker, **kwargs):
                    # agent_id = [0|1] -> policy depends on episode ID
                    # This way, we make sure that both policies sometimes play
                    # (start player) and sometimes agent1 (player to move 2nd).
                    return (
                        "main"
                        if episode.episode_id % 2 == agent_id
                        else "main_v{}".format(
                            np.random.choice(list(range(1, self.current_opponent + 1)))
                        )
                    )

                new_policy = trainer.add_policy(
                    policy_id=new_pol_id,
                    policy_cls=type(trainer.get_policy("main")),
                    policy_mapping_fn=policy_mapping_fn,
                )

                # Set the weights of the new policy to the main policy.
                # We'll keep training the main policy, whereas `new_pol_id` will
                # remain fixed.
                main_state = trainer.get_policy("main").get_state()
                new_policy.set_state(main_state)
                # We need to sync the just copied local weights (from main policy)
                # to all the remote workers as well.
                trainer.workers.sync_weights()
            else:
                print("not good enough; will keep learning ...")

            # +2 = main + random
            result["league_size"] = self.current_opponent + 2

    return SelfPlayHistoryCallback


def make_selfplay_sequential(args):
    class SelfPlaySequentialCallback(BaseCallbacks):
        def __init__(self):
            super().__init__()
            # 0=RandomPolicy, 1=1st main policy snapshot,
            # 2=2nd main policy snapshot, etc..

            self.current_trainable_policy = "attacker"
            self.iteration_counter = 0
            self.not_init = True

        def on_train_result(self, *, trainer, result, **kwargs):
            # Get the win rate for the train batch.
            # Note that normally, one should set up a proper evaluation config,
            # such that evaluation always happens on the already updated policy,
            # instead of on the already used train_batch.
            print(result["custom_metrics"].keys())
            print(result["custom_metrics"]["guard_win_mean"])
            print(result["custom_metrics"]["attacker_win_mean"])
            print(result["custom_metrics"]["timeout_mean"])

            # print(trainer.get_policy("good_policy").get_weights()['conceptEncoder.0.weight'])

            if self.not_init:

                def fn(worker: RolloutWorker):
                    worker.set_policies_to_train(
                        policies_to_train={"adversary_policy",}
                    )

                trainer.workers.foreach_worker(fn)
                self.not_init = False
                self.current_trainable_policy = "attacker"
                print(trainer.workers.local_worker().policies_to_train)
                print("initialized")

            # If win rate is good for the attacker -> switch the current trainable policy
            # to the guards policy.
            if (
                self.current_trainable_policy == "attacker"
                and result["custom_metrics"]["attacker_win_mean"]
                > args.sequential_win_rate_threshold
            ):
                print("1")
                print(trainer.workers.local_worker().policies_to_train)

                def fn(worker: RolloutWorker):
                    worker.set_policies_to_train(
                        policies_to_train={"good_policy",}
                    )

                trainer.workers.foreach_worker(fn)
                # We need to sync the just copied local weights (from main policy)
                # to all the remote workers as well.
                trainer.workers.sync_weights()

                self.current_trainable_policy = "guard"
                print(
                    "attacker win rate > "
                    + str(args.sequential_win_rate_threshold)
                    + " , switching to the guard policy"
                )

            # If win rate is good for the guard -> switch the current trainable policy
            # to the attacker policy.
            elif (
                self.current_trainable_policy == "guard"
                and result["custom_metrics"]["guard_win_mean"]
                > args.sequential_win_rate_threshold
            ):
                print("2")
                print(trainer.workers.local_worker().policies_to_train)

                def fn(worker: RolloutWorker):
                    worker.set_policies_to_train(
                        policies_to_train={"adversary_policy",}
                    )
                    # if worker.evaluation_workers is not None:
                    #    worker.evaluation_workers.set_policies_to_train(
                    #        policies_to_train=["adversary_policy",])

                trainer.workers.foreach_worker(fn)
                # We need to sync the just copied local weights (from main policy)
                # to all the remote workers as well.

                self.current_trainable_policy = "attacker"
                print(
                    "guard win rate > "
                    + str(args.sequential_win_rate_threshold)
                    + ", switching to the attacker policy"
                )
            else:
                print(
                    self.current_trainable_policy
                    + " not good enough; will keep learning ..."
                )

    return SelfPlaySequentialCallback
