import torch 
import torch.nn as nn 

from var_red_gfn.utils import Environment 

class LogReward(nn.Module): 
    
    def __init__(self, src_size, seed, device='cpu', shift=0.): 
        super(LogReward, self).__init__() 
        self.src_size = src_size 
        self.seed = seed 
        self.device = device 
        g = torch.Generator(device=device) 
        g.manual_seed(seed) 

        self.values = torch.rand((self.src_size,), device=self.device, generator=g) 
        self.shift = shift 

    def forward(self, batch_state): 
        log_reward = (self.values * batch_state.unique_input).sum(dim=1) 
        return (log_reward - self.shift)   

class Set(Environment): 

    def __init__(self, src_size, set_size, batch_size, log_reward, device='cpu'): 
        super(Set, self).__init__(batch_size, set_size, log_reward, device=device)
        self.src_size = src_size 
        self.set_size = set_size 
        self.state = torch.zeros((self.batch_size, self.src_size), device=self.device, dtype=int)  
        self.forward_mask = torch.ones((self.batch_size, self.src_size), device=self.device) 
        self.backward_mask = torch.zeros((self.batch_size, self.src_size), device=self.device) 

    @torch.no_grad() 
    def apply(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] + 1  
        self.is_initial[:] = 0. 
        self.stopped[:] = (self.state.sum(dim=1) == self.set_size)
        self.forward_mask = 1 - self.state.type(self.forward_mask.dtype) 
        self.backward_mask = self.state.type(self.backward_mask.dtype) 

    @torch.no_grad() 
    def backward(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] - 1 
        self.is_initial[:] = (self.state.sum(dim=1) == 0) 
        self.stopped[:] = 0 
        self.forward_mask = (1 - self.state).type(self.forward_mask.dtype) 
        self.backward_mask = self.state.type(self.backward_mask.dtype) 
        return indices 

    @torch.no_grad() 
    def merge(self, batch_state): 
        super().merge(batch_state) 
        self.state = torch.vstack([self.state, batch_state.state]) 
    
    @property 
    def unique_input(self): 
        return self.state.type(self.backward_mask.dtype)   