import torch 
import torch.nn as nn 

class LogRewardLinear(nn.Module): 

    def __init__(self, seed, warehouse_size, device): 
        super(LogRewardLinear, self).__init__() 
        self.seed = seed 
        self.warehouse_size = warehouse_size 

        g = torch.Generator(device=device) 
        g.manual_seed(seed) 
        self.el_values = torch.rand(generator=g, size=(self.warehouse_size,)) 
        self.el_values = torch.hstack([self.el_values, torch.tensor([0.])]) 

    @torch.no_grad() 
    def forward(self, batch_state): 
        v = self.el_values[batch_state.state].sum(dim=1) 
        return v 

class Set: 

    def __init__(self, set_size, warehouse_size, batch_size, log_reward, device='cpu'): 
        self.set_size = set_size 
        self.warehouse_size = warehouse_size 
        self.batch_size = batch_size 
        
        self._log_reward = log_reward  

        # Starts with an initial element (whose vectorial representation is used to compute the transition probabilities) 
        self.state = torch.ones((self.batch_size, 1), dtype=torch.long) * warehouse_size 
        
        self.size = torch.zeros((self.batch_size,)) 
        self.actions = torch.arange(self.warehouse_size, dtype=torch.long) 
        self.batch_ids = torch.arange(self.batch_size) 

        self.stopped = torch.zeros((batch_size,))  

    @torch.no_grad() 
    def apply(self, indices): 
        actions = self.actions[indices] 
        self.stopped += (self.state.shape[1] >= self.set_size)    
        self.size += (self.state != actions.unsqueeze(1)).all(dim=1).float() 
        self.state = torch.hstack([self.state, actions.unsqueeze(1)])
        assert self.state.shape[1] <= self.set_size + 1 
        return (self.stopped <= 1)  

    @torch.no_grad() 
    def log_reward(self): 
        return self._log_reward(self) 

    @torch.no_grad() 
    def merge(self, batch_state): 
        self.batch_ids = torch.hstack([self.batch_ids, batch_state.batch_ids + self.batch_size]) 
        self.batch_size += batch_state.batch_size 
        self.state = torch.vstack([self.state, batch_state.state]) 
    
    @property
    def sorted_state(self): 
        return self.state.sort(dim=1) 