Module bitsandbytes.ops
Expand source code
import torch
import os
import ctypes as ct
lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libClusterNet.so')
def get_ptr(A: torch.Tensor) -> ct.c_void_p:
'''
Get the ctypes pointer from a PyTorch Tensor.
Parameters
----------
A : torch.tensor
The PyTorch tensor.
'''
return ct.c_void_p(A.data.storage().data_ptr())
def estimate_quantiles(A: torch.Tensor, out: torch.Tensor=None, offset: float=1/512) -> torch.Tensor:
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
and the extreme quantiles close to 0 and 1 have high variance / large estimation
errors. These large errors can be circumnavigated by using the offset variable.
Default offset value of 1/512 ensures minimum entropy encoding. An offset value
of 0.01 to 0.02 usually has a much lower error. Given an offset of 0.02 equidistance
points in the range [0.02, 0.98] are used for the quantiles.
Parameters
----------
A : torch.Tensor
The input tensor. Any shape.
out : torch.Tensor
Tensor with the 256 estimated quantiles.
offset : float
The offset for the first and last quantile from 0 and 1. Default: 1/512
Returns
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
'''
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
if A.dtype == torch.float32:
lib.estimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
lib.estimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else:
raise NotImplementError(f'Not supported data type {A.dtype}')
return out
Functions
def estimate_quantiles(A: torch.Tensor, out: torch.Tensor = None, offset: float = 0.001953125) ‑> torch.Tensor
-
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles via the eCDF of the input tensor
A
. This is a fast but approximate algorithm and the extreme quantiles close to 0 and 1 have high variance / large estimation errors. These large errors can be circumnavigated by using the offset variable. Default offset value of 1/512 ensures minimum entropy encoding. An offset value of 0.01 to 0.02 usually has a much lower error. Given an offset of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.Parameters
A
:torch.Tensor
- The input tensor. Any shape.
out
:torch.Tensor
- Tensor with the 256 estimated quantiles.
offset
:float
- The offset for the first and last quantile from 0 and 1. Default: 1/512
Returns
torch.Tensor:
- The 256 quantiles in float32 datatype.
Expand source code
def estimate_quantiles(A: torch.Tensor, out: torch.Tensor=None, offset: float=1/512) -> torch.Tensor: ''' Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles via the eCDF of the input tensor `A`. This is a fast but approximate algorithm and the extreme quantiles close to 0 and 1 have high variance / large estimation errors. These large errors can be circumnavigated by using the offset variable. Default offset value of 1/512 ensures minimum entropy encoding. An offset value of 0.01 to 0.02 usually has a much lower error. Given an offset of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles. Parameters ---------- A : torch.Tensor The input tensor. Any shape. out : torch.Tensor Tensor with the 256 estimated quantiles. offset : float The offset for the first and last quantile from 0 and 1. Default: 1/512 Returns ------- torch.Tensor: The 256 quantiles in float32 datatype. ''' if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) if A.dtype == torch.float32: lib.estimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: lib.estimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) else: raise NotImplementError(f'Not supported data type {A.dtype}') return out
def get_ptr(A: torch.Tensor) ‑> ctypes.c_void_p
-
Get the ctypes pointer from a PyTorch Tensor.
Parameters
A
:torch.tensor
- The PyTorch tensor.
Expand source code
def get_ptr(A: torch.Tensor) -> ct.c_void_p: ''' Get the ctypes pointer from a PyTorch Tensor. Parameters ---------- A : torch.tensor The PyTorch tensor. ''' return ct.c_void_p(A.data.storage().data_ptr())