import torch
import torch_dct as dct

# ##############################################################################
# # SSRFT
# Adapted from https://github.com/andres-fr/skerch/tree/main
# ##############################################################################
class SSRFT:
    """Scrambled Subsampled Randomized Fourier Transform (SSRFT).

    This class encapsulates the left- and right-SSRFT transforms into a single
    linear operator, which is deterministic for the same shape and seed
    (particularly, also across different torch devices).
    """

    def __init__(self, shape, seed=0b1110101001010101011):
        """:param shape: ``(height, width)`` of linear operator."""
        if len(shape) != 2:
            raise ValueError("Shape must be a (height, width) pair!")
        self.shape = shape
        self.seed = seed
        h, w = shape
        if h > w:
            raise ValueError("Height > width not supported!")
        # :param scale: Ideally, ``1/l``, where ``l`` is the average diagonal
        #   value of the covmat ``A.T @ A``, where ``A`` is a FastJLT operator,
        #   so that ``l2norm(x)`` approximates ``l2norm(Ax)``.
        self.scale = NotImplemented

    def check_input(self, x, adjoint):
        """Checking that input has compatible shape.

        :param x: The input to this linear operator.
        :param bool adjoint: If true, ``x @ self`` is assumed, otherwise
          ``self @ x``.
        """
        try:
            assert len(x.shape) in {
                1,
                2,
            }, "Only vector or matrix input supported"
            #
            if adjoint:
                assert (
                    x.shape[-1] == self.shape[0]
                ), f"Mismatching shapes! {x.shape} <--> {self.shape}"
            else:
                assert (
                    x.shape[0] == self.shape[1]
                ), f"Mismatching shapes! {self.shape} <--> {x.shape}"
        except AssertionError as ae:
            raise ValueError from ae

    def __matmul__(self, x):
        """Forward (right) matrix-vector multiplication ``SSRFT @ x``.

        See parent class for more details.
        """
        self.check_input(x, adjoint=False)
        scale = (self.shape[1] / self.shape[0])**0.5
        return scale * ssrft(x, self.shape[0], seed=self.seed, dct_norm="ortho")

    def __rmatmul__(self, x):
        """Adjoint (left) matrix-vector multiplication ``x @ SSRFT``.

        See parent class for more details.
        """
        self.check_input(x, adjoint=True)
        scale = (self.shape[1] / self.shape[0])**0.5
        return scale * ssrft_adjoint(x, self.shape[1], seed=self.seed, dct_norm="ortho")

    def get_row(self, idx, dtype, device):
        """Returns SSRFT[idx, :] via left-matmul with a one-hot vector."""
        in_buff = torch.zeros(self.shape[0], dtype=dtype, device=device)
        in_buff[idx] = 1
        return in_buff @ self
    

def ssrft(x, out_dims, seed=0b1110101001010101011, dct_norm="ortho"):
    r"""Right (forward) matrix multiplication of the SSRFT.

    This function implements a matrix-free, right-matmul operator of the
    Scrambled Subsampled Randomized Fourier Transform (SSRFT) for real-valued
    signals, from `[TYUC2019, 3.2] <https://arxiv.org/abs/1902.08651>`_.

    .. math::

      \text{SSRFT} = R\,\mathcal{F}\,\Pi\,\mathcal{F}\,\Pi'

    Where :math:`R` is a random index-picker, \mathcal{F} is a Discrete Cosine
    Transform, and :math:`\Pi, \Pi'` are random permutations.

    :param x: Vector to be projected, such that ``y = SSRFT @ x``
    :param out_dims: Dimensions of output ``y``, must be less than ``dim(x)``
    :param seed: Random seed
    """
    # make sure all sources of randomness are CPU, to ensure cross-device
    # consistency of the operator
    if len(x.shape) != 1:
        raise ValueError("Only flat tensors supported!")
    x_len = len(x)
    assert out_dims <= x_len, "Projection to larger dimensions not supported!"
    seeds = [seed + i for i in range(5)]
    # first scramble: permute, rademacher, and DCT
    perm1 = randperm(x_len, seed=seeds[0], device="cpu")
    x, rad1 = rademacher_flip(x[perm1], seed=seeds[1], inplace=False)
    del perm1, rad1
    x = dct.dct(x, norm=dct_norm)
    # second scramble: permute, rademacher and DCT
    perm2 = randperm(x_len, seed=seeds[2], device="cpu")
    x, rad2 = rademacher_flip(x[perm2], seeds[3], inplace=False)
    del perm2, rad2
    x = dct.dct(x, norm=dct_norm)
    # extract random indices and return
    out_idxs = randperm(x_len, seed=seeds[4], device="cpu")[:out_dims]
    x = x[out_idxs]
    return x


def ssrft_adjoint(x, out_dims, seed=0b1110101001010101011, dct_norm="ortho"):
    r"""Left (adjoint) matrix multiplication of the SSRFT.

    Adjoint operator of SSRFT, such that ``x @ SSRFT = y``. See :func:`.ssrft`
    for more details. Note the following implementation detail:

    * Permutations are orthogonal transforms
    * Rademacher transforms are also orthogonal (also diagonal and self-inverse)
    * DCT/DFT are also orthogonal transforms
    * The index-picker :math:`R` is a subset of rows of I.

    With orthogonal operators, transform and inverse are the same. Therefore,
    this adjoint operator takes the following form:

    .. math::

       \text{SSRFT}^T =& (R\,\mathcal{F}\,\Pi\,\mathcal{F}\,\Pi')^T \\
       =& \Pi'^T \, \mathcal{F}^T \, \Pi^T \, \mathcal{F}^T \, R^T \\
       =& \Pi'^{-1} \, \mathcal{F}^{-1} \, \Pi^{-1} \, \mathcal{F}^{-1} \, R^T

    So we can make use of the inverses, except for :math:`R^T`, which is a
    column-truncated identity, so we embed the entries picked by :math:`R` into
    the corresponding indices, and leave the rest as zeros.
    """
    
    # make sure all sources of randomness are CPU, to ensure cross-device
    # consistency of the operator
    x_len = len(x)
    assert (
        out_dims >= x_len
    ), "Backprojection into smaller dimensions not supported!"
    #
    seeds = [seed + i for i in range(5)]
    result = torch.zeros(
        out_dims,
        dtype=x.dtype,
    ).to(x.device)
    # first embed signal into original indices
    out_idxs = randperm(out_dims, seed=seeds[4], device="cpu")[:x_len]
    result = torch.scatter(result, 0, out_idxs, x)
    #result[out_idxs] = x
    del x
    # then do the idct, followed by rademacher and inverse permutation
    result = dct.idct(result, norm=dct_norm)
    result, _ = rademacher_flip(result, seeds[3], inplace=False)
    perm2_inv = randperm(out_dims, seed=seeds[2], device="cpu", inverse=True)
    result = result[perm2_inv]
    del perm2_inv
    # second inverse pass
    result = dct.idct(result, norm=dct_norm)
    result, _ = rademacher_flip(result, seeds[1], inplace=False)
    perm1_inv = randperm(out_dims, seed=seeds[0], device="cpu", inverse=True)
    result = result[perm1_inv]
    #
    return result


def randperm(n, seed=None, device="cpu", inverse=False):
    """Reproducible randperm of ``n`` integers from  0 to (n-1) (both included).

    :param bool inverse: If False, a random permutation ``P`` is provided. If
      true, an inverse permutation ``Q`` is provided, such that both
      permutations are inverse to each other, i.e. ``v == v[P][Q] == v[Q][P]``.
    """
    rng = torch.Generator(device=device)
    rng.manual_seed(seed)
    #
    perm = torch.randperm(n, generator=rng, device=device)
    if inverse:
        # we take the O(N) approach since we anticipate large N
        inv = torch.empty_like(perm)
        inv[perm] = torch.arange(perm.size(0), device=perm.device)
        perm = inv
    return perm


def rademacher_flip(x, seed=None, inplace=True, rng_device="cpu"):
    """Reproducible random sign flip using Rademacher noise.

    .. note::
      This function makes use of :func:`uniform_noise` to sample the Rademacher
      noise. If ``x`` itself has been generated using ``uniform_noise``, make
      sure to use a different seed to mitigate correlations.

    .. warning::
      See :func:`rademacher_noise` for notes on reproducibility and more info.
    """
    mask = rademacher_noise(x.shape, seed, device=rng_device).to(x.device)
    if inplace:
        x *= mask
        return x, mask
    else:
        return x * mask, mask
    
def rademacher_noise(shape, seed=None, device="cpu"):
    """Reproducible Rademacher noise.

    .. note::
      This function makes use of :func:`uniform_noise` to sample the Rademacher
      noise. If ``x`` itself has been generated using ``uniform_noise``, make
      sure to use a different seed to mitigate correlations.

    .. warning::
      PyTorch does not ensure RNG reproducibility across
      devices. This parameter determines the device to generate the noise from.
      If you want cross-device reproducibility, make sure that the noise is
      always generated from the same device.

    :param shape: Shape of the output tensor with Rademacher noise.
    :param seed: Seed for the randomness.
    :param device: Device of the output tensor and also source for the noise.
      See warning.
    """
    noise = (
        uniform_noise(shape, seed=seed, dtype=torch.float32, device=device)
        > 0.5
    ) * 2 - 1
    return noise

def uniform_noise(shape, seed=None, dtype=torch.float64, device="cpu"):
    """Reproducible ``torch.rand`` uniform noise.

    :returns: A tensor of given shape, dtype and device, containing uniform
      random noise between 0 and 1 (analogous to ``torch.rand``), but with
      reproducible behaviour fixed to given random seed.
    """
    rng = torch.Generator(device=device)
    rng.manual_seed(seed)
    noise = torch.rand(shape, generator=rng, dtype=dtype, device=device)
    return noise


def gaussian_noise(
    shape, mean=0.0, std=1.0, seed=None, dtype=torch.float64, device="cpu"
):
    """Reproducible ``torch.normal`` Gaussian noise.

    :returns: A tensor of given shape, dtype and device, containing gaussian
      noise with given mean and std (analogous to ``torch.normal``), but with
      reproducible behaviour fixed to given random seed.
    """
    rng = torch.Generator(device=device)
    rng.manual_seed(seed)
    #
    noise = torch.zeros(shape, dtype=dtype, device=device)
    noise.normal_(mean=mean, std=std, generator=rng)
    return noise


