"""TensorFlow policy class used for R2D2."""

from typing import Dict, List, Optional, Tuple

import gym
import ray
from src.rllib.agents.dqn.dqn_tf_policy import clip_gradients, \
    compute_q_values, PRIO_WEIGHTS, postprocess_nstep_and_prio
from src.rllib.agents.dqn.dqn_tf_policy import build_q_model
from src.rllib.agents.dqn.simple_q_tf_policy import TargetNetworkMixin
from src.rllib.models.action_dist import ActionDistribution
from src.rllib.models.modelv2 import ModelV2
from src.rllib.models.tf.tf_action_dist import Categorical
from src.rllib.models.torch.torch_action_dist import TorchCategorical
from src.rllib.policy.policy import Policy
from src.rllib.policy.tf_policy_template import build_tf_policy
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.policy.tf_policy import LearningRateSchedule
from src.rllib.utils.framework import try_import_tf
from src.rllib.utils.tf_ops import huber_loss
from src.rllib.utils.typing import ModelInputDict, TensorType, \
    TrainerConfigDict

tf1, tf, tfv = try_import_tf()


def build_r2d2_model(policy: Policy, obs_space: gym.spaces.Space,
                     action_space: gym.spaces.Space, config: TrainerConfigDict
                     ) -> Tuple[ModelV2, ActionDistribution]:
    """Build q_model and target_model for DQN

    Args:
        policy (Policy): The policy, which will use the model for optimization.
        obs_space (gym.spaces.Space): The policy's observation space.
        action_space (gym.spaces.Space): The policy's action space.
        config (TrainerConfigDict):

    Returns:
        q_model
            Note: The target q model will not be returned, just assigned to
            `policy.target_model`.
    """

    # Create the policy's models.
    model = build_q_model(policy, obs_space, action_space, config)

    # Assert correct model type.
    assert model.get_initial_state() != [], \
        "R2D2 requires its model to be a recurrent one! Try using " \
        "`model.use_lstm` or `model.use_attention` in your config " \
        "to auto-wrap your model with an LSTM- or attention net."

    return model


def r2d2_loss(policy: Policy, model, _,
              train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for R2D2TFPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    config = policy.config

    # Construct internal state inputs.
    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches

    # Q-network evaluation (at t).
    q, _, _, _ = compute_q_values(
        policy,
        model,
        train_batch,
        state_batches=state_batches,
        seq_lens=train_batch.get("seq_lens"),
        explore=False,
        is_training=True)

    # Target Q-network evaluation (at t+1).
    q_target, _, _, _ = compute_q_values(
        policy,
        policy.target_model,
        train_batch,
        state_batches=state_batches,
        seq_lens=train_batch.get("seq_lens"),
        explore=False,
        is_training=True)

    if not hasattr(policy, "target_q_func_vars"):
        policy.target_q_func_vars = policy.target_model.variables()

    actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.int64)
    dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32)
    rewards = train_batch[SampleBatch.REWARDS]
    weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)

    B = tf.shape(state_batches[0])[0]
    T = tf.shape(q)[0] // B

    # Q scores for actions which we know were selected in the given state.
    one_hot_selection = tf.one_hot(actions, policy.action_space.n)
    q_selected = tf.reduce_sum(
        tf.where(q > tf.float32.min, q, tf.zeros_like(q)) * one_hot_selection,
        axis=1)

    if config["double_q"]:
        best_actions = tf.argmax(q, axis=1)
    else:
        best_actions = tf.argmax(q_target, axis=1)

    best_actions_one_hot = tf.one_hot(best_actions, policy.action_space.n)
    q_target_best = tf.reduce_sum(
        tf.where(q_target > tf.float32.min, q_target, tf.zeros_like(q_target))
        * best_actions_one_hot,
        axis=1)

    if config["num_atoms"] > 1:
        raise ValueError("Distributional R2D2 not supported yet!")
    else:
        q_target_best_masked_tp1 = (1.0 - dones) * tf.concat(
            [q_target_best[1:], tf.constant([0.0])], axis=0)

        if config["use_h_function"]:
            h_inv = h_inverse(q_target_best_masked_tp1,
                              config["h_function_epsilon"])
            target = h_function(
                rewards + config["gamma"]**config["n_step"] * h_inv,
                config["h_function_epsilon"])
        else:
            target = rewards + \
                config["gamma"] ** config["n_step"] * q_target_best_masked_tp1

        # Seq-mask all loss-related terms.
        seq_mask = tf.sequence_mask(train_batch["seq_lens"], T)[:, :-1]
        # Mask away also the burn-in sequence at the beginning.
        burn_in = policy.config["burn_in"]
        # Making sure, this works for both static graph and eager.
        if burn_in > 0 and (config["framework"] == "tf" or burn_in < T):
            seq_mask = tf.concat(
                [tf.fill([B, burn_in], False), seq_mask[:, burn_in:]], axis=1)

        def reduce_mean_valid(t):
            return tf.reduce_mean(tf.boolean_mask(t, seq_mask))

        # Make sure use the correct time indices:
        # Q(t) - [gamma * r + Q^(t+1)]
        q_selected = tf.reshape(q_selected, [B, T])[:, :-1]
        td_error = q_selected - tf.stop_gradient(
            tf.reshape(target, [B, T])[:, :-1])
        td_error = td_error * tf.cast(seq_mask, tf.float32)
        weights = tf.reshape(weights, [B, T])[:, :-1]
        policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
        policy._td_error = tf.reshape(td_error, [-1])
        policy._loss_stats = {
            "mean_q": reduce_mean_valid(q_selected),
            "min_q": tf.reduce_min(q_selected),
            "max_q": tf.reduce_max(q_selected),
            "mean_td_error": reduce_mean_valid(td_error),
        }

    return policy._total_loss


def h_function(x, epsilon=1.0):
    """h-function to normalize target Qs, described in the paper [1].

    h(x) = sign(x) * [sqrt(abs(x) + 1) - 1] + epsilon * x

    Used in [1] in combination with h_inverse:
      targets = h(r + gamma * h_inverse(Q^))
    """
    return tf.sign(x) * (tf.sqrt(tf.abs(x) + 1.0) - 1.0) + epsilon * x


def h_inverse(x, epsilon=1.0):
    """Inverse if the above h-function, described in the paper [1].

    If x > 0.0:
    h-1(x) = [2eps * x + (2eps + 1) - sqrt(4eps x + (2eps + 1)^2)] /
        (2 * eps^2)

    If x < 0.0:
    h-1(x) = [2eps * x + (2eps + 1) + sqrt(-4eps x + (2eps + 1)^2)] /
        (2 * eps^2)
    """
    two_epsilon = epsilon * 2
    if_x_pos = (two_epsilon * x + (two_epsilon + 1.0) -
                tf.sqrt(4.0 * epsilon * x +
                        (two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
    if_x_neg = (two_epsilon * x - (two_epsilon + 1.0) +
                tf.sqrt(-4.0 * epsilon * x +
                        (two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
    return tf.where(x < 0.0, if_x_neg, if_x_pos)


class ComputeTDErrorMixin:
    """Assign the `compute_td_error` method to the R2D2TFPolicy

    This allows us to prioritize on the worker side.
    """

    def __init__(self):
        def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
                             importance_weights):
            input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
            input_dict[SampleBatch.ACTIONS] = act_t
            input_dict[SampleBatch.REWARDS] = rew_t
            input_dict[SampleBatch.NEXT_OBS] = obs_tp1
            input_dict[SampleBatch.DONES] = done_mask
            input_dict[PRIO_WEIGHTS] = importance_weights

            # Do forward pass on loss to update td error attribute
            r2d2_loss(self, self.model, None, input_dict)

            return self._td_error

        self.compute_td_error = compute_td_error


def get_distribution_inputs_and_class(
        policy: Policy,
        model: ModelV2,
        *,
        input_dict: ModelInputDict,
        state_batches: Optional[List[TensorType]] = None,
        seq_lens: Optional[TensorType] = None,
        explore: bool = True,
        is_training: bool = False,
        **kwargs) -> Tuple[TensorType, type, List[TensorType]]:

    if policy.config["framework"] == "torch":
        from src.rllib.agents.dqn.r2d2_torch_policy import \
            compute_q_values as torch_compute_q_values
        func = torch_compute_q_values
    else:
        func = compute_q_values

    q_vals, logits, probs_or_logits, state_out = func(
        policy, model, input_dict, state_batches, seq_lens, explore,
        is_training)

    policy.q_values = q_vals
    if not hasattr(policy, "q_func_vars"):
        policy.q_func_vars = model.variables()

    action_dist_class = TorchCategorical if \
        policy.config["framework"] == "torch" else Categorical

    return policy.q_values, action_dist_class, state_out


def adam_optimizer(policy: Policy, config: TrainerConfigDict
                   ) -> "tf.keras.optimizers.Optimizer":
    return tf1.train.AdamOptimizer(
        learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])


def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
    return dict({
        "cur_lr": policy.cur_lr,
    }, **policy._loss_stats)


def setup_early_mixins(policy: Policy, obs_space, action_space,
                       config: TrainerConfigDict) -> None:
    LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])


def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,
                     action_space: gym.spaces.Space,
                     config: TrainerConfigDict) -> None:
    ComputeTDErrorMixin.__init__(policy)
    TargetNetworkMixin.__init__(policy, obs_space, action_space, config)


R2D2TFPolicy = build_tf_policy(
    name="R2D2TFPolicy",
    loss_fn=r2d2_loss,
    get_default_config=lambda: ray.rllib.agents.dqn.r2d2.DEFAULT_CONFIG,
    postprocess_fn=postprocess_nstep_and_prio,
    stats_fn=build_q_stats,
    make_model=build_r2d2_model,
    action_distribution_fn=get_distribution_inputs_and_class,
    optimizer_fn=adam_optimizer,
    extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
    compute_gradients_fn=clip_gradients,
    extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error},
    before_init=setup_early_mixins,
    before_loss_init=before_loss_init,
    mixins=[
        TargetNetworkMixin,
        ComputeTDErrorMixin,
        LearningRateSchedule,
    ])
