import torch

def apply_thresholds(x, thresholds):
    
    # Input check
    if type(x) is not torch.Tensor:
        x = torch.tensor(x)
    if type(thresholds) is not torch.Tensor:
        thresholds = torch.tensor(thresholds)
    
    # Func
    x = x.unsqueeze(-1)
    return x > thresholds

def linear_thresholds(x, num_bits, min_value=None, max_value=None):
    
    # Input check
    if type(x) is not torch.Tensor:
        x = torch.tensor(x)
    if min_value is None:
        min_value = x.min(dim=0)[0]
    elif type(min_value) is not torch.Tensor:
        min_value = torch.tensor(min_value)
    if max_value is None:
        max_value = x.max(dim=0)[0]
    elif type(max_value) is not torch.Tensor:
        max_value = torch.tensor(max_value)
    
    # Func
    return min_value.unsqueeze(-1) + torch.arange(1, num_bits+1).unsqueeze(0) * ((max_value - min_value) / (num_bits + 1)).unsqueeze(-1)

def gaussian_thresholds(x,  num_bits, individual=True):

    # Input check
    if type(x) is not torch.Tensor:
        x = torch.tensor(x)

    # Func
    std_skews = torch.distributions.Normal(0, 1).icdf(torch.arange(1, num_bits+1)/(num_bits+1))
    mean = x.mean(dim=0) if individual else x.mean()
    std = x.std(dim=0) if individual else x.std() 
    thresholds = torch.stack([std_skew * std + mean for std_skew in std_skews], dim=-1)
    return thresholds

def distributive_thresholds(x, num_bits=1, individual=True):

    # Input check
    if type(x) is not torch.Tensor:
        x = torch.tensor(x)

    # Func
    data = torch.sort(x.flatten())[0] if not individual else torch.sort(x, dim=0)[0]
    indicies = torch.tensor([int(data.shape[0]*i/(num_bits+1)) for i in range(1, num_bits+1)])
    thresholds = data[indicies]
    return torch.permute(thresholds, (*list(range(1, thresholds.ndim)), 0))

class Thresholds(torch.nn.Module):
    def __init__(self, num_thresholds=None, thresholds=None):
        super().__init__()
        assert num_thresholds is not None or thresholds is not None, "num_thresholds or thresholds need to be specified"
        if num_thresholds:
            self.thresholds = torch.nn.ParameterList(thresholds)
        elif thresholds:
            self.thresholds = torch.nn.ParameterList(thresholds)
    def forward(self, x):
        return torch.cat([(x[:, i:i+1].unsqueeze(-1) > self.thresholds[i]).flatten(start_dim=1) for i in range(x.shape[1])], dim=1)