from .AsymRectangle import asym_rectangle_fn
from .Triangle import triangle_fn

import torch

def threshold(x: torch.Tensor, method: str = "rectangle", alpha: float | tuple[float, float] = 1.0) -> torch.Tensor:
    if method == "rectangle":
        return asym_rectangle_fn(x, alpha)
    elif method == "triangle":
        return triangle_fn(x, alpha)
    else:
        raise NotImplementedError(f"Threshold method {method} not implemented.")
    

    
