import torch
import numpy as np

def CUDA():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

def norm(x, dim=None, except_dim=None, keepdim=False, no_sqrt=False):
    shp = x.shape
    n_dims = len(shp)
    if except_dim is None:
        if dim is None:
            reducing_dims = tuple(range(n_dims))
        else:
            reducing_dims = dim
    else:
        assert dim is None
        reducing_dims = list(range(n_dims))
        reducing_dims.pop(except_dim)
        reducing_dims = tuple(reducing_dims)
    sq_sum = torch.sum(x**2, dim=reducing_dims, keepdim=keepdim)
    return sq_sum if no_sqrt else torch.sqrt(sq_sum)

