import warnings
import tensorly as tl
from ..tenalg import khatri_rao
from ..cp_tensor import CPTensor, validate_cp_rank, cp_to_tensor, cp_normalize
from ._cp import initialize_cp


# Authors: Isabell Lehmann <isabell.lehmann94@outlook.de>

# License: BSD 3 clause


def coupled_matrix_tensor_3d_factorization(
    tensor_3d,
    matrix,
    rank,
    init="svd",
    n_iter_max=100,
    tol=1e-6,
    normalize_factors=False,
):
    """
    Calculates a coupled matrix and tensor factorization of 3rd order tensor and matrix which are
    coupled in first mode.

    Assume you have tensor_3d = [[lambda; A, B, C]] and matrix = [[gamma; A, V]], which are
    coupled in 1st mode. With coupled matrix and tensor factorization (CTMF), the normalized
    factor matrices A, B, C for the CP decomposition of X, the normalized matrix V and the
    weights lambda_ and gamma are found. This implementation only works for a coupling in the
    first mode.

    Solution is found via alternating least squares (ALS) as described in Figure 5 of
    @article{acar2011all,
      title={All-at-once optimization for coupled matrix and tensor factorizations},
      author={Acar, Evrim and Kolda, Tamara G and Dunlavy, Daniel M},
      journal={arXiv preprint arXiv:1105.3422},
      year={2011}
    }

    Notes
    -----
    In the paper, the columns of the factor matrices are not normalized and therefore weights are
    not included in the algorithm.

    Parameters
    ----------
    tensor_3d : tl.tensor or CP tensor
        3rd order tensor X = [[A, B, C]]
    matrix : tl.tensor or CP tensor
        matrix that is coupled with tensor in first mode: Y = [[A, V]]
    rank : int
        rank for CP decomposition of X
    tol : float, optional
        (Default: 1e-6) Relative reconstruction error tolerance. The
        algorithm is considered to have found the global minimum when the
        reconstruction error is less than `tol`.

    Returns
    -------
    tensor_3d_pred : CPTensor
        tensor_3d_pred = [[lambda; A,B,C]]
    matrix_pred : CPTensor
        matrix_pred = [[gamma; A,V]]
    rec_errors : list
        contains the reconstruction error of each iteration:
        error = 1 / 2 * | X - [[ lambda_; A, B, C ]] | ^ 2 + 1 / 2 * | Y - [[ gamma; A, V ]] | ^ 2

    Examples
    --------
    A = tl.tensor([[1, 2], [3, 4]])
    B = tl.tensor([[1, 0], [0, 2]])
    C = tl.tensor([[2, 0], [0, 1]])
    V = tl.tensor([[2, 0], [0, 1]])
    R = 2

    X = (None, [A, B, C])
    Y = (None, [A, V])

    tensor_3d_pred, matrix_pred = cmtf_als_for_third_order_tensor(X, Y, R)

    """
    rank = validate_cp_rank(tl.shape(tensor_3d), rank=rank)

    # initialize values
    tensor_cp = initialize_cp(tensor_3d, rank, init=init)
    # the coupled factor should be initialized with the concatenated dataset
    coupled_unfold = tl.concatenate((tl.unfold(tensor_3d, 0), matrix), axis=1)
    coupled_init = initialize_cp(coupled_unfold, rank, init=init)
    tensor_cp.factors[0] = coupled_init.factors[0]

    rec_errors = []

    # alternating least squares
    # note that the order of the khatri rao product is reversed since tl.unfold has another order
    # than assumed in paper
    for iteration in range(n_iter_max):
        V = tl.transpose(tl.lstsq(tensor_cp.factors[0], matrix)[0])

        # Loop over modes of the tensor
        # We want to solve for mode 0 last, since the coupled factor matrix is most influential and SVD gave us a good approximation
        for ii in reversed(range(tl.ndim(tensor_3d))):
            kr = khatri_rao(tensor_cp.factors, skip_matrix=ii)
            unfolded = tl.unfold(tensor_3d, ii)

            # If we are at the coupled mode, concat the matrix
            if ii == 0:
                kr = tl.concatenate((kr, V), axis=0)
                unfolded = tl.concatenate((unfolded, matrix), axis=1)

            tensor_cp.factors[ii] = tl.transpose(
                tl.lstsq(kr, tl.transpose(unfolded))[0]
            )

        error_new = (
            tl.norm(tensor_3d - cp_to_tensor(tensor_cp)) ** 2
            + tl.norm(matrix - cp_to_tensor((None, [tensor_cp.factors[0], V]))) ** 2
        )

        if iteration > 0 and (
            tl.abs(error_new - error_old) / error_old <= tol or error_new < tol
        ):
            break

        error_old = error_new
        rec_errors.append(error_new)

        if iteration == n_iter_max - 1:
            warnings.warn("Reached maximum iteration number without convergence.")

    matrix_pred = CPTensor((None, [tensor_cp.factors[0], V]))

    if normalize_factors:
        tensor_cp = cp_normalize(tensor_cp)
        matrix_pred = cp_normalize(matrix_pred)

    return tensor_cp, matrix_pred, rec_errors
