import torch

class MaskingModuleSigmoid(torch.nn.Module):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.mask_temp = 0.00001
    
    def compute_advection_mask(self, x_len, cT, delay_steps, antialiasing=1): 
        # TODO: check if it makes sense to learn delay_steps / x_len instead of
        # delay_steps. In that case substitute delay_steps / x_len with delay_steps
        # Then self.edge_time_coeffs would become length * cross_section / x_len
        # Since x_len is fixed for  the model this should be possible.
        x = torch.linspace(0., 1, x_len * antialiasing)[None].to(delay_steps.device)
        dx = 1 / antialiasing
        sigmoid_offset = cT/x_len
        mask_offset = delay_steps[:,None]/x_len
        sigmoid_scale = 1./self.mask_temp 
        # TODO: advection_mask_temp
        # can actually be very low (very sharp sigmoid) in the forward pass
        # as long as it is a bit more smooth in the backward pass to allow
        # a strong gradient.
        advection_mask = torch.sigmoid((x - sigmoid_offset - mask_offset) * sigmoid_scale)
        return advection_mask.reshape(-1, x_len, int(antialiasing)).sum(-1) * dx
    
    def forward(self, x, offset, loc):
        x_len = x.shape[-1]
        adv_mask = self.compute_advection_mask(x_len, loc, offset)
        return x * adv_mask # shape: [n_edges, x_len]
    
class MaskingModuleHard(torch.nn.Module):
    
    def __init__(self, antialiasing=False, **kwargs):
        super().__init__(**kwargs)
        self.mask_temp = 0.00001
        self.antialiasing = antialiasing
    
    def compute_advection_mask(self, x_len, cT, delay_steps): 
        i = (cT + delay_steps).clamp(0, x_len)
        masks = torch.triu(torch.ones((x_len,x_len)))
        if self.antialiasing:
            frac = (i % 1).unsqueeze(-1)
            i_low = i.floor().int()#.clamp(0, x_len)
            i_high = i.ceil().int()#.clamp(0, x_len)
            temporal_mask = masks[i_low] * (1. - frac) + masks[i_high] * frac
        else:
            temporal_mask = masks[i.int()]
        return temporal_mask.float()
    
    def forward(self, x, offset, loc):
        x_len = x.shape[-1]
        adv_mask = self.compute_advection_mask(x_len, loc, offset)
        return x * adv_mask # shape: [n_edges, x_len]