import ray
from src.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip, _adjust_nstep
from src.rllib.evaluation.metrics import LEARNER_STATS_KEY
from src.rllib.models import ModelCatalog
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.utils.annotations import override
from src.rllib.utils.error import UnsupportedSpaceException
from src.rllib.policy.policy import Policy
from src.rllib.policy.tf_policy import TFPolicy
from src.rllib.utils.framework import try_import_tf, try_import_tfp

import logging
from gym.spaces import Box, Discrete
import numpy as np

logger = logging.getLogger(__name__)

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()


class MADDPGPostprocessing:
    """Implements agentwise termination signal and n-step learning."""

    @override(Policy)
    def postprocess_trajectory(self,
                               sample_batch,
                               other_agent_batches=None,
                               episode=None):
        # FIXME: Get done from info is required since agentwise done is not
        #  supported now.
        sample_batch[SampleBatch.DONES] = self.get_done_from_info(
            sample_batch[SampleBatch.INFOS])

        # N-step Q adjustments
        if self.config["n_step"] > 1:
            _adjust_nstep(self.config["n_step"], self.config["gamma"],
                          sample_batch[SampleBatch.CUR_OBS],
                          sample_batch[SampleBatch.ACTIONS],
                          sample_batch[SampleBatch.REWARDS],
                          sample_batch[SampleBatch.NEXT_OBS],
                          sample_batch[SampleBatch.DONES])

        return sample_batch


class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
    def __init__(self, obs_space, act_space, config):
        # _____ Initial Configuration
        config = dict(src.rllib.contrib.maddpg.DEFAULT_CONFIG, **config)
        self.config = config
        self.global_step = tf1.train.get_or_create_global_step()

        # FIXME: Get done from info is required since agentwise done is not
        # supported now.
        self.get_done_from_info = np.vectorize(
            lambda info: info.get("done", False))

        agent_id = config["agent_id"]
        if agent_id is None:
            raise ValueError("Must set `agent_id` in the policy config.")
        if type(agent_id) is not int:
            raise ValueError("Agent ids must be integers for MADDPG.")

        # _____ Environment Setting
        def _make_continuous_space(space):
            if isinstance(space, Box):
                return space
            elif isinstance(space, Discrete):
                return Box(
                    low=np.zeros((space.n, )), high=np.ones((space.n, )))
            else:
                raise UnsupportedSpaceException(
                    "Space {} is not supported.".format(space))

        obs_space_n = [
            _make_continuous_space(space)
            for _, (_, space, _,
                    _) in config["multiagent"]["policies"].items()
        ]
        act_space_n = [
            _make_continuous_space(space)
            for _, (_, _, space,
                    _) in config["multiagent"]["policies"].items()
        ]

        # _____ Placeholders
        # Placeholders for policy evaluation and updates
        def _make_ph_n(space_n, name=""):
            return [
                tf1.placeholder(
                    tf.float32,
                    shape=(None, ) + space.shape,
                    name=name + "_%d" % i) for i, space in enumerate(space_n)
            ]

        obs_ph_n = _make_ph_n(obs_space_n, SampleBatch.OBS)
        act_ph_n = _make_ph_n(act_space_n, SampleBatch.ACTIONS)
        new_obs_ph_n = _make_ph_n(obs_space_n, SampleBatch.NEXT_OBS)
        new_act_ph_n = _make_ph_n(act_space_n, "new_actions")
        rew_ph = tf1.placeholder(
            tf.float32, shape=None, name="rewards_{}".format(agent_id))
        done_ph = tf1.placeholder(
            tf.float32, shape=None, name="dones_{}".format(agent_id))

        if config["use_local_critic"]:
            obs_space_n, act_space_n = [obs_space_n[agent_id]], [
                act_space_n[agent_id]
            ]
            obs_ph_n, act_ph_n = [obs_ph_n[agent_id]], [act_ph_n[agent_id]]
            new_obs_ph_n, new_act_ph_n = [new_obs_ph_n[agent_id]], [
                new_act_ph_n[agent_id]
            ]
            agent_id = 0

        # _____ Value Network
        # Build critic network for t.
        critic, _, critic_model_n, critic_vars = self._build_critic_network(
            obs_ph_n,
            act_ph_n,
            obs_space_n,
            act_space_n,
            config["use_state_preprocessor"],
            config["critic_hiddens"],
            getattr(tf.nn, config["critic_hidden_activation"]),
            scope="critic")

        # Build critic network for t + 1.
        target_critic, _, _, target_critic_vars = self._build_critic_network(
            new_obs_ph_n,
            new_act_ph_n,
            obs_space_n,
            act_space_n,
            config["use_state_preprocessor"],
            config["critic_hiddens"],
            getattr(tf.nn, config["critic_hidden_activation"]),
            scope="target_critic")

        # Build critic loss.
        td_error = tf.subtract(
            tf.stop_gradient(
                rew_ph + (1.0 - done_ph) *
                (config["gamma"]**config["n_step"]) * target_critic[:, 0]),
            critic[:, 0])
        critic_loss = tf.reduce_mean(td_error**2)

        # _____ Policy Network
        # Build actor network for t.
        act_sampler, actor_feature, actor_model, actor_vars = (
            self._build_actor_network(
                obs_ph_n[agent_id],
                obs_space_n[agent_id],
                act_space_n[agent_id],
                config["use_state_preprocessor"],
                config["actor_hiddens"],
                getattr(tf.nn, config["actor_hidden_activation"]),
                scope="actor"))

        # Build actor network for t + 1.
        self.new_obs_ph = new_obs_ph_n[agent_id]
        self.target_act_sampler, _, _, target_actor_vars = (
            self._build_actor_network(
                self.new_obs_ph,
                obs_space_n[agent_id],
                act_space_n[agent_id],
                config["use_state_preprocessor"],
                config["actor_hiddens"],
                getattr(tf.nn, config["actor_hidden_activation"]),
                scope="target_actor"))

        # Build actor loss.
        act_n = act_ph_n.copy()
        act_n[agent_id] = act_sampler
        critic, _, _, _ = self._build_critic_network(
            obs_ph_n,
            act_n,
            obs_space_n,
            act_space_n,
            config["use_state_preprocessor"],
            config["critic_hiddens"],
            getattr(tf.nn, config["critic_hidden_activation"]),
            scope="critic")
        actor_loss = -tf.reduce_mean(critic)
        if config["actor_feature_reg"] is not None:
            actor_loss += config["actor_feature_reg"] * tf.reduce_mean(
                actor_feature**2)

        # _____ Losses
        self.losses = {"critic": critic_loss, "actor": actor_loss}

        # _____ Optimizers
        self.optimizers = {
            "critic": tf1.train.AdamOptimizer(config["critic_lr"]),
            "actor": tf1.train.AdamOptimizer(config["actor_lr"])
        }

        # _____ Build variable update ops.
        self.tau = tf1.placeholder_with_default(
            config["tau"], shape=(), name="tau")

        def _make_target_update_op(vs, target_vs, tau):
            return [
                target_v.assign(tau * v + (1.0 - tau) * target_v)
                for v, target_v in zip(vs, target_vs)
            ]

        self.update_target_vars = _make_target_update_op(
            critic_vars + actor_vars, target_critic_vars + target_actor_vars,
            self.tau)

        def _make_set_weight_op(variables):
            vs = list()
            for v in variables.values():
                vs += v
            phs = [
                tf1.placeholder(
                    tf.float32,
                    shape=v.get_shape(),
                    name=v.name.split(":")[0] + "_ph") for v in vs
            ]
            return tf.group(*[v.assign(ph) for v, ph in zip(vs, phs)]), phs

        self.vars = {
            "critic": critic_vars,
            "actor": actor_vars,
            "target_critic": target_critic_vars,
            "target_actor": target_actor_vars
        }
        self.update_vars, self.vars_ph = _make_set_weight_op(self.vars)

        # _____ TensorFlow Initialization

        sess = tf1.get_default_session()
        assert sess

        def _make_loss_inputs(placeholders):
            return [(ph.name.split("/")[-1].split(":")[0], ph)
                    for ph in placeholders]

        loss_inputs = _make_loss_inputs(obs_ph_n + act_ph_n + new_obs_ph_n +
                                        new_act_ph_n + [rew_ph, done_ph])

        TFPolicy.__init__(
            self,
            obs_space,
            act_space,
            config=config,
            sess=sess,
            obs_input=obs_ph_n[agent_id],
            sampled_action=act_sampler,
            loss=actor_loss + critic_loss,
            loss_inputs=loss_inputs,
            dist_inputs=actor_feature)

        del self.view_requirements["prev_actions"]
        del self.view_requirements["prev_rewards"]

        self.get_session().run(tf1.global_variables_initializer())

        # Hard initial update
        self.update_target(1.0)

    @override(TFPolicy)
    def optimizer(self):
        return None

    @override(TFPolicy)
    def gradients(self, optimizer, loss):
        self.gvs = {
            k: minimize_and_clip(optimizer, self.losses[k], self.vars[k],
                                 self.config["grad_norm_clipping"])
            for k, optimizer in self.optimizers.items()
        }
        return self.gvs["critic"] + self.gvs["actor"]

    @override(TFPolicy)
    def build_apply_op(self, optimizer, grads_and_vars):
        critic_apply_op = self.optimizers["critic"].apply_gradients(
            self.gvs["critic"])

        with tf1.control_dependencies([tf1.assign_add(self.global_step, 1)]):
            with tf1.control_dependencies([critic_apply_op]):
                actor_apply_op = self.optimizers["actor"].apply_gradients(
                    self.gvs["actor"])

        return actor_apply_op

    @override(TFPolicy)
    def extra_compute_action_feed_dict(self):
        return {}

    @override(TFPolicy)
    def extra_compute_grad_fetches(self):
        return {LEARNER_STATS_KEY: {}}

    @override(TFPolicy)
    def get_weights(self):
        var_list = []
        for var in self.vars.values():
            var_list += var
        return {"_state": self.get_session().run(var_list)}

    @override(TFPolicy)
    def set_weights(self, weights):
        self.get_session().run(
            self.update_vars,
            feed_dict=dict(zip(self.vars_ph, weights["_state"])))

    @override(Policy)
    def get_state(self):
        return TFPolicy.get_state(self)

    @override(Policy)
    def set_state(self, state):
        TFPolicy.set_state(self, state)

    def _build_critic_network(self,
                              obs_n,
                              act_n,
                              obs_space_n,
                              act_space_n,
                              use_state_preprocessor,
                              hiddens,
                              activation=None,
                              scope=None):
        with tf1.variable_scope(scope, reuse=tf1.AUTO_REUSE) as scope:
            if use_state_preprocessor:
                model_n = [
                    ModelCatalog.get_model({
                        SampleBatch.OBS: obs,
                        "is_training": self._get_is_training_placeholder(),
                    }, obs_space, act_space, 1, self.config["model"])
                    for obs, obs_space, act_space in zip(
                        obs_n, obs_space_n, act_space_n)
                ]
                out_n = [model.last_layer for model in model_n]
                out = tf.concat(out_n + act_n, axis=1)
            else:
                model_n = [None] * len(obs_n)
                out = tf.concat(obs_n + act_n, axis=1)

            for hidden in hiddens:
                out = tf1.layers.dense(
                    out, units=hidden, activation=activation)
            feature = out
            out = tf1.layers.dense(feature, units=1, activation=None)

        return out, feature, model_n, tf1.global_variables(scope.name)

    def _build_actor_network(self,
                             obs,
                             obs_space,
                             act_space,
                             use_state_preprocessor,
                             hiddens,
                             activation=None,
                             scope=None):
        with tf1.variable_scope(scope, reuse=tf1.AUTO_REUSE) as scope:
            if use_state_preprocessor:
                model = ModelCatalog.get_model({
                    SampleBatch.OBS: obs,
                    "is_training": self._get_is_training_placeholder(),
                }, obs_space, act_space, 1, self.config["model"])
                out = model.last_layer
            else:
                model = None
                out = obs

            for hidden in hiddens:
                out = tf1.layers.dense(
                    out, units=hidden, activation=activation)
            feature = tf1.layers.dense(
                out, units=act_space.shape[0], activation=None)
            sampler = tfp.distributions.RelaxedOneHotCategorical(
                temperature=1.0, logits=feature).sample()

        return sampler, feature, model, tf1.global_variables(scope.name)

    def update_target(self, tau=None):
        if tau is not None:
            self.get_session().run(self.update_target_vars, {self.tau: tau})
        else:
            self.get_session().run(self.update_target_vars)
