import torch 
import torch.nn as nn 

class LogReward(nn.Module): 
    
    def __init__(self, src_size, seed, device='cpu', shift=0., alpha=1.): 
        super(LogReward, self).__init__() 
        self.src_size = src_size 
        self.seed = seed 
        self.device = device 
        g = torch.Generator(device=device) 
        g.manual_seed(seed) 

        self.values = 10 * torch.rand((self.src_size,), device=self.device, generator=g) - 5 
        self.shift = shift 

        self.alpha = alpha 

    def forward(self, batch_state): 
        log_reward = (self.values * batch_state.unique_input).sum(dim=1) 
        return log_reward / self.alpha - self.shift  

