import torch
from torch import Tensor

from src.selection.scores import register_selector_score
from src.selection.scores.base import BaseSelectorScore


def doctor(logits: Tensor, temperature: float = 1):
    g = torch.sum(torch.softmax(logits / temperature, 1) ** 2, 1)
    return 1 - g


@register_selector_score("gini")
class GiniSelector(BaseSelectorScore):
    def __init__(self, temperature: float = 1):
        super().__init__(fn=doctor, temperature=temperature)
