import math
import numpy as np


# Bit reversal permutation


def bitreversal_po2(n):
    """
    From S4 codebase.

    :param n:
    :return:
    """
    m = int(math.log(n) / math.log(2))
    perm = np.arange(n).reshape(n, 1)
    for i in range(m):
        n1 = perm.shape[0] // 2
        perm = np.hstack((perm[:n1], perm[n1:]))
    return perm.squeeze(0)


def bitreversal_permutation(n):
    """
    From S4 codebase.

    :param n:
    :return:
    """
    m = int(math.ceil(math.log(n) / math.log(2)))
    N = 1 << m
    perm = bitreversal_po2(N)
    return np.extract(perm < n, perm)


def transpose_permutation(h, w):
    """
    From S4 codebase.

    :param h:
    :param w:
    :return:
    """
    indices = np.arange(h * w)
    indices = indices.reshape((h, w))
    indices = indices.T
    indices = indices.reshape(h * w)
    return indices


def snake_permutation(h, w):
    """
    From S4 codebase.

    :param h:
    :param w:
    :return:
    """
    indices = np.arange(h * w)
    indices = indices.reshape((h, w))
    indices[1::2, :] = indices[1::2, ::-1]
    indices = indices.reshape(h * w)
    return indices


def hilbert_permutation(n):
    """
    From S4 codebase.

    From S4 codebase.
    :param n:
    :return:
    """

    m = int(math.log2(n))
    assert n == 2**m
    inds = decode(list(range(n * n)), 2, m)
    ind_x, ind_y = inds.T
    indices = np.arange(n * n).reshape((n, n))
    indices = indices[ind_x, ind_y]
    return indices


def decode(hilberts, num_dims, num_bits):
    """
    From S4 codebase.

    Hilbert curve utilities taken from https://github.com/PrincetonLIPS/numpy-hilbert-curve

    Decode an array of Hilbert integers into locations in a hypercube.
    This is a vectorized-ish version of the Hilbert curve implementation by John
    Skilling as described in:
    Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
      Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
    Params:
    -------
     hilberts - An ndarray of Hilbert integers.  Must be an integer dtype and
                cannot have fewer bits than num_dims * num_bits.
     num_dims - The dimensionality of the hypercube. Integer.
     num_bits - The number of bits for each dimension. Integer.
    Returns:
    --------
     The output is an ndarray of unsigned integers with the same shape as hilberts
     but with an additional dimension of size num_dims.
    """

    if num_dims * num_bits > 64:
        raise ValueError(
            "num_dims=%d and num_bits=%d for %d bits total, " % (num_dims, num_bits, 64)
            + "which can't be encoded into a uint64. Are you sure you need that many points on your"
            + "Hilbert curve?"
        )

    # Handle the case where we got handed a naked integer.
    hilberts = np.atleast_1d(hilberts)

    # Keep around the shape for later.
    orig_shape = hilberts.shape

    # Treat each of the hilberts as a sequence of eight uint8.
    # This treats all of the inputs as uint64 and makes things uniform.
    hh_uint8 = np.reshape(hilberts.ravel().astype(">u8").view(np.uint8), (-1, 8))

    # Turn these lists of uints into lists of bits and then truncate to the size
    # we actually need for using Skilling's procedure.
    hh_bits = np.unpackbits(hh_uint8, axis=1)[:, -num_dims * num_bits :]

    # Take the sequence of bits and Gray-code it.
    gray = binary2gray(hh_bits)

    # There has got to be a better way to do this.
    # I could index them differently, but the eventual packbits likes it this way.
    gray = np.swapaxes(
        np.reshape(gray, (-1, num_bits, num_dims)),
        axis1=1,
        axis2=2,
    )

    # Iterate backwards through the bits.
    for bit in range(num_bits - 1, -1, -1):
        # Iterate backwards through the dimensions.
        for dim in range(num_dims - 1, -1, -1):
            # Identify which ones have this bit active.
            mask = gray[:, dim, bit]

            # Where this bit is on, invert the 0 dimension for lower bits.
            gray[:, 0, bit + 1 :] = np.logical_xor(gray[:, 0, bit + 1 :], mask[:, np.newaxis])

            # Where the bit is off, exchange the lower bits with the 0 dimension.
            to_flip = np.logical_and(
                np.logical_not(mask[:, np.newaxis]),
                np.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
            )
            gray[:, dim, bit + 1 :] = np.logical_xor(gray[:, dim, bit + 1 :], to_flip)
            gray[:, 0, bit + 1 :] = np.logical_xor(gray[:, 0, bit + 1 :], to_flip)

    # Pad back out to 64 bits.
    extra_dims = 64 - num_bits
    padded = np.pad(gray, ((0, 0), (0, 0), (extra_dims, 0)), mode="constant", constant_values=0)

    # Now chop these up into blocks of 8.
    locs_chopped = np.reshape(padded[:, :, ::-1], (-1, num_dims, 8, 8))

    # Take those blocks and turn them unto uint8s.
    locs_uint8 = np.squeeze(np.packbits(locs_chopped, bitorder="little", axis=3))

    # Finally, treat these as uint64s.
    flat_locs = locs_uint8.view(np.uint64)

    # Return them in the expected shape.
    return np.reshape(flat_locs, (*orig_shape, num_dims))


def right_shift(binary, k=1, axis=-1):
    """Right shift an array of binary values.
    Parameters:
    -----------
    binary: An ndarray of binary values.
    k: The number of bits to shift. Default 1.
    axis: The axis along which to shift.  Default -1.
    Returns:
    --------
    Returns an ndarray with zero prepended and the ends truncated, along
    whatever axis was specified.
    """

    # If we're shifting the whole thing, just return zeros.
    if binary.shape[axis] <= k:
        return np.zeros_like(binary)

    # Determine the padding pattern.
    padding = [(0, 0)] * len(binary.shape)
    padding[axis] = (k, 0)

    # Determine the slicing pattern to eliminate just the last one.
    slicing = [slice(None)] * len(binary.shape)
    slicing[axis] = slice(None, -k)

    shifted = np.pad(binary[tuple(slicing)], padding, mode="constant", constant_values=0)

    return shifted


def binary2gray(binary, axis=-1):
    """Convert an array of binary values into Gray codes.
    This uses the classic X ^ (X >> 1) trick to compute the Gray code.
    Parameters:
    -----------
    binary: An ndarray of binary values.
    axis: The axis along which to compute the gray code. Default=-1.
    Returns:
    --------
    Returns an ndarray of Gray codes.
    """
    shifted = right_shift(binary, axis=axis)

    # Do the X ^ (X >> 1) trick.
    gray = np.logical_xor(binary, shifted)

    return gray
