# additional things only included in the appendix

def get_eigenbasis(X, k):
    Xc = X - X.mean(axis=0)
    pca = PCA(n_components=k)
    pca.fit(Xc)
    eigvecs = pca.components_.T 
    eigvals = pca.explained_variance_
    return eigvecs, eigvals

def riswie_axis_matching(X, Y, k):
    # Get projections
    Xc = X - X.mean(axis=0)
    Yc = Y - Y.mean(axis=0)
    eigvecs_X, _ = get_eigenbasis(Xc, k)
    eigvecs_Y, _ = get_eigenbasis(Yc, k)
    AX = Xc @ eigvecs_X
    BY = Yc @ eigvecs_Y
    # build cost matrix and sign matrix
    C = np.zeros((k, k))
    signs = np.zeros((k, k), dtype=int)
    for i in range(k):
        a_sorted = np.sort(AX[:, i])
        for j in range(k):
            b = BY[:, j]
            b_pos_sorted = np.sort(b)
            b_neg_sorted = np.sort(-b)
            # compare distributions
            #d_pos = np.mean(np.abs(a_sorted - b_pos_sorted))
            #d_neg = np.mean(np.abs(a_sorted - b_neg_sorted))
            # uniform weights for empirical measures
            a_weights = np.ones(len(a_sorted)) / len(a_sorted)
            b_weights = np.ones(len(b)) / len(b)
            # use W2^2 
            d_pos = ot.lp.emd2_1d(a_sorted, b_pos_sorted, a_weights, b_weights)
            d_neg = ot.lp.emd2_1d(a_sorted, b_neg_sorted, a_weights, b_weights)
            if d_pos <= d_neg: # we need to keep up in a dynamic programming style way with what we chose
                C[i, j] = d_pos
                signs[i, j] = +1
            else:
                C[i, j] = d_neg
                signs[i, j] = -1
    row_ind, col_ind = linear_sum_assignment(C)
    # for each matched pair, get the sign
    best_signs = [signs[i, j] for i, j in zip(row_ind, col_ind)]
    return row_ind, col_ind, best_signs, eigvecs_X, eigvecs_Y

def compute_alignment_rotation(X, Y, k):
    row_ind, col_ind, best_signs, eigvecs_X, eigvecs_Y = riswie_axis_matching(X, Y, k)
    # reorder and apply signs
    X_basis = eigvecs_X[:, row_ind]
    Y_basis = eigvecs_Y[:, col_ind] * best_signs
    # the rotation to align the eigenbases in the best way according to the riswie cost matrix
    R = Y_basis @ X_basis.T
    return R


# the boosted rigid-invariant forms of other distanaces we define

def boosted_ot(X, Y, k=2):
    R = compute_alignment_rotation(X, Y, k)
    Xc = X - X.mean(axis=0)
    Yc = Y - Y.mean(axis=0)
    Yc_aligned = Yc @ R    
    return standard_ot(Xc, Yc_aligned)

def boosted_sw(X, Y, k=2, n_proj=64):
    R = compute_alignment_rotation(X, Y, k)
    Xc = X - X.mean(axis=0)
    Yc = Y - Y.mean(axis=0)
    Yc_aligned = Yc @ R    
    return sliced_wasserstein(Xc, Yc_aligned, n_proj=n_proj)

