# Copyright 2022 The Deep RL Zoo Authors. All Rights Reserved.
# Copyright 2019 The SEED Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# The functions has been modified by The Deep RL Zoo Authors
# to support PyTorch operation.
#
# ==============================================================================
"""Functions to compute V-trace off-policy actor critic targets.

For details and theory see:

"IMPALA: Scalable Distributed Deep-RL with
Importance Weighted Actor-Learner Architectures"
by Espeholt, Soyer, Munos et al.

See https://arxiv.org/abs/1802.01561 for the full paper.
"""
import collections
from typing import NamedTuple

import torch
import torch.nn.functional as F

from ESA import base


VTraceFromLogitsReturns = collections.namedtuple(
    'VTraceFromLogitsReturns',
    [
        'vs',
        'pg_advantages',
        'log_rhos',
        'behavior_action_log_probs',
        'target_action_log_probs',
    ],
)

VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages')


# class VTraceReturns(NamedTuple):
#     vs: torch.Tensor
#     pg_advantages: torch.Tensor


def action_log_probs(policy_logits, actions):
    return -F.nll_loss(
        F.log_softmax(torch.flatten(policy_logits, 0, -2), dim=-1),
        torch.flatten(actions),
        reduction='none',
    ).view_as(actions)


def from_logits(
    behavior_policy_logits,
    target_policy_logits,
    actions,
    discounts,
    rewards,
    values,
    bootstrap_value,
    clip_rho_threshold=1.0,
    clip_pg_rho_threshold=1.0,
):
    """V-trace for softmax policies."""

    target_action_log_probs = action_log_probs(target_policy_logits, actions)
    behavior_action_log_probs = action_log_probs(behavior_policy_logits, actions)
    log_rhos = target_action_log_probs - behavior_action_log_probs
    vtrace_returns = from_importance_weights(
        target_action_log_probs=target_action_log_probs,
        behavior_action_log_probs=behavior_action_log_probs,
        discounts=discounts,
        rewards=rewards,
        values=values,
        bootstrap_value=bootstrap_value,
        clip_rho_threshold=clip_rho_threshold,
        clip_pg_rho_threshold=clip_pg_rho_threshold,
    )
    return VTraceFromLogitsReturns(
        log_rhos=log_rhos,
        behavior_action_log_probs=behavior_action_log_probs,
        target_action_log_probs=target_action_log_probs,
        **vtrace_returns._asdict(),
    )


# Make sure no gradients backpropagated through the returned values.
@torch.no_grad()
def from_importance_weights(
    target_action_log_probs: torch.Tensor,
    behavior_action_log_probs: torch.Tensor,
    discounts: torch.Tensor,
    rewards: torch.Tensor,
    values: torch.Tensor,
    bootstrap_value: torch.Tensor,
    clip_rho_threshold=1.0,
    clip_pg_rho_threshold=1.0,
    lambda_=1.0,
):
    r"""V-trace from log importance weights.

    Calculates V-trace actor critic targets as described in

    "IMPALA: Scalable Distributed Deep-RL with
    Importance Weighted Actor-Learner Architectures"
    by Espeholt, Soyer, Munos et al.

    In the notation used throughout documentation and comments, T refers to the
    time dimension ranging from 0 to T-1. B refers to the batch size and
    action_dim refers to the number of actions. This code also supports the
    case where all tensors have the same number of additional dimensions, e.g.,
    `rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C].

    Args:
      target_action_log_probs: A float32 tensor of shape [T, B] with
        log-probabilities of taking the action by the current policy
      behavior_action_log_probs: A float32 tensor of shape [T, B] with
        log-probabilities of taking the action by the behavioral policy
      discounts: A float32 tensor of shape [T, B] with discounts encountered when
        following the behavior policy.
      rewards: A float32 tensor of shape [T, B] containing rewards generated by
        following the behavior policy.
      values: A float32 tensor of shape [T, B] with the value function estimates
        wrt. the target policy.
      bootstrap_value: A float32 of shape [B] with the value function estimate at
        time T.
      clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
        importance weights (rho) when calculating the baseline targets (vs).
        rho^bar in the paper. If None, no clipping is applied.
      clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
        on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). If
        None, no clipping is applied.
      lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). See Remark 2
        in paper. Defaults to lambda_=1.

    Returns:
      A VTraceReturns namedtuple (vs, pg_advantages) where:
        vs: A float32 tensor of shape [T, B]. Can be used as target to
          train a baseline (V(x_t) - vs_t)^2.
        pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
          advantage in the calculation of policy gradients.
    """
    base.assert_rank_and_dtype(target_action_log_probs, 2, torch.float32)
    base.assert_rank_and_dtype(behavior_action_log_probs, 2, torch.float32)

    log_rhos = target_action_log_probs - behavior_action_log_probs

    if clip_rho_threshold is not None:
        clip_rho_threshold = torch.tensor(clip_rho_threshold, dtype=torch.float32, device=log_rhos.device)
    if clip_pg_rho_threshold is not None:
        clip_pg_rho_threshold = torch.tensor(clip_pg_rho_threshold, dtype=torch.float32, device=log_rhos.device)

    # Make sure tensor ranks are consistent.
    rho_rank = len(log_rhos.shape)  # Usually 2.
    base.assert_rank_and_dtype(values, rho_rank, torch.float32)
    base.assert_rank_and_dtype(bootstrap_value, int(rho_rank - 1), torch.float32)
    base.assert_rank_and_dtype(discounts, rho_rank, torch.float32)
    base.assert_rank_and_dtype(rewards, rho_rank, torch.float32)

    if clip_rho_threshold is not None:
        base.assert_rank(clip_rho_threshold, 0)
    if clip_pg_rho_threshold is not None:
        base.assert_rank(clip_pg_rho_threshold, 0)

    rhos = torch.exp(log_rhos)
    if clip_rho_threshold is not None:
        clipped_rhos = torch.minimum(clip_rho_threshold, rhos)
    else:
        clipped_rhos = rhos

    cs = torch.minimum(torch.tensor(1.0), rhos)
    cs *= torch.tensor(lambda_, dtype=torch.float32)

    # Append bootstrapped value to get [v1, ..., v_t+1]
    values_t_plus_1 = torch.concat([values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
    deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)

    acc = torch.zeros_like(bootstrap_value)
    vs_minus_v_xs = []
    for i in range(int(discounts.shape[0]) - 1, -1, -1):
        discount, c, delta = discounts[i], cs[i], deltas[i]
        acc = delta + discount * c * acc
        vs_minus_v_xs.append(acc)
    vs_minus_v_xs = vs_minus_v_xs[::-1]
    vs_minus_v_xs = torch.stack(vs_minus_v_xs, dim=0)

    # Add V(x_s) to get v_s.
    vs = torch.add(vs_minus_v_xs, values)

    # Advantage for policy gradient.
    vs_t_plus_1 = torch.concat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
    if clip_pg_rho_threshold is not None:
        clipped_pg_rhos = torch.minimum(clip_pg_rho_threshold, rhos)
    else:
        clipped_pg_rhos = rhos
    pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)

    return VTraceReturns(vs=vs, pg_advantages=pg_advantages)
