import numpy as np
import itertools
import os
import string
from itertools import permutations


def mse(a, b):
    """"""
    return ((a - b) ** 2).mean()


def max_ae(a, b):
    """"""
    return np.max(np.abs(a - b))


def mse_perm(a, b, return_perm=False):
    valid_perms = []
    a_shape = tuple(a.shape)
    for perm in itertools.permutations(range(b.ndim)):
        if tuple(b.shape[p] for p in perm) == a_shape:
            valid_perms.append(perm)

    if not valid_perms:
        raise ValueError(
            f"No permutation of b with shape {b.shape} matches a.shape {a_shape}."
        )

    errs = []
    for perm in valid_perms:
        errs.append(mse(a, b.transpose(perm)))

    ix = np.argmin(errs)
    if return_perm:
        return errs[ix], valid_perms[ix]
    return errs[ix]


def max_ae_perm(a, b, return_perm=False):
    errs = [
        max_ae(a, b.transpose(perm)) for perm in itertools.permutations(np.arange(3))
    ]
    ix = np.argmin(errs)
    if return_perm:
        return errs[ix], list(itertools.permutations(np.arange(3)))[ix]
    return errs[ix]


def multimap(A, V_array):
    """Compute a tensor product as a multilinear map.(pg. 2778, Section 2)

    Parameters
    ----------
    A
        A multidimensional tensor
    V_array
        Array of vectors to compute tensor against

    """
    factors = []
    for V in V_array:
        V = np.asarray(V)
        if V.ndim == 1:
            V = V[:, None]
        factors.append(V)

    if not factors:
        raise ValueError("V_array must contain at least one factor.")

    n = factors[0].shape[0]
    if any(V.shape[0] != n for V in factors):
        raise ValueError("All factors must share the same number of rows.")

    p = len(factors)
    if p > 26:
        raise ValueError("multimap currently supports up to 26 factors.")

    input_idx = string.ascii_lowercase[:p]
    output_idx = string.ascii_uppercase[:p]

    # Legacy loop-based implementation for reference:
    # dims = [a.shape[1] for a in factors]
    # dim_ranges = [range(a.shape[1]) for a in factors]
    # B = np.zeros(dims)
    # all_indices = list(itertools.product(*dim_ranges))  # i_1,...,i_p
    # all_vectors = list(itertools.product(range(n), repeat=p))  # j_1,...,j_p
    # for ind in all_indices:
    #     for vec in all_vectors:
    #         tmp = A[vec]
    #         for k in range(p):
    #             tmp *= factors[k][vec[k], ind[k]]
    #         B[ind] += tmp
    # return B

    einsum_str = ",".join(
        [input_idx] + [f"{input_idx[i]}{output_idx[i]}" for i in range(p)]
    ) + "->" + output_idx
    return np.einsum(einsum_str, np.asarray(A), *factors)


def two_tensor_prod(w, x, y):
    """
    A type of outer product
    """
    w = np.asarray(w)
    x = np.asarray(x)
    y = np.asarray(y)
    # Legacy loop-based implementation for reference:
    # r = x.shape[0]
    # M2 = np.zeros([r, r])
    # for a in range(w.shape[0]):
    #     for i in range(x.shape[0]):
    #         for j in range(y.shape[0]):
    #             M2[i, j] += w[a] * x[i, a] * y[j, a]
    # return M2
    return np.einsum("a,ia,ja->ij", w, x, y)


def three_tensor_prod(w, x, y, z):
    """
    Three-way outer product
    """
    w = np.asarray(w)
    x = np.asarray(x)
    y = np.asarray(y)
    z = np.asarray(z)

    # Legacy loop-based implementation for reference:
    # r = x.shape[0]
    # M3 = np.zeros([r, r, r])
    # if len(w.shape) == 0:
    #     for i in range(x.shape[0]):
    #         for j in range(y.shape[0]):
    #             for k in range(z.shape[0]):
    #                 M3[i, j, k] += w * x[i] * y[j] * z[k]
    # else:
    #     for a in range(w.shape[0]):
    #         for i in range(x.shape[0]):
    #             for j in range(y.shape[0]):
    #                 for k in range(z.shape[0]):
    #                     M3[i, j, k] += w[a] * x[i, a] * y[j, a] * z[k, a]
    # return M3

    if w.ndim == 0:
        return w * np.einsum("i,j,k->ijk", x, y, z)
    return np.einsum("a,ia,ja,ka->ijk", w, x, y, z)


def T_map(T, u):
    """Power method base transformation (pg. 2790, equation (5))

    Parameters
    ----------
    T
        A multidimensional tensor
    u
        A candidate eigenvector

    Returns
    -------
    t
        Transformed candidate

    """

    u = np.asarray(u)
    # Legacy loop-based implementation for reference:
    # d = u.shape[0]
    # t = np.zeros(d)
    # for i in range(d):
    #     for j in range(d):
    #         for k in range(d):
    #             t[i] += T[i, j, k] * u[j] * u[k]
    # return t
    return np.einsum("ijk,j,k->i", T, u, u)


def tensor_decomp(M2, M3, comps):
    """Tensor Decomposition Algorithm (pg. 2795, Algorithm 1)
    This is combined with reduction (4.3.1)

    Parameters
    ----------
    M2
        Symmetric matrix to aid the decomposition
    M3
        Symmetric tensor to be decomposed
    comps
        Number of eigencomponents to return

    Returns
    -------
    mu_rec
        Recovered eigenvectors (a matrix with #comps eigenvectors)
    lam_rec
        Recovered eigenvalues (a vector with #comps eigenvalues)

    """
    lam_rec = np.zeros(comps)
    mu_rec = np.zeros((M2.shape[0], comps))

    for b in range(comps):
        # initial eigendecomposition used in reduction (4.3.1)
        lam, v = np.linalg.eigh(M2)
        idx = lam.argsort()[::-1]
        lam = lam[idx]
        v = v[:, idx]

        # keep only the positive eigenvalues
        n_eigpos = np.sum(lam > 1e-2)
        if n_eigpos > 0:
            W = v[:, :n_eigpos] @ np.diag(1.0 / np.sqrt(np.abs(lam[:n_eigpos])))

            B = np.linalg.pinv(W.T)  # pseudo-inverse stabilizes the reduction
            M3_tilde = multimap(M3, [W, W, W])  # reduction complete

            # decomposition setup
            # adjust hyperparameters if this routine stalls
            N = 10  # number of power iterations
            restarts = 5  # number of random restarts # NOTE critical
            tau_star = 0  # best robust eigenvalue so far
            u_star = np.zeros(n_eigpos)  # best eigenvector so far

            # repeated restarts to find best eigenvector
            for j in range(restarts):
                # randomly draw from unit sphere (step 2)
                # u = np.random.randn(n_eigpos)
                u = np.random.multivariate_normal(np.zeros(n_eigpos), np.eye(n_eigpos))
                u /= np.linalg.norm(u)

                # power iteration for N iterations
                for i in range(N):
                    u = T_map(M3_tilde, u)
                    u /= np.linalg.norm(u)

                # check for best eigenvalue
                if j == 0 or (j > 0 and multimap(M3_tilde, [u, u, u]) > tau_star):
                    tau_star = multimap(M3_tilde, [u, u, u])
                    u_star = u

            # N more power iterations for best eigenvector found
            u = u_star
            for i in range(N):
                u = T_map(M3_tilde, u)
                u /= np.linalg.norm(u)

            # recovered modified (post-reduction) eigenvalue
            lamb = (T_map(M3_tilde, u) / u)[0]

            # recover original eigenvector and eigenvalue pair
            mu_rec[:, b] = lamb * B @ u
            lam_rec[b] = 1 / lamb**2

            # deflation: remove component, repeat
            M2 -= lam_rec[b] * np.outer(mu_rec[:, b], mu_rec[:, b])
            M3 -= three_tensor_prod(
                np.array(lam_rec[b]), mu_rec[:, b], mu_rec[:, b], mu_rec[:, b]
            )

    return mu_rec, lam_rec


def lowrank(x, k):
    u, s, vh = np.linalg.svd(x)
    s_abs = np.abs(s)
    inds = np.argsort(s_abs)[::-1][:k]
    rec = np.zeros_like(x)
    for i in inds:
        rec += s[i] * np.outer(u.T[i], vh[i])
    return rec


def tensor_decomp_x3(
    w, x1, x2, x3, k=None, debug=False, return_errs=False, savedir=None
):
    if k is None:
        k = w.shape[0]

    ex32 = lowrank(np.einsum("i,ji,ki->jk", w, x3, x2), k=k)
    ex12 = lowrank(np.einsum("i,ji,ki->jk", w, x1, x2), k=k)

    ex12_inv = np.linalg.pinv(ex12)  # pseudo-inverse handles rank deficiency

    ex31 = lowrank(np.einsum("i,ji,ki->jk", w, x3, x1), k=k)
    ex21 = lowrank(np.einsum("i,ji,ki->jk", w, x2, x1), k=k)

    ex21_inv = np.linalg.pinv(ex21)  # pseudo-inverse handles rank deficiency

    x_tilde_1 = ex32 @ ex12_inv @ x1
    x_tilde_2 = ex31 @ ex21_inv @ x2
    M2 = np.einsum("i,ji,ki->jk", w, x_tilde_1, x_tilde_2)
    M3 = np.einsum("i,ji,ki,li->jkl", w, x_tilde_1, x_tilde_2, x3)

    if savedir is not None:
        if not os.path.exists(savedir):
            os.makedirs(savedir)
        np.save(savedir + "/" + "ex32", ex32)
        np.save(savedir + "/" + "ex12", ex12)
        np.save(savedir + "/" + "ex12_inv", ex12_inv)
        np.save(savedir + "/" + "ex31", ex31)
        np.save(savedir + "/" + "ex21", ex21)
        np.save(savedir + "/" + "ex21_inv", ex21_inv)
        np.save(savedir + "/" + "x_tilde_1", x_tilde_1)
        np.save(savedir + "/" + "x_tilde_2", x_tilde_2)
        np.save(savedir + "/" + "M2", M2)
        np.save(savedir + "/" + "M3", M3)
        print("saved")

    factors, weights = tensor_decomp(M2, M3, k)

    try:
        err = mse(factors, x3)
    except ValueError:
        # print("[TENSOR_DECOMP] cannot compute error due to rank...")
        err = -1.0
    if debug:
        print(f"[TENSOR_DECOMP] error:", err)
    if return_errs:
        return weights, factors, err
    return weights, factors


def mixture_tensor_decomp_full_inner(
    w, x1, x2, x3, k=None, debug=False, return_errs=False, savedir=None
):
    w_rec_1, x3_rec, err_3_12 = tensor_decomp_x3(
        w, x1, x2, x3, k=k, debug=debug, return_errs=True, savedir=savedir
    )
    #print('1', w_rec_1)
    w_rec_2, x2_rec, err_2_13 = tensor_decomp_x3(
        w, x1, x3, x2, k=k, debug=debug, return_errs=True, savedir=savedir
    )
    #print('2', w_rec_2)
    w_rec_3, x1_rec, err_1_23 = tensor_decomp_x3(
        w, x2, x3, x1, k=k, debug=debug, return_errs=True, savedir=savedir
    )
    #print('3', w_rec_3)
    w_rec = np.mean([w_rec_1,w_rec_2,w_rec_3],axis=0)

    return w_rec, x1_rec, x2_rec, x3_rec


def mixture_tensor_decomp_full(w, x1, x2, x3, k=None, debug=False, savedir=None):
    T = np.einsum("i,ji,ki,li->jkl", w, x1, x2, x3)
    T_hat = np.zeros_like(T)
    print(x1.shape,x2.shape,x3.shape)
    max_iters = 10  # iteration cap for alternating decomposition
    eps = 1e-2
    err_best = np.inf
    w_rec_best, x1_rec_best, x2_rec_best, x3_rec_best = None, None, None, None
    for i in range(max_iters):
        #print(f"ITERATION {i}")
        w_rec, x1_rec, x2_rec, x3_rec = mixture_tensor_decomp_full_inner(
            w, x1, x2, x3, k=k, debug=False, savedir=savedir
        )
        T_hat = np.einsum("i,ji,ki,li->jkl", w_rec, x1_rec, x2_rec, x3_rec)

        err, perm = mse_perm(T, T_hat, return_perm=True)
        if err <= err_best:
            w_rec_best, x1_rec_best, x2_rec_best, x3_rec_best = (
                w_rec,
                x1_rec,
                x2_rec,
                x3_rec,
            )
            err_best = err
        #print(err, perm)
        if err <= eps:
            return w_rec_best, x1_rec_best, x2_rec_best, x3_rec_best

    return w_rec_best, x1_rec_best, x2_rec_best, x3_rec_best


def main():
    # tolerance parameter is currently hard-coded

    dim = 10
    eps = 0.2
    for p in np.arange(0, 1.0 + eps, eps):
        w = np.array([p, 1.0 - p])
        k = len(w)
        x1 = np.random.normal(size=(dim, k)) + 1
        x2 = np.random.normal(size=(dim, k)) + 2
        x3 = np.random.normal(size=(dim, k)) + 3

        w_rec, x1_rec, x2_rec, x3_rec = mixture_tensor_decomp_full(
            w, x1, x2, x3, debug=True
        )

        T = np.einsum("i,ji,ki,li->jkl", w, x1, x2, x3)
        T_rec = np.einsum("i,ji,ki,li->jkl", w_rec, x1_rec, x2_rec, x3_rec)

        # print(np.sort(w_rec)[0], np.sort(w)[0])
        print(mse_perm(T, T_rec, return_perm=True))


if __name__ == "__main__":
    main()
