"""PyTorch policy class used for DQN"""

from typing import Dict, List, Tuple

import gym
import ray
from ray.rllib.agents.dqn.dqn_tf_policy import (
    PRIO_WEIGHTS, Q_SCOPE, Q_TARGET_SCOPE)
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.agents.dqn.simple_q_torch_policy import TargetNetworkMixin
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
                                                      TorchDistributionWrapper)
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import apply_grad_clipping, \
    concat_multi_gpu_td_errors, FLOAT_MIN, huber_loss, \
    reduce_mean_ignore_inf, softmax_cross_entropy_with_logits
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
import numpy as np

torch, nn = try_import_torch()
F = None
if nn:
    F = nn.functional





def build_custom_dqn_policy(args):
    """
    Build a custom DQN policy.
    :param args: arguments
    :return: DQN policy
    """
    class QLoss:
        def __init__(self,
                    q_t_selected: TensorType,
                    q_logits_t_selected: TensorType,
                    q_tp1_best: TensorType,
                    q_probs_tp1_best: TensorType,
                    importance_weights: TensorType,
                    rewards: TensorType,
                    done_mask: TensorType,
                    gamma=0.99,
                    n_step=1,
                    num_atoms=1,
                    v_min=-10.0,
                    v_max=10.0):

            if num_atoms > 1:
                # Distributional Q-learning which corresponds to an entropy loss
                z = torch.range(
                    0.0, num_atoms - 1, dtype=torch.float32).to(rewards.device)
                z = v_min + z * (v_max - v_min) / float(num_atoms - 1)

                # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
                r_tau = torch.unsqueeze(
                    rewards, -1) + gamma**n_step * torch.unsqueeze(
                        1.0 - done_mask, -1) * torch.unsqueeze(z, 0)
                r_tau = torch.clamp(r_tau, v_min, v_max)
                b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
                lb = torch.floor(b)
                ub = torch.ceil(b)

                # Indispensable judgement which is missed in most implementations
                # when b happens to be an integer, lb == ub, so pr_j(s', a*) will
                # be discarded because (ub-b) == (b-lb) == 0.
                floor_equal_ceil = ((ub - lb) < 0.5).float()

                # (batch_size, num_atoms, num_atoms)
                l_project = F.one_hot(lb.long(), num_atoms)
                # (batch_size, num_atoms, num_atoms)
                u_project = F.one_hot(ub.long(), num_atoms)
                ml_delta = q_probs_tp1_best * (ub - b + floor_equal_ceil)
                mu_delta = q_probs_tp1_best * (b - lb)
                ml_delta = torch.sum(
                    l_project * torch.unsqueeze(ml_delta, -1), dim=1)
                mu_delta = torch.sum(
                    u_project * torch.unsqueeze(mu_delta, -1), dim=1)
                m = ml_delta + mu_delta

                # Rainbow paper claims that using this cross entropy loss for
                # priority is robust and insensitive to `prioritized_replay_alpha`
                self.td_error = softmax_cross_entropy_with_logits(
                    logits=q_logits_t_selected, labels=m.detach())
                self.loss = torch.mean(self.td_error * importance_weights)
                self.stats = {
                    # TODO: better Q stats for dist dqn
                }
            else:
                q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best

                # compute RHS of bellman equation
                q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked

                # compute the error (potentially clipped)
                self.td_error = q_t_selected - q_t_selected_target.detach()
                self.loss = torch.mean(
                    importance_weights.float() * huber_loss(self.td_error))
                self.stats = {
                    "mean_q": torch.mean(q_t_selected),
                    "min_q": torch.min(q_t_selected),
                    "max_q": torch.max(q_t_selected),
                }


    class ComputeTDErrorMixin:
        """Assign the `compute_td_error` method to the DQNTorchPolicy
        This allows us to prioritize on the worker side.
        """

        def __init__(self):
            def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
                                importance_weights):
                input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
                input_dict[SampleBatch.ACTIONS] = act_t
                input_dict[SampleBatch.REWARDS] = rew_t
                input_dict[SampleBatch.NEXT_OBS] = obs_tp1
                input_dict[SampleBatch.DONES] = done_mask
                input_dict[PRIO_WEIGHTS] = importance_weights

                # Do forward pass on loss to update td error attribute
                build_q_losses(self, self.model, None, input_dict)

                return self.model.tower_stats["q_loss"].td_error

            self.compute_td_error = compute_td_error


    def build_q_model_and_distribution(
            policy: Policy, obs_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
        """Build q_model and target_model for DQN
        Args:
            policy (Policy): The policy, which will use the model for optimization.
            obs_space (gym.spaces.Space): The policy's observation space.
            action_space (gym.spaces.Space): The policy's action space.
            config (TrainerConfigDict):
        Returns:
            (q_model, TorchCategorical)
                Note: The target q model will not be returned, just assigned to
                `policy.target_model`.
        """
        if not isinstance(action_space, gym.spaces.Discrete):
            raise UnsupportedSpaceException(
                "Action space {} is not supported for DQN.".format(action_space))

        if config["hiddens"]:
            # try to infer the last layer size, otherwise fall back to 256
            num_outputs = ([128] + list(config["model"]["fcnet_hiddens"]))[-1]
            config["model"]["no_final_linear"] = True
        else:
            num_outputs = action_space.n

        # TODO(sven): Move option to add LayerNorm after each Dense
        #  generically into ModelCatalog.
        add_layer_norm = (
            isinstance(getattr(policy, "exploration", None), ParameterNoise)
            or config["exploration_config"]["type"] == "ParameterNoise")

        model = ModelCatalog.get_model_v2(
            obs_space=obs_space,
            action_space=action_space,
            num_outputs=num_outputs,
            model_config=config["model"],
            framework="torch",
            model_interface=DQNAttention,
            name=Q_SCOPE,
            q_hiddens=config["hiddens"],
            dueling=config["dueling"],
            num_atoms=config["num_atoms"],
            use_noisy=config["noisy"],
            v_min=config["v_min"],
            v_max=config["v_max"],
            sigma0=config["sigma0"],
            # TODO(sven): Move option to add LayerNorm after each Dense
            #  generically into ModelCatalog.
            add_layer_norm=add_layer_norm)

        policy.target_model = ModelCatalog.get_model_v2(
            obs_space=obs_space,
            action_space=action_space,
            num_outputs=num_outputs,
            model_config=config["model"],
            framework="torch",
            model_interface=DQNAttention,
            name=Q_TARGET_SCOPE,
            q_hiddens=config["hiddens"],
            dueling=config["dueling"],
            num_atoms=config["num_atoms"],
            use_noisy=config["noisy"],
            v_min=config["v_min"],
            v_max=config["v_max"],
            sigma0=config["sigma0"],
            # TODO(sven): Move option to add LayerNorm after each Dense
            #  generically into ModelCatalog.
            add_layer_norm=add_layer_norm)

        return model, TorchCategorical


    def get_distribution_inputs_and_class(
            policy: Policy,
            model: ModelV2,
            input_dict: SampleBatch,
            *,
            explore: bool = True,
            is_training: bool = False,
            **kwargs) -> Tuple[TensorType, type, List[TensorType]]:
        q_vals = compute_q_values(
            policy, model, input_dict, explore=explore, is_training=is_training)
        q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

        model.tower_stats["q_values"] = q_vals

        return q_vals, TorchCategorical, []  # state-out


    def build_q_losses(policy: Policy, model, _,
                    train_batch: SampleBatch) -> TensorType:
        """Constructs the loss for DQNTorchPolicy.
        Args:
            policy (Policy): The Policy to calculate the loss for.
            model (ModelV2): The Model to calculate the loss for.
            train_batch (SampleBatch): The training data.
        Returns:
            TensorType: A single loss tensor.
        """

        config = policy.config
        # Q-network evaluation.
        q_t, q_logits_t, q_probs_t, _ = compute_q_values(
            policy,
            model, {"obs": train_batch[SampleBatch.CUR_OBS]},
            explore=False,
            is_training=True)

        # Target Q-network evaluation.
        q_tp1, q_logits_tp1, q_probs_tp1, _ = compute_q_values(
            policy,
            policy.target_models[model],
            {"obs": train_batch[SampleBatch.NEXT_OBS]},
            explore=False,
            is_training=True)

        # Q scores for actions which we know were selected in the given state.
        one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                                    policy.action_space.n)
        q_t_selected = torch.sum(
            torch.where(q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=q_t.device))
            * one_hot_selection, 1)
        q_logits_t_selected = torch.sum(
            q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1)

        # compute estimate of best possible value starting from state at t + 1
        if config["double_q"]:
            q_tp1_using_online_net, q_logits_tp1_using_online_net, \
                q_dist_tp1_using_online_net, _ = compute_q_values(
                    policy,
                    model,
                    {"obs": train_batch[SampleBatch.NEXT_OBS]},
                    explore=False,
                    is_training=True)
            q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
            q_tp1_best_one_hot_selection = F.one_hot(q_tp1_best_using_online_net,
                                                    policy.action_space.n)
            q_tp1_best = torch.sum(
                torch.where(q_tp1 > FLOAT_MIN, q_tp1,
                            torch.tensor(0.0, device=q_tp1.device)) *
                q_tp1_best_one_hot_selection, 1)
            q_probs_tp1_best = torch.sum(
                q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1)
        else:
            q_tp1_best_one_hot_selection = F.one_hot(
                torch.argmax(q_tp1, 1), policy.action_space.n)
            q_tp1_best = torch.sum(
                torch.where(q_tp1 > FLOAT_MIN, q_tp1,
                            torch.tensor(0.0, device=q_tp1.device)) *
                q_tp1_best_one_hot_selection, 1)
            q_probs_tp1_best = torch.sum(
                q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1)

        print(train_batch.keys())
        print("important")
        q_loss = QLoss(q_t_selected, q_logits_t_selected, q_tp1_best,
                    q_probs_tp1_best, train_batch[PRIO_WEIGHTS],
                    train_batch[SampleBatch.REWARDS],
                    train_batch[SampleBatch.DONES].float(), config["gamma"],
                    config["n_step"], config["num_atoms"], config["v_min"],
                    config["v_max"])

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["td_error"] = q_loss.td_error
        # TD-error tensor in final stats
        # will be concatenated and retrieved for each individual batch item.
        model.tower_stats["q_loss"] = q_loss

        return q_loss.loss


    def adam_optimizer(policy: Policy,
                    config: TrainerConfigDict) -> "torch.optim.Optimizer":

        # By this time, the models have been moved to the GPU - if any - and we
        # can define our optimizers using the correct CUDA variables.
        if not hasattr(policy, "q_func_vars"):
            policy.q_func_vars = policy.model.variables()

        return torch.optim.Adam(
            policy.q_func_vars, lr=policy.cur_lr, eps=config["adam_epsilon"])


    def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
        stats = {}
        for stats_key in policy.model_gpu_towers[0].tower_stats[
                "q_loss"].stats.keys():
            stats[stats_key] = torch.mean(
                torch.stack([
                    t.tower_stats["q_loss"].stats[stats_key].to(policy.device)
                    for t in policy.model_gpu_towers if "q_loss" in t.tower_stats
                ]))
        stats["cur_lr"] = policy.cur_lr
        return stats


    def setup_early_mixins(policy: Policy, obs_space, action_space,
                        config: TrainerConfigDict) -> None:
        LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])


    def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,
                        action_space: gym.spaces.Space,
                        config: TrainerConfigDict) -> None:
        ComputeTDErrorMixin.__init__(policy)
        TargetNetworkMixin.__init__(policy)


    def compute_q_values(policy: Policy,
                        model: ModelV2,
                        input_dict,
                        state_batches=None,
                        seq_lens=None,
                        explore=None,
                        is_training: bool = False):
        config = policy.config

        model_out, state = model(input_dict, state_batches or [], seq_lens)

        if config["num_atoms"] > 1:
            (action_scores, z, support_logits_per_action, logits,
            probs_or_logits) = model.get_q_value_distributions(model_out)
        else:
            (action_scores, logits,
            probs_or_logits) = model.get_q_value_distributions(model_out)
        if config["dueling"]:
            state_score = model.get_state_value(model_out)
            if policy.config["num_atoms"] > 1:
                support_logits_per_action_mean = torch.mean(
                    support_logits_per_action, dim=1)
                support_logits_per_action_centered = (
                    support_logits_per_action - torch.unsqueeze(
                        support_logits_per_action_mean, dim=1))
                support_logits_per_action = torch.unsqueeze(
                    state_score, dim=1) + support_logits_per_action_centered
                support_prob_per_action = nn.functional.softmax(
                    support_logits_per_action, dim=-1)
                value = torch.sum(z * support_prob_per_action, dim=-1)
                logits = support_logits_per_action
                probs_or_logits = support_prob_per_action
            else:
                advantages_mean = reduce_mean_ignore_inf(action_scores, 1)
                advantages_centered = action_scores - torch.unsqueeze(
                    advantages_mean, 1)
                value = state_score + advantages_centered
        else:
            value = action_scores

        return value, logits, probs_or_logits, state


    def grad_process_and_td_error_fn(policy: Policy,
                                    optimizer: "torch.optim.Optimizer",
                                    loss: TensorType) -> Dict[str, TensorType]:
        # Clip grads if configured.
        return apply_grad_clipping(policy, optimizer, loss)


    def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
                            action_dist) -> Dict[str, TensorType]:
        if not args.include_concepts:
            return {
            "q_values": model.tower_stats["q_values"],
            }
        else:
            return {
            "q_values": model.tower_stats["q_values"],
            "CONCEPT_PREDS": model.concept_function(),
            }

    



    DQNTorchPolicy = build_policy_class(
        name="DQNTorchPolicy",
        framework="torch",
        loss_fn=build_q_losses,
        get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
        make_model_and_action_dist=build_q_model_and_distribution,
        action_distribution_fn=get_distribution_inputs_and_class,
        stats_fn=build_q_stats,
        optimizer_fn=adam_optimizer,
        extra_grad_process_fn=grad_process_and_td_error_fn,
        extra_learn_fetches_fn=concat_multi_gpu_td_errors,
        extra_action_out_fn=extra_action_out_fn,
        before_init=setup_early_mixins,
        before_loss_init=before_loss_init,
        mixins=[
            TargetNetworkMixin,
            ComputeTDErrorMixin,
            LearningRateSchedule,
        ])
    return DQNTorchPolicy