from collections import deque
import gym
import os
import pickle
from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING

from src.rllib.policy.policy import PolicySpec
from src.rllib.utils.annotations import override
from src.rllib.utils.framework import try_import_tf
from src.rllib.utils.tf_ops import get_tf_eager_cls_if_necessary
from src.rllib.utils.typing import PartialTrainerConfigDict, \
    PolicyID, TrainerConfigDict
from ray.tune.utils.util import merge_dicts

if TYPE_CHECKING:
    from src.rllib.policy.policy import Policy

tf1, tf, tfv = try_import_tf()


class PolicyMap(dict):
    """Maps policy IDs to Policy objects.

    Thereby, keeps n policies in memory and - when capacity is reached -
    writes the least recently used to disk. This allows adding 100s of
    policies to a Trainer for league-based setups w/o running out of memory.
    """

    def __init__(
            self,
            worker_index: int,
            num_workers: int,
            capacity: Optional[int] = None,
            path: Optional[str] = None,
            policy_config: Optional[TrainerConfigDict] = None,
            session_creator: Optional[Callable[[], "tf1.Session"]] = None,
            seed: Optional[int] = None,
    ):
        """Initializes a PolicyMap instance.

        Args:
            worker_index (int): The worker index of the RolloutWorker this map
                resides in.
            num_workers (int): The total number of remote workers in the
                WorkerSet to which this map's RolloutWorker belongs to.
            capacity (int): The maximum number of policies to hold in memory.
                The least used ones are written to disk/S3 and retrieved
                when needed.
            path (str): The path to store the policy pickle files to. Files
                will have the name: [policy_id].[worker idx].policy.pkl.
            policy_config (TrainerConfigDict): The Trainer's base config dict.
            session_creator (Optional[Callable[[], tf1.Session]): An optional
                tf1.Session creation callable.
            seed (int): An optional seed (used to seed tf policies).
        """
        super().__init__()

        self.worker_index = worker_index
        self.num_workers = num_workers
        self.session_creator = session_creator
        self.seed = seed

        # The file extension for stashed policies (that are no longer available
        # in-memory but can be reinstated any time from storage).
        self.extension = f".{self.worker_index}.policy.pkl"

        # Dictionary of keys that may be looked up (cached or not).
        self.valid_keys: Set[str] = set()
        # The actual cache with the in-memory policy objects.
        self.cache: Dict[str, Policy] = {}
        # The doubly-linked list holding the currently in-memory objects.
        self.deque = deque(maxlen=capacity or 10)
        # The file path where to store overflowing policies.
        self.path = path or "."
        # The core config to use. Each single policy's config override is
        # added on top of this.
        self.policy_config: TrainerConfigDict = policy_config or {}
        # The orig classes/obs+act spaces, and config overrides of the
        # Policies.
        self.policy_specs: Dict[PolicyID, PolicySpec] = {}

    def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
                      observation_space: gym.Space, action_space: gym.Space,
                      config_override: PartialTrainerConfigDict,
                      merged_config: TrainerConfigDict) -> None:
        """Creates a new policy and stores it to the cache.

        Args:
            policy_id (PolicyID): The policy ID. This is the key under which
                the created policy will be stored in this map.
            policy_cls (Type[Policy]): The (original) policy class to use.
                This may still be altered in case tf-eager (and tracing)
                is used.
            observation_space (gym.Space): The observation space of the
                policy.
            action_space (gym.Space): The action space of the policy.
            config_override (PartialTrainerConfigDict): The config override
                dict for this policy. This is the partial dict provided by
                the user.
            merged_config (TrainerConfigDict): The entire config (merged
                default config + `config_override`).
        """
        framework = merged_config.get("framework", "tf")
        class_ = get_tf_eager_cls_if_necessary(policy_cls, merged_config)

        # Tf.
        if framework in ["tf2", "tf", "tfe"]:
            var_scope = policy_id + (("_wk" + str(self.worker_index))
                                     if self.worker_index else "")

            # For tf static graph, build every policy in its own graph
            # and create a new session for it.
            if framework == "tf":
                with tf1.Graph().as_default():
                    if self.session_creator:
                        sess = self.session_creator()
                    else:
                        sess = tf1.Session(
                            config=tf1.ConfigProto(
                                gpu_options=tf1.GPUOptions(allow_growth=True)))
                    with sess.as_default():
                        # Set graph-level seed.
                        if self.seed is not None:
                            tf1.set_random_seed(self.seed)
                        with tf1.variable_scope(var_scope):
                            self[policy_id] = class_(
                                observation_space, action_space, merged_config)
            # For tf-eager: no graph, no session.
            else:
                with tf1.variable_scope(var_scope):
                    self[policy_id] = \
                        class_(observation_space, action_space, merged_config)
        # Non-tf: No graph, no session.
        else:
            class_ = policy_cls
            self[policy_id] = class_(observation_space, action_space,
                                     merged_config)

        # Store spec (class, obs-space, act-space, and config overrides) such
        # that the map will be able to reproduce on-the-fly added policies
        # from disk.
        self.policy_specs[policy_id] = PolicySpec(
            policy_class=policy_cls,
            observation_space=observation_space,
            action_space=action_space,
            config=config_override)

    @override(dict)
    def __getitem__(self, item):
        # Never seen this key -> Error.
        if item not in self.valid_keys:
            raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!")

        # Item already in cache -> Rearrange deque (least recently used) and
        # return.
        if item in self.cache:
            self.deque.remove(item)
            self.deque.append(item)
        # Item not currently in cache -> Get from disk and - if at capacity -
        # remove leftmost one.
        else:
            self._read_from_disk(policy_id=item)

        return self.cache[item]

    @override(dict)
    def __setitem__(self, key, value):
        # Item already in cache -> Rearrange deque (least recently used).
        if key in self.cache:
            self.deque.remove(key)
            self.deque.append(key)
            self.cache[key] = value
        # Item not currently in cache -> store new value and - if at capacity -
        # remove leftmost one.
        else:
            # Cache at capacity -> Drop leftmost item.
            if len(self.deque) == self.deque.maxlen:
                self._stash_to_disk()
            self.deque.append(key)
            self.cache[key] = value
        self.valid_keys.add(key)

    @override(dict)
    def __delitem__(self, key):
        # Make key invalid.
        self.valid_keys.remove(key)
        # Remove policy from memory if currently cached.
        if key in self.cache:
            policy = self.cache[key]
            self._close_session(policy)
            del self.cache[key]
        # Remove file associated with the policy, if it exists.
        filename = self.path + "/" + key + self.extension
        if os.path.isfile(filename):
            os.remove(filename)

    @override(dict)
    def __iter__(self):
        return self.keys()

    @override(dict)
    def items(self):
        """Iterates over all policies, even the stashed-to-disk ones."""

        def gen():
            for key in self.valid_keys:
                yield (key, self[key])

        return gen()

    @override(dict)
    def keys(self):
        def gen():
            for key in self.valid_keys:
                yield key

        return gen()

    @override(dict)
    def values(self):
        def gen():
            for key in self.valid_keys:
                yield self[key]

        return gen()

    @override(dict)
    def update(self, __m, **kwargs):
        for k, v in __m.items():
            self[k] = v
        for k, v in kwargs.items():
            self[k] = v

    @override(dict)
    def get(self, key):
        if key not in self.valid_keys:
            return None
        return self[key]

    @override(dict)
    def __len__(self):
        """Returns number of all policies, including the stashed-to-disk ones.
        """
        return len(self.valid_keys)

    @override(dict)
    def __contains__(self, item):
        return item in self.valid_keys

    def _stash_to_disk(self):
        """Writes the least-recently used policy to disk and rearranges cache.

        Also closes the session - if applicable - of the stashed policy.
        """
        # Get least recently used policy (all the way on the left in deque).
        delkey = self.deque.popleft()
        policy = self.cache[delkey]
        # Get its state for writing to disk.
        policy_state = policy.get_state()
        # Closes policy's tf session, if any.
        self._close_session(policy)
        # Remove from memory. This will clear the tf Graph as well.
        del self.cache[delkey]
        # Write state to disk.
        with open(self.path + "/" + delkey + self.extension, "wb") as f:
            pickle.dump(policy_state, file=f)

    def _read_from_disk(self, policy_id):
        """Reads a policy ID from disk and re-adds it to the cache.
        """
        # Make sure this policy ID is not in the cache right now.
        assert policy_id not in self.cache
        # Read policy state from disk.
        with open(self.path + "/" + policy_id + self.extension, "rb") as f:
            policy_state = pickle.load(f)

        # Get class and config override.
        merged_conf = merge_dicts(self.policy_config,
                                  self.policy_specs[policy_id].config)

        # Create policy object (from its spec: cls, obs-space, act-space,
        # config).
        self.create_policy(
            policy_id,
            self.policy_specs[policy_id].policy_class,
            self.policy_specs[policy_id].observation_space,
            self.policy_specs[policy_id].action_space,
            self.policy_specs[policy_id].config,
            merged_conf,
        )
        # Restore policy's state.
        policy = self[policy_id]
        policy.set_state(policy_state)

    def _close_session(self, policy):
        sess = policy.get_session()
        # Closes the tf session, if any.
        if sess is not None:
            sess.close()
