import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.decomposition import PCA
from scipy.linalg import orthogonal_procrustes
from sklearn.neighbors import kneighbors_graph
from scipy.sparse import csr_matrix, diags
from scipy.sparse.linalg import eigsh
from scipy.linalg import eigh
from sklearn import datasets
import itertools
import matplotlib.pyplot as plt
import time
import pandas as pd
import os
from sklearn.manifold import SpectralEmbedding
import pacmap
from sklearn.neighbors import NearestNeighbors


import ot 


def pca_embedding(X, k=None):
    Xc = X - X.mean(axis=0)
    pca = PCA(n_components=k if k else X.shape[1])
    pca.fit(Xc)
    return pca.transform(Xc)  # returns n x k array, like diffusion_embedding


def pca_embedding_with_eigvals_and_evecs(X, k=None):
    Xc = X - X.mean(axis=0)
    k = k if k else X.shape[1]
    pca = PCA(n_components=k)
    pca.fit(Xc)
    AX = pca.transform(Xc)
    eigvals = pca.explained_variance_
    evecs = pca.components_.T  
    return AX, eigvals, evecs

def diffusion_embedding2(X, k=5, epsilon=None, t=1, symmetrize=True, dtype=np.float32):
    Xc = X - X.mean(axis=0)
    Xc = Xc.astype(dtype, copy=False)
    n = Xc.shape[0]
    k = min(k, n-1)
    n, d = Xc.shape

    n_neighbors = int(np.ceil(d * np.log(n)))
    n_neighbors = min(n_neighbors, n - 1)

    A = kneighbors_graph(
        Xc,       
        n_neighbors=n_neighbors,
        mode='distance',
        include_self=False,
        metric='euclidean',
        n_jobs=1
    ).tocsr()
    
    A = A.tocsr().astype(dtype)

    if epsilon is None:
        eps_base = np.median((A.data ** 2)) if A.nnz > 0 else 1.0
        epsilon = np.float32(eps_base if eps_base > 0 else 1.0)

    W = A.copy()
    W.data = np.exp(-(W.data ** 2) / epsilon).astype(dtype, copy=False)

    if symmetrize:
        W = (W + W.T) * (dtype(0.5))

    d = np.asarray(W.sum(axis=1)).ravel()
    dinv = 1.0 / np.sqrt(d + 1e-12)
    Dinv = diags(dinv.astype(dtype))
    P_tilde = Dinv @ W @ Dinv

    vals, vecs = eigsh(P_tilde, k=k+1, which='LM')
    order = np.argsort(vals)[::-1]
    vals = vals[order]
    vecs = vecs[:, order]

    lambdas = vals[1:k+1]
    phis = vecs[:, 1:k+1]

    emb = phis * (lambdas ** t)
    return emb

def diffusion_embedding(X, k=5, epsilon=None, t=1):
    Xc = X - X.mean(axis=0)
    n = Xc.shape[0]
    D_sq = cdist(Xc, Xc, 'sqeuclidean')
    if epsilon is None:
        epsilon = np.median(D_sq)
    K = np.exp(-D_sq / epsilon)
    d = np.sum(K, axis=1)
    D_inv_sqrt = np.diag(1.0 / np.sqrt(d))
    P_tilde = D_inv_sqrt @ K @ D_inv_sqrt

    if k >= n - 1: # full evd
        vals, vecs = eigh(P_tilde)
        order = np.argsort(vals)[::-1]
        vals = vals[order]
        vecs = vecs[:, order]
        # remove the trivial eigenvector 
        lambdas = vals[1:k+1]
        phis = vecs[:, 1:k+1]
    else:
        vals, vecs = eigsh(P_tilde, k=k+1, which='LM')
        order = np.argsort(vals)[::-1]
        vals = vals[order]
        vecs = vecs[:, order]
        lambdas = vals[1:k+1]
        phis = vecs[:, 1:k+1]

    embedding_matrix = phis * (lambdas ** t)
    return embedding_matrix



def get_sorted_columns(A): # A: n x k array
    # returns list of length k; each is sorted 1D array of size n
    return [np.sort(A[:, i]) for i in range(A.shape[1])]

from ot.lp import emd_1d_sorted

# accessing internal method to save time
def emd2_1d_nosort(x_sorted, y_sorted, a_sorted=None, b_sorted=None, metric="sqeuclidean", p=2.0):
    n, m = len(x_sorted), len(y_sorted)
    if a_sorted is None:
        a_sorted = np.ones(n, dtype=float)/n
    if b_sorted is None:
        b_sorted = np.ones(m, dtype=float)/m

    Gs, indices, cost = emd_1d_sorted(
        a_sorted.astype(np.float64),
        b_sorted.astype(np.float64),
        x_sorted.astype(np.float64),
        y_sorted.astype(np.float64),
        metric=metric, p=p
    )
    return cost


def compute_1d_wasserstein_sorted(u_sorted, v_sorted): 
    n, m = len(u_sorted), len(v_sorted)
    if True and n == m: # uniform weights
        return np.mean((u_sorted - v_sorted) ** 2)
    else:
        a = np.ones(n) / n
        b = np.ones(m) / m
        return emd2_1d_nosort(u_sorted, v_sorted, a, b, metric="sqeuclidean", p=2.0)
    #return (ot.lp.emd2_1d(u_sorted, v_sorted, a, b)) # W2^2

def assignment_sliced_wasserstein(X, Y, AX, BY, eigvals_X=None, eigvals_Y=None, eigvecs_X=None, eigvecs_Y=None, print_pairings=False):
    k = AX.shape[1]
    AX_sorted = get_sorted_columns(AX) # save a linear factor by sorting early
    BY_sorted = get_sorted_columns(BY)
    C = np.zeros((k, k))
    for i in range(k):
        a_sorted = AX_sorted[i]
        for j in range(k):
            b_sorted = BY_sorted[j]
            # positive‐sign cost (both arrays already ascending)
            d_pos = compute_1d_wasserstein_sorted(a_sorted, b_sorted)
            # negative‐sign cost: reverse the sorted array before negating
            b_neg_sorted = -b_sorted[::-1]
            d_neg = compute_1d_wasserstein_sorted(a_sorted, b_neg_sorted)
            C[i, j] = min(d_pos, d_neg)
    row_ind, col_ind = linear_sum_assignment(C)
    total_cost = C[row_ind, col_ind].sum()

    return np.sqrt((1 / k) * total_cost) # matching units


# a weighted version - not used in experiments but left for completeness
def assignment_weighted_after(X, Y, AX, BY, eigvals_X, eigvals_Y, eigvecs_X=None, eigvecs_Y=None, weight_type='geometric', print_pairings=False):
    k = AX.shape[1]
    AX_sorted = get_sorted_columns(AX)
    BY_sorted = get_sorted_columns(BY)
    C = np.zeros((k, k))
    for i in range(k):
        a_sorted = AX_sorted[i]
        for j in range(k):
            b_sorted = BY_sorted[j]
            d_pos = compute_1d_wasserstein_sorted(a_sorted, b_sorted)
            b_neg_sorted = -b_sorted[::-1]
            d_neg = compute_1d_wasserstein_sorted(a_sorted, b_neg_sorted)
            C[i, j] = min(d_pos, d_neg)
    row_ind, col_ind = linear_sum_assignment(C)
    costs = C[row_ind, col_ind]
    if weight_type == 'geometric':
        weights = np.sqrt(eigvals_X[row_ind] * eigvals_Y[col_ind])
    elif weight_type == 'arithmetic':
        weights = 0.5 * (eigvals_X[row_ind] + eigvals_Y[col_ind])
    weights_sum = weights.sum()
    weights = weights / weights_sum if weights_sum > 0 else np.ones_like(weights)/k
    total_cost = (weights * costs).sum()
    return np.sqrt(total_cost)

def softmin_sign(C_pos, C_neg, beta=10.0):
    # soft-min over sign
    w = 1.0 / (1.0 + np.exp(beta * (C_pos - C_neg)))
    return w * C_pos + (1 - w) * C_neg

# "soft" riswie
def assignment_sliced_wasserstein_soft_sinkhorn(X, Y, AX, BY, epsilon=0.08, beta=5.0, n_iter=50000):
    k = AX.shape[1]
    AX_sorted = get_sorted_columns(AX)
    BY_sorted = get_sorted_columns(BY)
    C_pos = np.zeros((k, k))
    C_neg = np.zeros((k, k))
    for i in range(k):
        a_sorted = AX_sorted[i]
        for j in range(k):
            b_sorted = BY_sorted[j]
            d_pos = compute_1d_wasserstein_sorted(a_sorted, b_sorted)
            b_neg_sorted = -b_sorted[::-1]
            d_neg = compute_1d_wasserstein_sorted(a_sorted, b_neg_sorted)
            C_pos[i, j] = d_pos
            C_neg[i, j] = d_neg

    C_soft = softmin_sign(C_pos, C_neg, beta=beta)

    a = np.ones(k) / k
    b = np.ones(k) / k

    # use pot's sinkhorn algorithm to solve
    P = ot.sinkhorn(a, b, C_soft, reg=epsilon, numItermax=n_iter)
    total_cost = np.sum(P * C_soft)
    return np.sqrt(total_cost / k)

def gromov_wasserstein(X, Y):
    C1 = cdist(X, X)  # pairwise distance matrix for X
    C2 = cdist(Y, Y)  # pairwise distance matrix for Y
    p = np.ones(len(X)) / len(X)
    q = np.ones(len(Y)) / len(Y)
    gw_cost = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', max_iter=100000000)
    return np.sqrt(np.maximum(gw_cost, 0.0)) # some numerical details


def standard_ot(X, Y):
    if X.shape[1] != Y.shape[1]: # need to live in same dimensional space
        raise ValueError(f"OT: Feature dimension mismatch: {X.shape[1]} vs {Y.shape[1]}")
    # same deal as GW
    C = cdist(X, Y)
    p = np.ones(len(X)) / len(X)
    q = np.ones(len(Y)) / len(Y)
    T = ot.emd(p, q, C, numItermax=100000000)
    ot_cost = (C * T).sum()
    return np.sqrt(ot_cost) 



def sliced_wasserstein(X, Y, n_proj=100, p=2, seed=None):
    if X.shape[1] != Y.shape[1]:
        raise ValueError(f"SW: Feature dimension mismatch: {X.shape[1]} vs {Y.shape[1]}")
    rng = np.random.default_rng(seed)
    n, d = X.shape
    m = Y.shape[0]
    dists = []
    for _ in range(n_proj):
        v = rng.normal(size=d)
        v /= np.linalg.norm(v)  # unit vector

        # project onto some multiple of the axes
        X_proj = np.dot(X, v)
        Y_proj = np.dot(Y, v)

        X_proj_sorted = np.sort(X_proj)
        Y_proj_sorted = np.sort(Y_proj)
        # we do equal sizes here, which isn't an issue because in our experiments we are always subsampling
        n_used = min(len(X_proj_sorted), len(Y_proj_sorted))
        Xs = X_proj_sorted[:n_used]
        Ys = Y_proj_sorted[:n_used]
        dist = np.mean(np.abs(Xs - Ys) ** p)
        dists.append(dist)

    return np.mean(dists)**(1/p)
