import torch
from khrylib.utils import batch_to


def estimate_advantages(rewards, masks, values, gamma, tau, normalize_rewards=False):
    device = rewards.device
    rewards, masks, values = batch_to(torch.device('cpu'), rewards, masks, values)
    tensor_type = type(rewards)
    deltas = tensor_type(rewards.size(0), 1)
    advantages = tensor_type(rewards.size(0), 1)

    prev_value = 0
    prev_advantage = 0
    for i in reversed(range(rewards.size(0))):
        deltas[i] = rewards[i] + gamma * prev_value * masks[i] - values[i]
        advantages[i] = deltas[i] + gamma * tau * prev_advantage * masks[i]

        prev_value = values[i, 0]
        prev_advantage = advantages[i, 0]

    returns = values + advantages
    if normalize_rewards:
        advantages = (advantages - advantages.mean()) / advantages.std()

    advantages, returns = batch_to(device, advantages, returns)
    return advantages, returns
