import numpy as np

def _to_repeated_list(a, length):
    if isinstance(a, list):
        return a
    elif isinstance(a, tuple):
        return list(a)
    else:
        a = [a] * length
        return a
    
def pad_if_needed(im, min_shape, mode):
    min_shape = _to_repeated_list(min_shape, 2)
    if im.shape[-2] >= min_shape[0] and im.shape[-1] >= min_shape[1]:
        return im
    else:
        pad = [0, 0]
        if im.shape[-2] < min_shape[0]:
            p = (min_shape[0] - im.shape[-2])//2 + 1
            pad[0] = p
        if im.shape[-1] < min_shape[1]:
            p = (min_shape[1] - im.shape[-1])//2 + 1
            pad[1] = p
        if len(im.shape) == 2:
            pad = ((pad[0], pad[0]), (pad[1], pad[1]))
        else:
            assert len(im.shape) == 3
            pad = ((0, 0), (pad[0], pad[0]), (pad[1], pad[1]))

        padded = np.pad(im, pad_width=pad, mode=mode)
        return padded
    
    
def ifft2_np(x):
    return np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(x.astype(np.complex64)), norm='ortho')).astype(np.complex64)


def fft2_np(x):
    return np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(x.astype(np.complex64)), norm='ortho')).astype(np.complex64)

def split_coils(x):
    if len(x.shape) == 2:
        mode = 'singlecoil'
        expanded = [ x ]
    else:
        assert len(x.shape) == 3
        mode = 'multicoil'
        numcoils = x.shape[0]
        expanded = [x[i, ...] for i in range(numcoils)]
    return expanded

def stack_coils(x):
    stacked = np.stack(x, axis=0) if len(x)>1 else x[0]
    return stacked