import numpy as np
import torch as th


class GAE(object):
    def __init__(self):
        pass

    def estimate_advantages(self, rewards, masks, values, gamma, tau, bootstrap_value):
        deltas = [1 for i in range(len(rewards))]
        advantages = [1 for i in range(len(rewards))]
        returns = [1 for i in range(len(rewards))]
        prev_value = bootstrap_value
        prev_advantage = 0
        for i in reversed(range(len(rewards))):
            deltas[i] = rewards[i] + gamma * prev_value * masks[i] - values[i]
            advantages[i] = deltas[i] + gamma * tau * prev_advantage * masks[i]
            prev_value = values[i]
            prev_advantage = advantages[i]
            returns[i] = values[i] + advantages[i]

        return advantages, returns

