import torch


def remove_nan(x):
    x[torch.isnan(x)] = 0.
    return x

def rbf_func(values, v_min=0., v_max=20., n_bins=16):
    rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device, dtype=values.dtype)
    rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
    rbf_std = (v_max - v_min) / n_bins
    z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
    return torch.exp(-z ** 2)

def decouple(U, v_min=0., v_max=20., n_bins=16):
    norm = U.norm(dim=-1, keepdim=True)
    direct = U / (norm+1e-6)
    rbf = rbf_func(norm[...,0], v_min, v_max, n_bins)
    return torch.cat([direct, rbf], dim=-1)