"""PyTorch policy class used for Simple Q-Learning"""

import logging
from typing import List, Dict, Tuple

import gym
from agents.dqn.simple_q.config import SIMPLE_Q_DEFAULT_CONFIG
from models import ModelCatalog
from models.modelv2 import ModelV2
from models.torch.torch_action_dist import TorchCategorical, \
    TorchDistributionWrapper
from policy import Policy
from policy.policy_template import build_policy_class
from policy.sample_batch import SampleBatch
from policy.torch_policy import TorchPolicy
from utils.annotations import override
from utils.framework import try_import_torch
from utils.torch_utils import concat_multi_gpu_td_errors, huber_loss
from utils.typing import TensorType, TrainerConfigDict
from utils.error import UnsupportedSpaceException

torch, nn = try_import_torch()
F = None
if nn:
    F = nn.functional
logger = logging.getLogger(__name__)

Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"


class TargetNetworkMixin:
    """Assign the `update_target` method to the SimpleQTorchPolicy

    The function is called every `target_network_update_freq` steps by the
    master learner.
    """

    def __init__(self):
        # Hard initial update from Q-net(s) to target Q-net(s).
        self.update_target()

    def update_target(self):
        # Update_target_fn will be called periodically to copy Q network to
        # target Q networks.
        state_dict = self.model.state_dict()
        for target in self.target_models.values():
            target.load_state_dict(state_dict)

    @override(TorchPolicy)
    def set_weights(self, weights):
        # Makes sure that whenever we restore weights for this policy's
        # model, we sync the target network (from the main model)
        # at the same time.
        TorchPolicy.set_weights(self, weights)
        self.update_target()


def build_q_model_and_distribution(
        policy: Policy, obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
    """Build q_model and target_model for Simple Q learning

    Note that this function works for both Tensorflow and PyTorch.

    Args:
        policy (Policy): The Policy, which will use the model for optimization.
        obs_space (gym.spaces.Space): The policy's observation space.
        action_space (gym.spaces.Space): The policy's action space.
        config (TrainerConfigDict):

    Returns:
        ModelV2: The Model for the Policy to use.
            Note: The target q model will not be returned, just assigned to
            `policy.target_model`.
    """
    if not isinstance(action_space, gym.spaces.Discrete):
        raise UnsupportedSpaceException(
            "Action space {} is not supported for DQN.".format(action_space))

    model = ModelCatalog.get_model_v2(
        obs_space=obs_space,
        action_space=action_space,
        num_outputs=action_space.n,
        model_config=config["model"],
        framework=config["framework"],
        name=Q_SCOPE)

    policy.target_model = ModelCatalog.get_model_v2(
        obs_space=obs_space,
        action_space=action_space,
        num_outputs=action_space.n,
        model_config=config["model"],
        framework=config["framework"],
        name=Q_TARGET_SCOPE)

    return model, TorchCategorical


def compute_q_values(policy: Policy,
                     model: ModelV2,
                     obs: TensorType,
                     explore,
                     is_training=None) -> TensorType:
    _is_training = (is_training if is_training is not None else
                    policy._get_is_training_placeholder())
    model_out, _ = model(
        SampleBatch(obs=obs, _is_training=_is_training), [], None)

    return model_out


def get_distribution_inputs_and_class(
        policy: Policy,
        q_model: ModelV2,
        obs_batch: TensorType,
        *,
        explore=True,
        is_training=True,
        **kwargs) -> Tuple[TensorType, type, List[TensorType]]:
    """Build the action distribution"""
    q_vals = compute_q_values(policy, q_model, obs_batch, explore, is_training)
    q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

    policy.q_values = q_vals
    return policy.q_values, TorchCategorical, []  # state-outs


def build_q_losses(policy: Policy, model, dist_class,
                   train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for SimpleQTorchPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[ActionDistribution]): The action distribution class.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    target_model = policy.target_models[model]

    # q network evaluation
    q_t = compute_q_values(
        policy,
        model,
        train_batch[SampleBatch.CUR_OBS],
        explore=False,
        is_training=True)

    # target q network evalution
    q_tp1 = compute_q_values(
        policy,
        target_model,
        train_batch[SampleBatch.NEXT_OBS],
        explore=False,
        is_training=True)

    # q scores for actions which we know were selected in the given state.
    one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                                  policy.action_space.n)
    q_t_selected = torch.sum(q_t * one_hot_selection, 1)

    # compute estimate of best possible value starting from state at t + 1
    dones = train_batch[SampleBatch.DONES].float()
    q_tp1_best_one_hot_selection = F.one_hot(
        torch.argmax(q_tp1, 1), policy.action_space.n)
    q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
    q_tp1_best_masked = (1.0 - dones) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           policy.config["gamma"] * q_tp1_best_masked)

    # Compute the error (Square/Huber).
    td_error = q_t_selected - q_t_selected_target.detach()
    loss = torch.mean(huber_loss(td_error))

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["loss"] = loss
    # TD-error tensor in final stats
    # will be concatenated and retrieved for each individual batch item.
    model.tower_stats["td_error"] = td_error

    return loss


def stats_fn(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]:
    return {"loss": torch.mean(torch.stack(policy.get_tower_stats("loss")))}


def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
                        action_dist) -> Dict[str, TensorType]:
    """Adds q-values to the action out dict."""
    return {"q_values": policy.q_values}


def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
                      action_space: gym.spaces.Space,
                      config: TrainerConfigDict) -> None:
    """Call all mixin classes' constructors before SimpleQTorchPolicy
    initialization.

    Args:
        policy (Policy): The Policy object.
        obs_space (gym.spaces.Space): The Policy's observation space.
        action_space (gym.spaces.Space): The Policy's action space.
        config (TrainerConfigDict): The Policy's config.
    """
    TargetNetworkMixin.__init__(policy)


SimpleQTorchPolicy = build_policy_class(
    name="SimpleQPolicy",
    framework="torch",
    loss_fn=build_q_losses,
    get_default_config=lambda: SIMPLE_Q_DEFAULT_CONFIG,
    stats_fn=stats_fn,
    extra_action_out_fn=extra_action_out_fn,
    after_init=setup_late_mixins,
    make_model_and_action_dist=build_q_model_and_distribution,
    mixins=[TargetNetworkMixin],
    action_distribution_fn=get_distribution_inputs_and_class,
    extra_learn_fetches_fn=concat_multi_gpu_td_errors,
)
