import torch 
import torch.nn as nn 

class LogRewardDist(nn.Module): 

    def __init__(self, refs):
        super(LogRewardDist, self).__init__() 
        self.refs = refs 
     
    @torch.no_grad() 
    def forward(self, batch_state): 
        dist_to_refs = (batch_state.pos[:, None, :] - self.refs).abs().sum(dim=-1).min(dim=-1).values
        return - torch.log(dist_to_refs + 1)  

class Grid2D: 

    def __init__(self, width, height, batch_size, log_reward):
        self.width = width 
        self.height = height 
        self.batch_size = batch_size 

        self.topright = torch.tensor([width, height], dtype=torch.get_default_dtype())  
        self.bottomleft = torch.tensor([0, 0], dtype=torch.get_default_dtype()) 
        
        self._log_reward = log_reward 
        
        self.pos = torch.zeros((2,)).repeat(batch_size, 1) 

        self.actions = torch.tensor([[0, 0], [1, 0], [0, 1]], dtype=torch.get_default_dtype()) 
        self.mask = torch.ones((batch_size, self.actions.shape[0])) 
        self.batch_ids = torch.arange(self.batch_size) 

        self.stopped = torch.zeros((self.batch_size,)) 
    
    @torch.no_grad() 
    def apply(self, indices): 
        actions = self.actions[indices] 
        # Update the states 
        stop_action_mask = (indices == 0).to(dtype=self.pos.dtype) 
        self.pos = self.pos + (1 - stop_action_mask)[..., None] * actions
        # Update the mask 
        invalid_future_actions = torch.argwhere(torch.gt(self.pos[:, None, :] + self.actions, self.topright).any(dim=-1)).t() 
        self.mask[*invalid_future_actions] = 0 
        self.mask[(stop_action_mask == 1), 1:] = 0 
        self.stopped += stop_action_mask 
        return (self.stopped < 2.) # Consider exclusively the non-stop actions  
    
    @torch.no_grad() 
    def log_reward(self): 
        return self._log_reward(self)

    @torch.no_grad() 
    def merge(self, batch_state): 
        self.pos = torch.vstack([self.pos, batch_state.pos]) 
        self.mask = torch.vstack([self.mask, batch_state.mask]) 

        self.batch_ids = torch.hstack([self.batch_ids, self.batch_size + batch_state.batch_ids]) 
        self.batch_size += batch_state.batch_size 