from typing import Union

import torch
import numpy as np


def multidimensional_unfold(tensor: torch.Tensor, kernel_size: tuple, stride: tuple,
                            device: torch.device = torch.device('cpu')) -> torch.Tensor:
    r"""Unfolds `tensor` by extracting patches of shape `kernel_size`.

    Reshaping and traversal for patch extraction both follow C-order convention (last index changes the fastest).

    Args:
        tensor: Input tensor to be unfolded with shape [N, *spatial_dims] (N is batch dimension)
        kernel_size: Patch size.
        stride: Stride of multidimensional traversal.
        device: Device used for operations.

    Returns:
       Unfolded tensor with shape [N, :math:`\prod_k kernel_size[k]`, L]

    """

    s_dims = tensor.shape[1:]  # spatial dimensions

    # Number of positions along each axis
    num_positions = [np.floor((s_dims[i] - (kernel_size[i] - 1) - 1) / stride[i] + 1).astype(int)
                     for i in range(len(s_dims))]

    # Start indices for each position in each axis
    positions = [torch.tensor([n * stride[i] for n in range(num_positions[i] - 1, -1, -1)]) for i in
                 range(len(num_positions))]

    # Each column is a flattened patch
    output = torch.zeros(tensor.size(0), np.prod(kernel_size).item(), np.prod(num_positions).item(), device=device)

    for i, pos in enumerate(torch.cartesian_prod(*positions)):
        start_pos = torch.tensor([0, *pos])
        end_pos = torch.tensor([tensor.size(0), *(pos + torch.tensor(kernel_size))])
        patch = multidimensional_slice(tensor, start_pos, end_pos)  # n,f2,c2,h2,w2
        output[:, :, np.prod(num_positions).item() - 1 - i] = patch.reshape(tensor.size(0), -1)

    return output
if __name__ == "__main__":
    tensori = torch.randn(1, 3, 3, 3)
    kernel_size = (2, 2)
    stride = (1, 1)
    unfolded = multidimensional_unfold(tensori, kernel_size, stride)

def multidimensional_slice(tensor: torch.Tensor, start: torch.Tensor, stop: torch.Tensor) -> torch.Tensor:
    """Returns A[start_1:stop_1, ..., start_n:stop_n] for tensor A"

    Args:
        tensor: Input tensor `A`
        start: start indices
        stop: stop indices

    Returns:
         A[start_1:stop_1, ..., start_n:stop_n]
    """
    slices = [slice(start[i], stop[i]) for i in range(len(start))]
    return tensor[slices]


def kron(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
    """Kronecker product between factors `a` and `b`

    Args:
        a: First factor
        b: Second factor

    Returns:
        Tensor containing kronecker product between `a` and `b`
    """

    a = torch.from_numpy(a) if isinstance(a, np.ndarray) else a
    b = torch.from_numpy(b) if isinstance(b, np.ndarray) else b

    return torch.stack([torch.kron(a[k], b[k]) for k in range(a.shape[0])]).sum(dim=0)

def gkpd(tensor: torch.Tensor, a_shape: Union[list, tuple], b_shape: Union[list, tuple],
         atol: float = 1e-3) -> tuple:
    """Finds Kronecker decomposition of `tensor` via SVD.
    Patch traversal and reshaping operations all follow a C-order convention (last dimension changing fastest).
    Args:
        tensor (torch.Tensor): Tensor to be decomposed.
        a_shape (list, tuple): Shape of first Kronecker factor.
        b_shape (list, tuple): Shape of second Kronecker factor.
        atol (float): Tolerance for determining tensor rank.

    Returns:
        a_hat: [rank, *a_shape]
        b_hat: [rank, *b_shape]
    """

    if not np.all(np.array([a_shape, b_shape]).prod(axis=0) == np.array(tensor.shape)):
        raise ValueError("Received invalid factorization dimensions for tensor during its GKPD decomposition")

    with torch.no_grad():
        w_unf = multidimensional_unfold(
            tensor.unsqueeze(0), kernel_size=b_shape, stride=b_shape
        )[0].T  # [num_positions, prod(s_dims)]

        u, s, v = torch.svd(w_unf)
        rank = len(s.detach().numpy()[np.abs(s.detach().numpy()) > atol])

        # Note: pytorch reshaping follows C-order as well
        a_hat = torch.stack([s[i].item() * u[:, i].reshape(*a_shape) for i in range(rank)])  # [rank, *a_shape]
        b_hat = torch.stack([v.T[i].reshape(*b_shape) for i in range(rank)])  # [rank, *b_shape]

    return a_hat, b_hat