from typing import overload

import numpy as np
import torch


@overload
def sigmoid(x: np.ndarray) -> np.ndarray:
    ...


@overload
def sigmoid(x: torch.Tensor) -> torch.Tensor:
    ...


def sigmoid(x):
    return 1 / (1 + np.exp(-x) if isinstance(x, np.ndarray) else torch.exp(-x))


@overload
def logit(x: np.ndarray) -> np.ndarray:
    ...


@overload
def logit(x: torch.Tensor) -> torch.Tensor:
    ...


def logit(x):
    log = np.log if isinstance(x, np.ndarray) else torch.log
    return log(x) - log(1 - x)
