import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.decomposition import PCA
from scipy.linalg import orthogonal_procrustes
from sklearn.neighbors import kneighbors_graph
from scipy.sparse import csr_matrix, diags
from scipy.sparse.linalg import eigsh
from scipy.linalg import eigh
from sklearn import datasets
import itertools
import matplotlib.pyplot as plt
import time
import pandas as pd
import os
from sklearn.manifold import SpectralEmbedding
import pacmap
from sklearn.neighbors import NearestNeighbors


import ot 


def pca_embedding(X, k=None):
    Xc = X - X.mean(axis=0)
    pca = PCA(n_components=k if k else X.shape[1])
    pca.fit(Xc)
    return pca.transform(Xc)  # returns n x k array, like diffusion_embedding


def pca_embedding_with_eigvals_and_evecs(X, k=None):
    Xc = X - X.mean(axis=0)
    k = k if k else X.shape[1]
    pca = PCA(n_components=k)
    pca.fit(Xc)
    AX = pca.transform(Xc)
    eigvals = pca.explained_variance_
    evecs = pca.components_.T  
    return AX, eigvals, evecs

def diffusion_embedding2(X, k=5, epsilon=None, t=1, symmetrize=True, dtype=np.float32):
    Xc = X - X.mean(axis=0)
    Xc = Xc.astype(dtype, copy=False)
    n = Xc.shape[0]
    k = min(k, n-1)
    n, d = Xc.shape

    n_neighbors = int(np.ceil(d * np.log(n)))
    n_neighbors = min(n_neighbors, n - 1)

    A = kneighbors_graph(
        Xc,       
        n_neighbors=n_neighbors,
        mode='distance',
        include_self=False,
        metric='euclidean',
        n_jobs=1
    ).tocsr()
    
    A = A.tocsr().astype(dtype)

    if epsilon is None:
        eps_base = np.median((A.data ** 2)) if A.nnz > 0 else 1.0
        epsilon = np.float32(eps_base if eps_base > 0 else 1.0)

    W = A.copy()
    W.data = np.exp(-(W.data ** 2) / epsilon).astype(dtype, copy=False)

    if symmetrize:
        W = (W + W.T) * (dtype(0.5))

    d = np.asarray(W.sum(axis=1)).ravel()
    dinv = 1.0 / np.sqrt(d + 1e-12)
    Dinv = diags(dinv.astype(dtype))
    P_tilde = Dinv @ W @ Dinv

    vals, vecs = eigsh(P_tilde, k=k+1, which='LM')
    order = np.argsort(vals)[::-1]
    vals = vals[order]
    vecs = vecs[:, order]

    lambdas = vals[1:k+1]
    phis = vecs[:, 1:k+1]

    emb = phis * (lambdas ** t)
    return emb

def diffusion_embedding(X, k=5, epsilon=None, t=1):
    Xc = X - X.mean(axis=0)
    n = Xc.shape[0]
    D_sq = cdist(Xc, Xc, 'sqeuclidean')
    if epsilon is None:
        epsilon = np.median(D_sq)
    K = np.exp(-D_sq / epsilon)
    d = np.sum(K, axis=1)
    D_inv_sqrt = np.diag(1.0 / np.sqrt(d))
    P_tilde = D_inv_sqrt @ K @ D_inv_sqrt

    if k >= n - 1: # full evd
        vals, vecs = eigh(P_tilde)
        order = np.argsort(vals)[::-1]
        vals = vals[order]
        vecs = vecs[:, order]
        # remove the trivial eigenvector 
        lambdas = vals[1:k+1]
        phis = vecs[:, 1:k+1]
    else:
        vals, vecs = eigsh(P_tilde, k=k+1, which='LM')
        order = np.argsort(vals)[::-1]
        vals = vals[order]
        vecs = vecs[:, order]
        lambdas = vals[1:k+1]
        phis = vecs[:, 1:k+1]

    embedding_matrix = phis * (lambdas ** t)
    return embedding_matrix



def get_sorted_columns(A): # A: n x k array
    # returns list of length k; each is sorted 1D array of size n
    return [np.sort(A[:, i]) for i in range(A.shape[1])]

from ot.lp import emd_1d_sorted

# accessing internal method to save time
def emd2_1d_nosort(x_sorted, y_sorted, a_sorted=None, b_sorted=None, metric="sqeuclidean", p=2.0):
    n, m = len(x_sorted), len(y_sorted)
    if a_sorted is None:
        a_sorted = np.ones(n, dtype=float)/n
    if b_sorted is None:
        b_sorted = np.ones(m, dtype=float)/m

    Gs, indices, cost = emd_1d_sorted(
        a_sorted.astype(np.float64),
        b_sorted.astype(np.float64),
        x_sorted.astype(np.float64),
        y_sorted.astype(np.float64),
        metric=metric, p=p
    )
    return cost


def compute_1d_wasserstein_sorted(u_sorted, v_sorted): 
    n, m = len(u_sorted), len(v_sorted)
    if True and n == m: # uniform weights
        return np.mean((u_sorted - v_sorted) ** 2)
    else:
        a = np.ones(n) / n
        b = np.ones(m) / m
        return emd2_1d_nosort(u_sorted, v_sorted, a, b, metric="sqeuclidean", p=2.0)
    #return (ot.lp.emd2_1d(u_sorted, v_sorted, a, b)) # W2^2

def assignment_sliced_wasserstein(X, Y, AX, BY, eigvals_X=None, eigvals_Y=None, eigvecs_X=None, eigvecs_Y=None, print_pairings=False):
    k = AX.shape[1]
    AX_sorted = get_sorted_columns(AX) # save a linear factor by sorting early
    BY_sorted = get_sorted_columns(BY)
    C = np.zeros((k, k))
    for i in range(k):
        a_sorted = AX_sorted[i]
        for j in range(k):
            b_sorted = BY_sorted[j]
            # positive‐sign cost (both arrays already ascending)
            d_pos = compute_1d_wasserstein_sorted(a_sorted, b_sorted)
            # negative‐sign cost: reverse the sorted array before negating
            b_neg_sorted = -b_sorted[::-1]
            d_neg = compute_1d_wasserstein_sorted(a_sorted, b_neg_sorted)
            C[i, j] = min(d_pos, d_neg)
    row_ind, col_ind = linear_sum_assignment(C)
    total_cost = C[row_ind, col_ind].sum()

    return np.sqrt((1 / k) * total_cost) # matching units


# a weighted version - not used in experiments but left for completeness
def assignment_weighted_after(X, Y, AX, BY, eigvals_X, eigvals_Y, eigvecs_X=None, eigvecs_Y=None, weight_type='geometric', print_pairings=False):
    k = AX.shape[1]
    AX_sorted = get_sorted_columns(AX)
    BY_sorted = get_sorted_columns(BY)
    C = np.zeros((k, k))
    for i in range(k):
        a_sorted = AX_sorted[i]
        for j in range(k):
            b_sorted = BY_sorted[j]
            d_pos = compute_1d_wasserstein_sorted(a_sorted, b_sorted)
            b_neg_sorted = -b_sorted[::-1]
            d_neg = compute_1d_wasserstein_sorted(a_sorted, b_neg_sorted)
            C[i, j] = min(d_pos, d_neg)
    row_ind, col_ind = linear_sum_assignment(C)
    costs = C[row_ind, col_ind]
    if weight_type == 'geometric':
        weights = np.sqrt(eigvals_X[row_ind] * eigvals_Y[col_ind])
    elif weight_type == 'arithmetic':
        weights = 0.5 * (eigvals_X[row_ind] + eigvals_Y[col_ind])
    weights_sum = weights.sum()
    weights = weights / weights_sum if weights_sum > 0 else np.ones_like(weights)/k
    total_cost = (weights * costs).sum()
    return np.sqrt(total_cost)

def softmin_sign(C_pos, C_neg, beta=10.0):
    # soft-min over sign
    w = 1.0 / (1.0 + np.exp(beta * (C_pos - C_neg)))
    return w * C_pos + (1 - w) * C_neg

# "soft" riswie
def assignment_sliced_wasserstein_soft_sinkhorn(X, Y, AX, BY, epsilon=0.08, beta=5.0, n_iter=50000):
    k = AX.shape[1]
    AX_sorted = get_sorted_columns(AX)
    BY_sorted = get_sorted_columns(BY)
    C_pos = np.zeros((k, k))
    C_neg = np.zeros((k, k))
    for i in range(k):
        a_sorted = AX_sorted[i]
        for j in range(k):
            b_sorted = BY_sorted[j]
            d_pos = compute_1d_wasserstein_sorted(a_sorted, b_sorted)
            b_neg_sorted = -b_sorted[::-1]
            d_neg = compute_1d_wasserstein_sorted(a_sorted, b_neg_sorted)
            C_pos[i, j] = d_pos
            C_neg[i, j] = d_neg

    C_soft = softmin_sign(C_pos, C_neg, beta=beta)

    a = np.ones(k) / k
    b = np.ones(k) / k

    # use pot's sinkhorn algorithm to solve
    P = ot.sinkhorn(a, b, C_soft, reg=epsilon, numItermax=n_iter)
    total_cost = np.sum(P * C_soft)
    return np.sqrt(total_cost / k)

def gromov_wasserstein(X, Y):
    C1 = cdist(X, X)  # pairwise distance matrix for X
    C2 = cdist(Y, Y)  # pairwise distance matrix for Y
    p = np.ones(len(X)) / len(X)
    q = np.ones(len(Y)) / len(Y)
    gw_cost = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', max_iter=100000000)
    return np.sqrt(np.maximum(gw_cost, 0.0)) # some numerical details


def standard_ot(X, Y):
    if X.shape[1] != Y.shape[1]: # need to live in same dimensional space
        raise ValueError(f"OT: Feature dimension mismatch: {X.shape[1]} vs {Y.shape[1]}")
    # same deal as GW
    C = cdist(X, Y)
    p = np.ones(len(X)) / len(X)
    q = np.ones(len(Y)) / len(Y)
    T = ot.emd(p, q, C, numItermax=100000000)
    ot_cost = (C * T).sum()
    return np.sqrt(ot_cost) 



def sliced_wasserstein(X, Y, n_proj=100, p=2, seed=None):
    if X.shape[1] != Y.shape[1]:
        raise ValueError(f"SW: Feature dimension mismatch: {X.shape[1]} vs {Y.shape[1]}")
    rng = np.random.default_rng(seed)
    n, d = X.shape
    m = Y.shape[0]
    dists = []
    for _ in range(n_proj):
        v = rng.normal(size=d)
        v /= np.linalg.norm(v)  # unit vector

        # project onto some multiple of the axes
        X_proj = np.dot(X, v)
        Y_proj = np.dot(Y, v)

        X_proj_sorted = np.sort(X_proj)
        Y_proj_sorted = np.sort(Y_proj)
        # we do equal sizes here, which isn't an issue because in our experiments we are always subsampling
        n_used = min(len(X_proj_sorted), len(Y_proj_sorted))
        Xs = X_proj_sorted[:n_used]
        Ys = Y_proj_sorted[:n_used]
        dist = np.mean(np.abs(Xs - Ys) ** p)
        dists.append(dist)

    return np.mean(dists)**(1/p)

## START NONE BOILERPLATE CODE HERE

import pandas as pd
import numpy as np
from itertools import combinations
from collections import Counter
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
import seaborn as sns

df = pd.read_csv("Cells/j.csv")
df = df[df["cell_type"].notna()]
df = df[df["region_new"].isin(range(1, 9))]
FEATURE_COLS = ["Xcorr", "Ycorr"]
N_PER_REGION = 50

all_regions = {}
region_keys = []
idx = 0
for file in sorted(df["filename"].unique()):
    for r in range(1, 9):
        sub = df[(df["filename"] == file) & (df["region_new"] == r)]
        X_full = sub[FEATURE_COLS].values
        if len(X_full) >= N_PER_REGION:
            idxs = np.random.RandomState(0).choice(len(X_full), size=N_PER_REGION, replace=False)
            X = X_full[idxs]
        else:
            X = X_full
        mean_before = X.mean(axis=0) # center the samples - pretend you just got them
        X_centered = X - mean_before
        mean_after = X_centered.mean(axis=0)
        print(f"Tissue {file[-6:]}, R{r} | Mean before: {mean_before.round(3)}  after: {mean_after.round(3)}")
        all_regions[idx] = X_centered
        region_keys.append((file, r))
        idx += 1

assert len(all_regions) == 48

def apply_random_2d_rotation(region_dict, seed=0):
    rng = np.random.default_rng(seed)
    rotated = {}
    for r, X in region_dict.items():
        print(f"Tissue {r}: mean before rotation = {X.mean(axis=0)}")
        angle = rng.uniform(0, 360)
        theta = np.deg2rad(angle)
        R = np.array([[np.cos(theta), -np.sin(theta)],
                      [np.sin(theta),  np.cos(theta)]])
        rotated[r] = X @ R.T
        print(f"Tissue {r}: mean after rotation  = {rotated[r].mean(axis=0)}\n")
    return rotated

all_regions = apply_random_2d_rotation(all_regions)


import pandas as pd
import numpy as np
from itertools import combinations
from collections import Counter
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import pdist, squareform
from sklearn.decomposition import PCA

def pca_embedding_with_eigvals_and_evecs(X, k=None):
    Xc = X - X.mean(axis=0)
    k = k if k else X.shape[1]
    pca = PCA(n_components=k)
    pca.fit(Xc)
    AX = pca.transform(Xc)
    eigvals = pca.explained_variance_
    evecs = pca.components_.T
    return AX, eigvals, evecs

def riswie_distance(X, Y, k=50):
    AX = pca_embedding(X)
    BY = pca_embedding(Y)
    x = assignment_sliced_wasserstein(X, Y, AX, BY)
    return x

D = np.zeros((48, 48))
for i in tqdm(range(48), desc="Computing RISWIE"):
    for j in range(i + 1, 48):
        D[i, j] = riswie_distance(all_regions[i], all_regions[j])
        D[j, i] = D[i, j]
        
def farthest_point_seeds(D, n_buckets=8, first_idx=0):
    seeds = [first_idx]
    candidates = set(range(D.shape[0])) - {first_idx}
    while len(seeds) < n_buckets:
        min_dists = {idx: min(D[idx, s] for s in seeds) for idx in candidates}
        next_seed = max(min_dists, key=min_dists.get)
        seeds.append(next_seed)
        candidates.remove(next_seed)
    return seeds

def globally_greedy_balanced_assignment(D, seeds, bucket_size=6, n_buckets=8):
    n = D.shape[0]
    assigned = np.full(n, -1)
    buckets = [[] for _ in range(n_buckets)]
    for k, s in enumerate(seeds):
        buckets[k].append(s)
        assigned[s] = k
    remaining = set(range(n)) - set(seeds)
    while remaining:
        cost_matrix = np.full((len(remaining), n_buckets), np.inf)
        idx_list = list(remaining)
        for i, idx in enumerate(idx_list):
            for j in range(n_buckets):
                if len(buckets[j]) < bucket_size:
                    cost_matrix[i, j] = sum(D[idx, b] for b in buckets[j])
        min_i, min_j = np.unravel_index(np.argmin(cost_matrix), cost_matrix.shape)
        idx_to_assign = idx_list[min_i]
        buckets[min_j].append(idx_to_assign)
        assigned[idx_to_assign] = min_j
        remaining.remove(idx_to_assign)
    return buckets

def stack_total_within_sum(buckets, D):
    return sum(
        sum(D[i, j] for i, j in combinations(bucket, 2))
        for bucket in buckets
    )

# run over all possible initial seeds and build a base for each stack
all_assignments = []
all_scores = []
print("Running farthest-point seeding + global greedy for all 48 starts...")
for start in tqdm(range(48)):
    seeds = farthest_point_seeds(D, n_buckets=8, first_idx=start)
    buckets = globally_greedy_balanced_assignment(D, seeds, bucket_size=6, n_buckets=8)
    score = stack_total_within_sum(buckets, D)
    all_assignments.append(buckets)
    all_scores.append(score)

# also trying a lot more random inits of bases if greedy seeding is too greedy
for _ in range(10): # we run with 10000 for robustness in our experiments
    seeds = np.random.choice(np.arange(48), size=8, replace=False).tolist()
    buckets = globally_greedy_balanced_assignment(D, seeds, bucket_size=6, n_buckets=8)
    score = stack_total_within_sum(buckets, D)
    all_assignments.append(buckets)
    all_scores.append(score)


best_idx = int(np.argmin(all_scores))
stacks = all_assignments[best_idx]
print(f"\nBest assignment (total within-stack sum): {all_scores[best_idx]:.2f}")

regions = sorted(set([region for (_, region) in region_keys])) # ground truth
true_stacks = [
    [i for i, (_, region) in enumerate(region_keys) if region == r]
    for r in regions
]

# hungarian since we don't care about the order of the stacks - this will give us a faithful accuracy
cost = np.zeros((8, 8))
for i, pred_stack in enumerate(stacks):
    for j, true_stack in enumerate(true_stacks):
        cost[i, j] = -len(set(pred_stack) & set(true_stack))
row_ind, col_ind = linear_sum_assignment(cost)

total_correct = 0
print("\n--- STACK ASSIGNMENTS ALIGNED TO REGIONS (HUNGARIAN MATCHED) ---")
for pred_k, true_k in zip(row_ind, col_ind):
    region_ids = stacks[pred_k]
    files = [region_keys[i][0] for i in region_ids]
    counter = Counter(files)
    true_region = regions[true_k]
    n_overlap = len(set(region_ids) & set(true_stacks[true_k]))
    total_correct += n_overlap
    print(f"Stack {pred_k} (matched to region {true_region}):")
    for i in region_ids:
        print(f"  {region_keys[i]}")
    print(f"  → Overlap with true region: {n_overlap}/6\n")
print(f"\nTotal correct: {total_correct}/48\n")

conf_mat = np.zeros((8, 8), dtype=int)
region_to_idx = {r: j for j, r in enumerate(regions)}
for pred_k, true_k in zip(row_ind, col_ind):
    region_ids = stacks[pred_k]
    for i in region_ids:
        region = region_keys[i][1]
        j = region_to_idx[region]
        conf_mat[pred_k, j] += 1

plt.figure(figsize=(8, 7))
sns.heatmap(conf_mat, annot=True, fmt="d", cmap="Blues", cbar=False,
            xticklabels=[str(r) for r in regions],
            yticklabels=[f"Stack {i}" for i in range(8)])
plt.xlabel("Ground Truth Region")
plt.ylabel("Predicted Stack")
plt.title("Stack/Region Overlap")
plt.tight_layout()
plt.show()

# perform this same feature for more distance functions, we include just RISWIE as an example
