"""TensorFlow policy class used for Simple Q-Learning"""

import logging
from typing import List, Tuple, Type

import gym
import ray
from src.rllib.models import ModelCatalog
from src.rllib.models.modelv2 import ModelV2
from src.rllib.models.tf.tf_action_dist import (Categorical,
                                                TFActionDistribution)
from src.rllib.models.torch.torch_action_dist import TorchCategorical
from src.rllib.policy import Policy
from src.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.policy.tf_policy import TFPolicy
from src.rllib.policy.tf_policy_template import build_tf_policy
from src.rllib.utils.annotations import override
from src.rllib.utils.error import UnsupportedSpaceException
from src.rllib.utils.framework import try_import_tf
from src.rllib.utils.tf_ops import huber_loss, make_tf_callable
from src.rllib.utils.typing import TensorType, TrainerConfigDict

tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)

Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"


class TargetNetworkMixin:
    """Assign the `update_target` method to the SimpleQTFPolicy

    The function is called every `target_network_update_freq` steps by the
    master learner.
    """

    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, config: TrainerConfigDict):
        @make_tf_callable(self.get_session())
        def do_update():
            # update_target_fn will be called periodically to copy Q network to
            # target Q network
            update_target_expr = []
            assert len(self.q_func_vars) == len(self.target_q_func_vars), \
                (self.q_func_vars, self.target_q_func_vars)
            for var, var_target in zip(self.q_func_vars,
                                       self.target_q_func_vars):
                update_target_expr.append(var_target.assign(var))
                logger.debug("Update target op {}".format(var_target))
            return tf.group(*update_target_expr)

        self.update_target = do_update

    @override(TFPolicy)
    def variables(self):
        if not hasattr(self, "q_func_vars"):
            self.q_func_vars = self.model.variables()
        if not hasattr(self, "target_q_func_vars"):
            self.target_q_func_vars = self.target_model.variables()
        return self.q_func_vars + self.target_q_func_vars


def build_q_models(policy: Policy, obs_space: gym.spaces.Space,
                   action_space: gym.spaces.Space,
                   config: TrainerConfigDict) -> ModelV2:
    """Build q_model and target_model for Simple Q learning

    Note that this function works for both Tensorflow and PyTorch.

    Args:
        policy (Policy): The Policy, which will use the model for optimization.
        obs_space (gym.spaces.Space): The policy's observation space.
        action_space (gym.spaces.Space): The policy's action space.
        config (TrainerConfigDict):

    Returns:
        ModelV2: The Model for the Policy to use.
            Note: The target q model will not be returned, just assigned to
            `policy.target_model`.
    """
    if not isinstance(action_space, gym.spaces.Discrete):
        raise UnsupportedSpaceException(
            "Action space {} is not supported for DQN.".format(action_space))

    model = ModelCatalog.get_model_v2(
        obs_space=obs_space,
        action_space=action_space,
        num_outputs=action_space.n,
        model_config=config["model"],
        framework=config["framework"],
        name=Q_SCOPE)

    policy.target_model = ModelCatalog.get_model_v2(
        obs_space=obs_space,
        action_space=action_space,
        num_outputs=action_space.n,
        model_config=config["model"],
        framework=config["framework"],
        name=Q_TARGET_SCOPE)

    return model


def get_distribution_inputs_and_class(
        policy: Policy,
        q_model: ModelV2,
        obs_batch: TensorType,
        *,
        explore=True,
        is_training=True,
        **kwargs) -> Tuple[TensorType, type, List[TensorType]]:
    """Build the action distribution"""
    q_vals = compute_q_values(policy, q_model, obs_batch, explore, is_training)
    q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

    policy.q_values = q_vals
    return policy.q_values, (TorchCategorical
                             if policy.config["framework"] == "torch" else
                             Categorical), []  # state-outs


def build_q_losses(policy: Policy, model: ModelV2,
                   dist_class: Type[TFActionDistribution],
                   train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for SimpleQTFPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[ActionDistribution]): The action distribution class.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    # q network evaluation
    q_t = compute_q_values(
        policy, policy.model, train_batch[SampleBatch.CUR_OBS], explore=False)

    # target q network evalution
    q_tp1 = compute_q_values(
        policy,
        policy.target_model,
        train_batch[SampleBatch.NEXT_OBS],
        explore=False)
    if not hasattr(policy, "q_func_vars"):
        policy.q_func_vars = model.variables()
        policy.target_q_func_vars = policy.target_model.variables()

    # q scores for actions which we know were selected in the given state.
    one_hot_selection = tf.one_hot(
        tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32),
        policy.action_space.n)
    q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)

    # compute estimate of best possible value starting from state at t + 1
    dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32)
    q_tp1_best_one_hot_selection = tf.one_hot(
        tf.argmax(q_tp1, 1), policy.action_space.n)
    q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
    q_tp1_best_masked = (1.0 - dones) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           policy.config["gamma"] * q_tp1_best_masked)

    # compute the error (potentially clipped)
    td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
    loss = tf.reduce_mean(huber_loss(td_error))

    # save TD error as an attribute for outside access
    policy.td_error = td_error

    return loss


def compute_q_values(policy: Policy,
                     model: ModelV2,
                     obs: TensorType,
                     explore,
                     is_training=None) -> TensorType:
    model_out, _ = model({
        SampleBatch.CUR_OBS: obs,
        "is_training": is_training
        if is_training is not None else policy._get_is_training_placeholder(),
    }, [], None)

    return model_out


def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
                      action_space: gym.spaces.Space,
                      config: TrainerConfigDict) -> None:
    """Call all mixin classes' constructors before SimpleQTFPolicy initialization.

    Args:
        policy (Policy): The Policy object.
        obs_space (gym.spaces.Space): The Policy's observation space.
        action_space (gym.spaces.Space): The Policy's action space.
        config (TrainerConfigDict): The Policy's config.
    """
    TargetNetworkMixin.__init__(policy, obs_space, action_space, config)


# Build a child class of `DynamicTFPolicy`, given the custom functions defined
# above.
SimpleQTFPolicy: Type[DynamicTFPolicy] = build_tf_policy(
    name="SimpleQTFPolicy",
    get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
    make_model=build_q_models,
    action_distribution_fn=get_distribution_inputs_and_class,
    loss_fn=build_q_losses,
    extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
    extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
    after_init=setup_late_mixins,
    mixins=[TargetNetworkMixin])
