import numpy as np
import numpy.linalg as la
import copy
from tqdm import tqdm


def get_channel_structures(encoder_type, enc_krn, xyz_dec_krn):
    enc_krn = copy.deepcopy(enc_krn)
    xyz_dec_krn = copy.deepcopy(xyz_dec_krn)
    num_channels = [None] * 2
    num_latent = []
    transform_flag = None
    if encoder_type == 'xyz':
        num_latent.append(enc_krn.pop())
        enc_krn.insert(0, 3)
        num_channels[0] = [enc_krn]
    elif encoder_type == 'cfan':
        transform_flag = True
        enc_krn.insert(0, 4)
        num_latent.append(enc_krn.pop())
        num_channels[0] = [enc_krn]
    else:
        cf_enc_krn = enc_krn.copy()
        nrm_enc_krn = enc_krn.copy()
        transform_flag = True
        cf_enc_krn.insert(0, 1)
        nrm_enc_krn.insert(0, 3)
        num_latent.append(cf_enc_krn.pop())
        num_latent.append(nrm_enc_krn.pop())
        num_channels[0] = [cf_enc_krn, nrm_enc_krn]
    xyz_dec_krn.append(3)
    num_channels[1] = xyz_dec_krn
    return num_channels, num_latent, transform_flag


def sparse_cca2(X, Y, split_idx=None):
    if split_idx is None:
        split_idx = int(X.shape[1] / 2)
    X -= np.mean(X, axis=0)
    Y -= np.mean(Y, axis=0)
    X /= np.std(X, axis=0)
    Y /= np.std(Y, axis=0)
    X /= X.shape[0] ** 0.5
    Y /= Y.shape[0] ** 0.5
    Y_cf = Y[:, :split_idx]
    Y_nrm = Y[:, split_idx:]

    A = np.zeros([X.shape[1], X.shape[1]])
    B = np.zeros([Y.shape[1], Y.shape[1]])
    S_cancorr = np.zeros(X.shape[1])

    cf_count = 0
    nrm_count = 0

    def proc(a_c, b_c, c_c, a_n, b_n, c_n, c_count, n_count):
        print('Index: %02d, Conformal Correlation: %.3f, Normal Correlation: %.3f' % (c_count + n_count, c_c, c_n))
        if c_c >= c_n:
            A[:, c_count] = a_c
            B[:split_idx, c_count] = b_c
            S_cancorr[c_count] = c_c
            c_count += 1
        else:
            A[:, split_idx + n_count] = a_n
            B[split_idx:, split_idx + n_count] = b_n
            S_cancorr[split_idx + n_count] = c_n
            n_count += 1
        return c_count, n_count

    a_cf, b_cf, corr_cf = sparse_cca_iterate2(X, Y_cf, A, B[:split_idx, :split_idx])
    a_nrm, b_nrm, corr_nrm = sparse_cca_iterate2(X, Y_nrm, A, B[split_idx:, split_idx:])
    cf_count, nrm_count = proc(a_cf, b_cf, corr_cf, a_nrm, b_nrm, corr_nrm, cf_count, nrm_count)
    while cf_count < split_idx or nrm_count < (X.shape[1] - split_idx):
        while cf_count < split_idx and corr_cf >= corr_nrm:
            a_cf, b_cf, corr_cf = sparse_cca_iterate2(X, Y_cf, A, B[:split_idx, :split_idx])
            if corr_cf >= corr_nrm:
                cf_count, nrm_count = proc(a_cf, b_cf, corr_cf, a_nrm, b_nrm, corr_nrm, cf_count, nrm_count)
        if cf_count >= split_idx:
            corr_cf = -2
        while nrm_count < (X.shape[1] - split_idx) and corr_nrm > corr_cf:
            a_nrm, b_nrm, corr_nrm = sparse_cca_iterate2(X, Y_nrm, A, B[split_idx:, split_idx:])
            if corr_nrm > corr_cf:
                cf_count, nrm_count = proc(a_cf, b_cf, corr_cf, a_nrm, b_nrm, corr_nrm, cf_count, nrm_count)
        if nrm_count >= (X.shape[1] - split_idx):
            corr_nrm = -2
    return A, B, S_cancorr


def sparse_cca_iterate2(X, Y, A, B, max_iter=2000, tol=1E-8):
    Cov = np.transpose(X) @ Y
    a = np.ones([X.shape[1], 1])
    b = np.ones([Y.shape[1], 1])
    # First solve for a

    non_zero_cols = np.unique(np.nonzero(A)[1])
    A = A[:, non_zero_cols]
    non_zero_cols = np.unique(np.nonzero(B)[1])
    B = B[:, non_zero_cols]

    XA = X @ A
    XA_proj = XA @ np.transpose(XA)

    YB = Y @ B
    YB_proj = YB @ np.transpose(YB)
    X_proj = np.linalg.pinv(X)
    Y_proj = np.linalg.pinv(Y)

    A_trans = X_proj @ (np.eye(XA_proj.shape[0]) - XA_proj) @ Y
    B_trans = Y_proj @ (np.eye(YB_proj.shape[0]) - YB_proj) @ X
    # X a = b -> U E V' a = b -> a =~ V E^{-1} U' b
    for _ in range(max_iter):
        a_old = a.copy()
        b_old = b.copy()

        a = A_trans @ b
        Xa = X @ a
        xa_len = np.sum(Xa ** 2) ** 0.5
        a /= xa_len

        b = B_trans @ a
        yb_len = np.sum((Y @ b) ** 2) ** 0.5
        b /= yb_len

        a_diff = np.sum((a - a_old) ** 2)
        b_diff = np.sum((b - b_old) ** 2)
        if a_diff < tol and b_diff < tol:
            break
    cancor = np.transpose(a) @ Cov @ b

    a = a.squeeze()
    b = b.squeeze()
    cancor = cancor.squeeze()
    return a, b, cancor


def sparse_cca(X, Y):
    """
    Maximizes trace(a^T x^T y b) s.t. a^T x^T x a = I, b^t y^t y b = I,
    :param x:
    :param y:
    :param sparse_pattern:
    :param lr:
    :param iter:
    :return a:
    :return b:
    """

    A = np.zeros([X.shape[1], X.shape[1]])
    B = np.zeros([Y.shape[1], Y.shape[1]])
    S_cancorr = np.zeros(X.shape[1])

    cf_count = 0
    nrm_count = 0

    last_corr_cf = -2
    last_corr_nrm = -2
    corr_cf_update = True
    for idx in range(X.shape[1]):
        if cf_count < 16 and corr_cf_update:
            a_cf, b_cf, corr_cf = sparse_cca_iterate(X, Y, A, B, np.arange(16, 32))
            last_corr_cf = corr_cf.copy()
            if last_corr_cf < last_corr_nrm or idx == 0:
                a_nrm, b_nrm, corr_nrm = sparse_cca_iterate(X, Y, A, B, np.arange(16))
                last_corr_nrm = corr_nrm.copy()
        elif corr_cf_update:
            corr_cf = -2
            corr_cf_update = False
        if nrm_count < 16 and not corr_cf_update:
            a_nrm, b_nrm, corr_nrm = sparse_cca_iterate(X, Y, A, B, np.arange(16))
            last_corr_nrm = corr_nrm.copy()
            if last_corr_nrm < last_corr_cf:
                a_cf, b_cf, corr_cf = sparse_cca_iterate(X, Y, A, B, np.arange(16, 32))
                last_corr_cf = corr_cf.copy()
        elif not corr_cf_update:
            corr_nrm = -2
            a_cf, b_cf, corr_cf = sparse_cca_iterate(X, Y, A, B, np.arange(16, 32))
            last_corr_cf = corr_cf.copy()
            corr_cf_update = True
        print('Index: %02d, Conformal Correlation: %.3f, Normal Correlation: %.3f' % (idx, corr_cf, corr_nrm))
        if corr_cf >= corr_nrm:
            A[:, cf_count] = a_cf
            B[:, cf_count] = b_cf
            S_cancorr[cf_count] = corr_cf
            cf_count += 1
            corr_cf_update = True
        else:
            A[:, 16 + nrm_count] = a_nrm
            B[:, 16 + nrm_count] = b_nrm
            S_cancorr[16 + nrm_count] = corr_nrm
            nrm_count += 1
            corr_cf_update = False
    return A, B, S_cancorr


def sparse_cca_iterate(X, Y, A, B, sparse_idx, lr=1E-2, p=1E1, tol=1E-6, max_iter_out=2000, max_iter_in=200):
    X_Var = np.transpose(X) @ X
    Y_Var = np.transpose(Y) @ Y
    Cov = np.transpose(X) @ Y
    a = np.ones([X.shape[1], 1])
    b = np.ones([Y.shape[1], 1])
    b[sparse_idx] = 0
    #  min - a' X' Y b + ind(z' z <= 1) s.t. X a = z, A' X' X a = 0
    #  solve with ADMM
    # L_p = - np.transpose(a) @ Cov @ b + ind(z' z <= 1) + lam' ((X; A' X' X) a + (-I; 0) z) + p/2 ||(X; A' X' X) a + (-I; 0)z||_2^2
    # a-min
    # part L_p / part a = - Cov @ b + (lam' @ (X; A' X' X))' + p * (X; A' X' X)' ((X; A' X' X)a + (-I; 0)z)
    X_bar = np.vstack([X, np.transpose(A) @ X_Var])
    X_bar_t = np.transpose(X_bar)
    Y_bar = np.vstack([Y, np.transpose(B) @ Y_Var])
    Y_bar_t = np.transpose(Y_bar)
    I_bar = np.vstack([-np.eye(X.shape[0], X.shape[0]), np.zeros([A.shape[1], X.shape[0]])])
    I_bar_t = np.transpose(I_bar)
    tqdm_iterator = tqdm(range(max_iter_out), desc='Computing canonical variates')
    can_corr_old = None
    for _ in tqdm_iterator:
        a_old_outer = a.copy()
        b_old_outer = b.copy()
        if can_corr_old is None:
            can_corr_old = -2
        else:
            can_corr_old = np.transpose(a) @ Cov @ b
        z_a = X @ a
        lam_a = np.zeros([X.shape[0] + A.shape[1], 1])
        non_zero_cols = np.unique(np.nonzero(A)[1])
        # for _ in range(max_iter_in):
        #     a_old = a.copy()
        #     # prod = X_bar @ a
        #     # prod[:X.shape[0]] = prod[:X.shape[0]] - z_a
        #     # # a -= lr * (X_bar_t @ lam_a + p * X_bar_t @ (X_bar @ a + I_bar @ z_a))
        #     # a -= lr * (X_bar_t @ lam_a + p * X_bar_t @ prod)
        #     # a -= lr * (- Cov @ b)
        #     # # Proximal step, min - lr (a, Cov @ b) + (a, 1/2a - v) = min - (a, lr * Cov @ b + v - 1/2 a)
        #     # # a = 2 * lr * Cov @ b + 2 * a
        #
        #     # Project
        #     # prod = X_bar @ a
        #     # prod[:X.shape[0]] = prod[:X.shape[0]] - z_a
        #     # prod_it = - prod[:X.shape[0]]
        #     # lam_it = - lam_a[:X.shape[0]]
        #     # # z_a -= lr * (I_bar_t @ lam_a + p * I_bar_t @ prod)
        #     # z_a -= lr * (lam_it + p * prod_it)
        #     # z_norm = np.sum(z_a ** 2)
        #     # # if z_norm > 1:
        #     # z_a = z_a / z_norm
        #     #
        #     # # Update lagrangian multipliers
        #     # prod = X_bar @ a
        #     # prod[:X.shape[0]] = prod[:X.shape[0]] - z_a
        #     # lam_a = lam_a + p * prod
        #     # # print('a difference, idx: %d, %.2E' % (idx, np.sum((a-a_old) ** 2)))
        #
        #
        #     if np.sum((a - a_old) ** 2) < (tol * 1E2):
        #         break

        a = Cov @ b
        A_full = np.hstack([A[:, non_zero_cols], a])
        a_con = np.transpose(A_full) @ X_Var @ A_full
        A_full = A_full @ np.linalg.inv(np.transpose(np.linalg.cholesky(a_con)))
        a = np.expand_dims(A_full[:, -1], -1)

        z_b = Y @ b
        lam_b = np.zeros([Y.shape[0] + B.shape[1], 1])
        # I_bar = np.vstack([-np.eye(Y.shape[0], Y.shape[0]), np.zeros([B.shape[1], Y.shape[0]])])
        for _ in range(max_iter_in):
            b_old = b.copy()
            prod = Y_bar @ b
            prod[:Y.shape[0]] = prod[:Y.shape[0]] - z_b
            b -= lr * (Y_bar_t @ lam_b + p * Y_bar_t @ prod)
            b -= lr * (- np.transpose(Cov) @ a)

            # Proximal step, min - lr (a, Cov @ b) + (a, 1/2a - v) = min - (a, lr * Cov @ b + v - 1/2 a)
            # 1/2 a eps + 1/2 eps**2 + t * lr * sgn(a)(a+eps)_1, grad is 1/2 a + eps + t * lr * sgn(a+eps)
            # -1/2 eps (a + eps) + lr*tau*sgn(a)(a + eps)
            # a = 2 * lr * Cov @ b + 2 * a

            b[sparse_idx] = 0

            # Project
            prod = Y_bar @ b
            prod[:Y.shape[0]] = prod[:Y.shape[0]] - z_b
            prod_it = - prod[:Y.shape[0]]
            lam_it = - lam_b[:Y.shape[0]]
            # z_b -= lr * (I_bar_t @ lam_b + p * I_bar_t @ prod)
            z_b -= lr * (lam_it + p * prod_it)
            z_norm = np.sum(z_b ** 2)
            # if z_norm > 1:
            z_b = z_b / z_norm

            # Update lagrangian multipliers
            prod = Y_bar @ b
            prod[:Y.shape[0]] = prod[:Y.shape[0]] - z_b
            lam_b = lam_b + p * prod
            # print('b difference, idx %d, %.2E' % (idx, np.sum((b-b_old) ** 2)))
            if np.sum((b - b_old) ** 2) < tol:
                break
        can_corr_new = np.transpose(a) @ Cov @ b
        if (np.abs(can_corr_new - can_corr_old) < 1E-6) and (np.sum((a - a_old_outer) ** 2) < tol) and (np.sum((b - b_old_outer) ** 2) < tol):
            tqdm_iterator.close()
            break
    # can_corr = np.transpose(a) @ Cov @ b
    # can_corr = can_corr.squeeze()
    # can_corr /= np.sum((X @ a) ** 2) ** 0.5
    # can_corr /= np.sum((Y @ b) ** 2) ** 0.5
    non_zero_cols = np.unique(np.nonzero(A)[1])
    A_full = np.hstack([A[:, non_zero_cols], a])
    B_full = np.hstack([B[:, non_zero_cols], b])
    a_con = np.transpose(A_full) @ X_Var @ A_full
    b_con = np.transpose(B_full) @ Y_Var @ B_full
    A_full = A_full @ np.linalg.inv(np.transpose(np.linalg.cholesky(a_con)))
    B_full = B_full @ np.linalg.inv(np.transpose(np.linalg.cholesky(b_con)))
    a = np.expand_dims(A_full[:, -1], -1)
    b = np.expand_dims(B_full[:, -1], -1)
    b[sparse_idx] = 0

    # a = a / np.sum((X @ a) ** 2) ** 0.5
    b = b / np.sum((Y @ b) ** 2) ** 0.5

    can_corr = np.transpose(a) @ Cov @ b
    can_corr = can_corr.squeeze()
    a = a.squeeze()
    b = b.squeeze()
    return a, b, can_corr


def cca(x, y):
    u_0, s_0, v_0t = la.svd(x, full_matrices=False)
    u_1, s_1, v_1t = la.svd(y, full_matrices=False)
    prod = np.matmul(np.transpose(u_0), u_1)
    U, S, Vt = la.svd(prod, full_matrices=False)
    a = np.matmul(np.matmul(np.transpose(v_0t), np.diag(1 / s_0)), U)
    b = np.matmul(np.matmul(np.transpose(v_1t), np.diag(1 / s_1)), np.transpose(Vt))
    return a, b, S


def pca(x, y):
    u_x, s_x, v_xt = la.svd(x, full_matrices=False)
    u_y, s_y, v_yt = la.svd(y, full_matrices=False)
    c_xy = np.diag(s_x) @ (np.transpose(u_x) @ u_y) @ np.diag(s_y)
    u_xy, s_xy, v_xyt = la.svd(c_xy, full_matrices=False)
    a = np.transpose(v_xt) @ u_xy
    b = np.transpose(v_yt) @ np.transpose(v_xyt)
    return a, b, s_xy


def sparse_pca(x, y):
    x_mean = np.mean(x, axis=0)
    y_mean = np.mean(y, axis=0)
    x = x - x_mean
    y = y - y_mean
    x /= x.shape[0] ** 0.5
    y /= y.shape[0] ** 0.5
    a_full, b_full, s_full = pca(x, y)
    a_cf, b_cf, s_cf = pca(x, y[:, :16])
    a_nrm, b_nrm, s_nrm = pca(x, y[:, 16:])

    a_disent = np.zeros_like(a_full)
    b_disent = np.zeros_like(b_full)
    a_disent[:, :16] = a_cf
    a_disent[:, 16:] = a_nrm
    b_disent[:16, :16] = b_cf
    b_disent[16:, 16:] = b_nrm

    return a_disent, b_disent, x_mean, y_mean


def ortho_procrustes(x, y, std_flag=True):
    x_mean = np.mean(x, axis=0)
    y_mean = np.mean(y, axis=0)
    x_std = np.std(x, axis=0)
    y_std = np.std(y, axis=0)

    if std_flag:
        x = (x - x_mean) / x_std
        y = (y - y_mean) / y_std
    else:
        x = (x - x_mean)
        y = (y - y_mean)
    x /= x.shape[0] ** 0.5
    y /= y.shape[0] ** 0.5

    M = np.transpose(x) @ y
    u, _, v_t = np.linalg.svd(M, full_matrices=False)
    R = u @ v_t
    return R, x_mean, x_std, y_mean, y_std
