import tensorly as tl
from ._base_decomposition import DecompositionMixin
from ..tr_tensor import validate_tr_rank, TRTensor
from ..tenalg.svd import svd_interface


def tensor_ring(input_tensor, rank, mode=0, svd="truncated_svd", verbose=False):
    """Tensor Ring decomposition via recursive SVD

        Decomposes `input_tensor` into a sequence of order-3 tensors (factors) [1]_.

    Parameters
    ----------
    input_tensor : tensorly.tensor
    rank : Union[int, List[int]]
            maximum allowable TR rank of the factors
            if int, then this is the same for all the factors
            if int list, then rank[k] is the rank of the kth factor
    mode : int, default is 0
            index of the first factor to compute
    svd : str, default is 'truncated_svd'
        function to use to compute the SVD, acceptable values in tensorly.SVD_FUNS
    verbose : boolean, optional
            level of verbosity

    Returns
    -------
    factors : TR factors
              order-3 tensors of the TR decomposition

    References
    ----------
    .. [1] Qibin Zhao et al. "Tensor Ring Decomposition" arXiv preprint arXiv:1606.05535, (2016).
    """
    rank = validate_tr_rank(tl.shape(input_tensor), rank=rank)
    n_dim = len(input_tensor.shape)

    # Change order
    if mode:
        order = tuple(range(mode, n_dim)) + tuple(range(mode))
        input_tensor = tl.transpose(input_tensor, order)
        rank = rank[mode:] + rank[:mode]

    tensor_size = input_tensor.shape

    factors = [None] * n_dim

    # Getting the first factor
    unfolding = tl.reshape(input_tensor, (tensor_size[0], -1))

    n_row, n_column = unfolding.shape
    if rank[0] * rank[1] > min(n_row, n_column):
        raise ValueError(
            f"rank[{mode}] * rank[{mode + 1}] = {rank[0] * rank[1]} is larger than "
            f"first matricization dimension {n_row}×{n_column}.\n"
            "Failed to compute first factor with specified rank. "
            "Reduce specified ranks or change first matricization `mode`."
        )

    # SVD of unfolding matrix
    U, S, V = svd_interface(unfolding, n_eigenvecs=rank[0] * rank[1], method=svd)

    # Get first TR factor
    factor = tl.reshape(U, (tensor_size[0], rank[0], rank[1]))
    factors[0] = tl.transpose(factor, (1, 0, 2))
    if verbose is True:
        print("TR factor " + str(mode) + " computed with shape " + str(factor.shape))

    # Get new unfolding matrix for the remaining factors
    unfolding = tl.reshape(S, (-1, 1)) * V
    unfolding = tl.reshape(unfolding, (rank[0], rank[1], -1))
    unfolding = tl.transpose(unfolding, (1, 2, 0))

    # Getting the TR factors up to n_dim - 1
    for k in range(1, n_dim - 1):
        # Reshape the unfolding matrix of the remaining factors
        n_row = int(rank[k] * tensor_size[k])
        unfolding = tl.reshape(unfolding, (n_row, -1))

        # SVD of unfolding matrix
        n_row, n_column = unfolding.shape
        current_rank = min(n_row, n_column, rank[k + 1])
        U, S, V = svd_interface(unfolding, n_eigenvecs=current_rank, method=svd)
        rank[k + 1] = current_rank

        # Get kth TR factor
        factors[k] = tl.reshape(U, (rank[k], tensor_size[k], rank[k + 1]))

        if verbose is True:
            print(
                "TR factor "
                + str((mode + k) % n_dim)
                + " computed with shape "
                + str(factors[k].shape)
            )

        # Get new unfolding matrix for the remaining factors
        unfolding = tl.reshape(S, (-1, 1)) * V

    # Getting the last factor
    prev_rank = unfolding.shape[0]
    factors[-1] = tl.reshape(unfolding, (prev_rank, -1, rank[0]))

    if verbose is True:
        print(
            "TR factor "
            + str((mode - 1) % n_dim)
            + " computed with shape "
            + str(factors[-1].shape)
        )

    # Reorder factors to match input
    if mode:
        factors = factors[-mode:] + factors[:-mode]

    return TRTensor(factors)


class TensorRing(DecompositionMixin):
    """Tensor Ring decomposition via recursive SVD

        Decomposes `input_tensor` into a sequence of order-3 tensors (factors) [1]_.

    Parameters
    ----------
    input_tensor : tensorly.tensor
    rank : Union[int, List[int]]
            maximum allowable TR rank of the factors
            if int, then this is the same for all the factors
            if int list, then rank[k] is the rank of the kth factor
    mode : int, default is 0
            index of the first factor to compute
    svd : str, default is 'truncated_svd'
        function to use to compute the SVD, acceptable values in tensorly.SVD_FUNS
    verbose : boolean, optional
            level of verbosity

    Returns
    -------
    factors : TR factors
              order-3 tensors of the TR decomposition

    References
    ----------
    .. [1] Qibin Zhao et al. "Tensor Ring Decomposition" arXiv preprint arXiv:1606.05535, (2016).
    """

    def __init__(self, rank, mode=0, svd="truncated_svd", verbose=False):
        self.rank = rank
        self.mode = mode
        self.svd = svd
        self.verbose = verbose

    def fit_transform(self, tensor):
        self.decomposition_ = tensor_ring(
            tensor, rank=self.rank, mode=self.mode, svd=self.svd, verbose=self.verbose
        )
        return self.decomposition_
