import torch

def calculate_quantile(arr, value):
    assert arr.dim() == 1, "Input array must be a 1D tensor"
    less_equal_count = torch.sum(arr <= value).item()
    quantile = less_equal_count / arr.numel()

    return quantile


def compute_class_weights(y_pl_cls, num_classes, return_counts=False, clamp=False):
    """
    Args:
        y_pl_cls: Tensor of shape (N,)
        num_classes: int
        
    Returns:
        weights: Tensor of shape (N,), 
    """
    class_counts = torch.zeros(num_classes, dtype=torch.float)
    unique_classes, counts = torch.unique(y_pl_cls, return_counts=True)
    class_counts[unique_classes] = counts.float()

    total_samples = len(y_pl_cls)
    class_weights = total_samples / ((class_counts + 1.0) * num_classes)

    weights = class_weights[y_pl_cls]
    if clamp:
        weights = torch.clamp(weights, max=clamp)
    if return_counts:
        return weights, class_weights, class_counts
    else:
        return weights, class_weights
    

def select_prob(y_pred_prob, threshold=0.68):
    selection_y_pred_prob = torch.zeros_like(y_pred_prob)
   
    for i in range(y_pred_prob.size(0)):
        row = y_pred_prob[i]
        sorted_values, sorted_indices = torch.sort(row, descending=True)
        cumulative_sum = 0.0
        for value, idx in zip(sorted_values, sorted_indices):
            cumulative_sum += value.item()
            selection_y_pred_prob[i, idx] = value
            if cumulative_sum >= threshold:
                break
        
    selection_y_pred_prob = selection_y_pred_prob / torch.sum(
        selection_y_pred_prob, dim=-1, keepdim=True)
    return selection_y_pred_prob