from abc import ABC
from typing import Union

import numpy as np
import torch
from torch import Tensor


class BaseSelectorScore(ABC):
    def __init__(self, fn, *args, **fn_kwargs):
        super().__init__()
        self.fn = fn
        self.fn_kwargs = fn_kwargs

    def forward(self, x: Union[Tensor, np.ndarray]) -> Tensor:
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        return self.fn(x, **self.fn_kwargs)

    def forward_logits(self, x: Union[Tensor, np.ndarray]) -> Tensor:
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        return self.forward(x)
