from functools import partial

import torch
import torch.utils._pytree as pytree

add = torch.Tensor.add
equal = torch.equal
zeros_like = torch.zeros_like
min = torch.min
max = torch.max
allclose = partial(torch.allclose, rtol=5e-3, atol=5e-03)
concatenate = torch.concatenate
chunk = torch.chunk
narrow = torch.narrow

Tensor = torch.Tensor

tree_flatten = pytree.tree_flatten
tree_unflatten = pytree.tree_unflatten

clone = torch.clone
from_numpy = torch.from_numpy