import time
import warnings

from abc import ABC, abstractmethod
from typing import Tuple, Optional, Union, Callable, Any

import cupy as cp
import cupyx.scipy.sparse as cpsp
import cupyx.scipy.sparse.linalg as cpla
import numba
import numpy as np

from cuml.neighbors import NearestNeighbors
from numba import cuda, uint64, uint32, int32
from scipy.sparse.csgraph import connected_components
from sklearn.manifold import spectral_embedding
from tqdm import tqdm


def _cuml_gpu_knn(
    X: np.ndarray, k: int, batch_size: int = 10000
) -> tuple[np.ndarray, np.ndarray]:
    """
    cuML GPU–based k-NN search with batch processing
    """
    X_gpu = cp.asarray(X.astype(np.float32))
    n, d = X_gpu.shape

    if n < 100_000:
        algo = "brute"
        algo_params = {}
    else:
        nlist = min(int(4 * np.sqrt(n)), 4096)
        nprobe = min(nlist // 2, 128)
        algo = "ivfflat"
        algo_params = {"nlist": nlist, "nprobe": nprobe}

    knn = NearestNeighbors(
        n_neighbors=k,
        algorithm=algo,
        metric="euclidean",
        output_type="cupy",
        algo_params=algo_params,
    )
    knn.fit(X_gpu)

    all_dists = cp.empty((n, k), dtype=cp.float32)
    all_inds = cp.empty((n, k), dtype=cp.int32)

    print("Building kNN graph...")
    for start in tqdm(range(0, n, batch_size)):
        end = min(start + batch_size, n)
        batch = X_gpu[start:end]
        dists, inds = knn.kneighbors(batch, n_neighbors=k)
        all_dists[start:end] = dists
        all_inds[start:end] = inds

    distances = cp.asnumpy(all_dists).astype(np.float64)
    indices = cp.asnumpy(all_inds).astype(np.int64)

    return distances, indices


def _cuda_self_tuning_sigma(
    rho: np.ndarray,
    dists: np.ndarray,
    k: int,
    kernel_function,
    kernel_params=None,
    tol: float = 1e-5,
    max_iter: int = 1000,
) -> np.ndarray:
    """
    GPU-accelerated self-tuning sigma computation using CuPy.

    Parameters
    ----------
    rho : np.ndarray
        Array of rho values (shape: n_samples,)
    dists : np.ndarray
        Distance matrix (shape: n_samples, k-1)
    k : int
        Number of neighbors
    kernel_function : callable
        Function that takes (r, params) where r is the normalized distance array.
        Must work with CuPy arrays if using GPU acceleration.
    kernel_params : any
        Parameters to pass to the kernel function
    tol : float
        Tolerance for convergence
    max_iter : int
        Maximum iterations for binary search

    Returns
    -------
    sigmas : np.ndarray
        Computed sigma values for each point
    """
    # Transfer to GPU
    rho_gpu = cp.asarray(rho, dtype=cp.float32)
    dists_gpu = cp.asarray(dists, dtype=cp.float32)

    N, M = dists_gpu.shape
    sigmas_gpu = cp.ones(N, dtype=cp.float32)

    # Compute target value (log2(k))
    target = cp.float32(np.log(k) / np.log(2))

    # Initialize bounds for binary search
    lo = cp.zeros(N, dtype=cp.float32)
    hi = cp.maximum(cp.max(dists_gpu, axis=1), 1.0)

    # Binary search loop
    for iteration in range(max_iter):
        # Compute differences
        diff = dists_gpu - rho_gpu[:, cp.newaxis]
        diff = cp.maximum(diff, 0.0)

        # Normalize by sigma
        normalized = diff / sigmas_gpu[:, cp.newaxis]

        # Apply kernel function
        weights = kernel_function(normalized, kernel_params)

        # Sum weights for each point
        psum = cp.sum(weights, axis=1)

        # Check convergence
        converged = cp.abs(psum - target) < tol
        if cp.all(converged):
            break

        # Update bounds
        too_high = psum > target
        too_low = ~too_high

        # Update hi where psum is too high
        hi = cp.where(too_high, sigmas_gpu, hi)

        # Update lo where psum is too low
        lo = cp.where(too_low, sigmas_gpu, lo)

        # Compute new sigma values
        # For points where lo > 0, use midpoint
        has_bounds = lo > 0
        sigmas_gpu = cp.where(
            has_bounds,
            0.5 * (lo + hi),
            cp.where(too_high, sigmas_gpu * 0.5, sigmas_gpu * 2.0),
        )

    # Ensure minimum sigma value
    sigmas_gpu = cp.maximum(sigmas_gpu, 1e-8)

    # Transfer back to CPU
    return sigmas_gpu


# Example kernel functions that users can define
def exponential_kernel(r, params=None):
    """Exponential kernel: exp(-r)"""
    return cp.exp(-r)


def gaussian_kernel(r, params=None):
    """Gaussian kernel: exp(-0.5 * r^2)"""
    return cp.exp(-0.5 * r * r)


def cauchy_kernel(r, params=None):
    """Cauchy kernel: 1 / (1 + r^2)"""
    return 1.0 / (1.0 + r * r)


def student_t_kernel(r, params):
    """Student-t kernel: (1 + r^2/nu)^(-(nu+1)/2)"""
    nu = params if params is not None else 1.0
    return cp.power(1.0 + r * r / nu, -(nu + 1.0) / 2.0)


def _cuda_spectral_init(W, n_components=2, check_connectivity=True):
    """GPU-accelerated spectral initialization using CuPy."""
    n = W.shape[0]

    # Check if W is already a CuPy sparse matrix
    is_gpu_matrix = hasattr(W, "__module__") and "cupy" in W.__module__

    # Check connectivity on CPU (if needed)
    if check_connectivity:
        if is_gpu_matrix:
            # Transfer to CPU for connectivity check
            W_cpu = W.get()  # CuPy sparse matrices have a .get() method
        else:
            W_cpu = W

        n_cc, labels = connected_components(W_cpu, directed=False, return_labels=True)
        if n_cc > 1:
            warnings.warn(
                f"Graph has {n_cc} connected components. "
                "Embedding will mix components."
            )

    # Handle GPU matrix conversion
    if is_gpu_matrix:
        # Already on GPU - just ensure it's float32
        if W.dtype != np.float32:
            W_gpu = W.astype(np.float32)
        else:
            W_gpu = W
    else:
        # Transfer to GPU
        W_gpu = cpsp.csr_matrix(W.astype(np.float32))

    # Rest of the function remains the same...
    # Compute degree matrix
    degrees = cp.asarray(W_gpu.sum(axis=1)).flatten()

    # Handle isolated nodes
    isolated = degrees == 0
    degrees[isolated] = 1

    # D^(-1/2)
    d_sqrt_inv = cpsp.diags(1.0 / cp.sqrt(degrees), format="csr")

    # Normalized affinity matrix
    W_norm = d_sqrt_inv @ W_gpu @ d_sqrt_inv

    try:
        # Request one extra eigenvector to skip the trivial one
        k = n_components + 1

        # Get largest eigenvalues
        eigenvals, eigenvecs = cpla.eigsh(
            W_norm, k=k, which="LA", maxiter=300, tol=1e-4
        )

        # Sort by eigenvalue (largest first)
        idx = cp.argsort(eigenvals)[::-1]
        eigenvals = eigenvals[idx]
        eigenvecs = eigenvecs[:, idx]

        # Skip the first (trivial) eigenvector
        embedding_gpu = eigenvecs[:, 1 : n_components + 1]

        # Scale by sqrt(degree)
        embedding_gpu = embedding_gpu * cp.sqrt(degrees)[:, None]

    except Exception as e:
        raise RuntimeError(f"GPU eigensolver failed: {e}")

    # Transfer back to CPU
    embedding = cp.asnumpy(embedding_gpu).astype(np.float64)

    # Fix sign ambiguity
    for i in range(n_components):
        if embedding[0, i] < 0:
            embedding[:, i] *= -1

    return embedding, cp.asnumpy(eigenvals)


def cpu_spectral_init(W, n_components=2):
    """CPU reference implementation using sklearn."""
    # sklearn's spectral_embedding does essentially the same thing
    embedding = spectral_embedding(
        W,
        n_components=n_components,
        drop_first=False,  # We'll manually drop first
        norm_laplacian=True,
    )

    # Drop first component (sklearn returns all components)
    if embedding.shape[1] > n_components:
        embedding = embedding[:, 1 : n_components + 1]

    return embedding


def _cuda_compute_weights_vectorized(
    neigh_idx,
    neigh_dist,
    rho,
    sigmas,
    sym_code,
    n_samples,
    k,
    kernel_function,
    kernel_params=None,
):
    """
    GPU-accelerated computation of graph edge weights with optional Sinkhorn-Knopp normalization.

    Parameters
    ----------
    neigh_idx : np.ndarray
        Neighbor indices (n_samples, k-1)
    neigh_dist : np.ndarray
        Distances to neighbors (n_samples, k-1)
    rho : np.ndarray
        Local connectivity distance for each point
    sigmas : np.ndarray
        Bandwidth for each point
    sym_code : int
        Symmetrization mode: 0=mean, 1=max, 2=umap, etc.
    n_samples : int
        Number of samples
    k : int
        Number of neighbors (including self)
    kernel_function : callable
        Function that takes (r, params) where r is the normalized distance array.
        Must work with CuPy arrays if using GPU acceleration.
    kernel_params : any
        Parameters to pass to the kernel function
    sinkhorn_knopp : bool
        Whether to apply Sinkhorn-Knopp normalization
    c : float
        Target sum for Sinkhorn-Knopp
    max_iter : int
        Maximum iterations for Sinkhorn-Knopp
    tol : float
        Convergence tolerance

    Returns
    -------
    rows, cols, vals : np.ndarray
        COO format sparse matrix data
    """
    # Transfer to GPU
    neigh_idx_gpu = cp.asarray(neigh_idx, dtype=cp.int32)
    neigh_dist_gpu = cp.asarray(neigh_dist, dtype=cp.float32)
    rho_gpu = cp.asarray(rho, dtype=cp.float32)
    sigmas_gpu = cp.asarray(sigmas, dtype=cp.float32)

    # Allocate output arrays
    total_edges = n_samples * (k - 1)
    rows_gpu = cp.empty(total_edges, dtype=cp.int32)
    cols_gpu = cp.empty(total_edges, dtype=cp.int32)
    vals_gpu = cp.empty(total_edges, dtype=cp.float32)

    # Compute edge indices
    row_idx = cp.arange(n_samples, dtype=cp.int32)
    for j in range(k - 1):
        start = j * n_samples
        end = start + n_samples
        rows_gpu[start:end] = row_idx
        cols_gpu[start:end] = neigh_idx_gpu[:, j]

    # Reshape for vectorized operations
    rows_2d = rows_gpu.reshape(k - 1, n_samples).T  # (n_samples, k-1)
    cols_2d = cols_gpu.reshape(k - 1, n_samples).T  # (n_samples, k-1)

    # Compute differences and apply kernel
    rho_i = rho_gpu[rows_2d]
    sigma_i = sigmas_gpu[rows_2d]
    rho_j = rho_gpu[cols_2d]
    sigma_j = sigmas_gpu[cols_2d]

    # Compute differences
    diff = neigh_dist_gpu - rho_i
    diff = cp.maximum(diff, 0.0)

    # Choose effective sigma based on symmetrization mode
    if sym_code == 1:  # max
        sigma_eff = cp.maximum(sigma_i, sigma_j)
    elif sym_code == -10:  # geometric mean (not in original but useful)
        sigma_eff = cp.sqrt(sigma_i * sigma_j)
    else:  # default to using sigma_i
        sigma_eff = sigma_i

    # Apply kernel function
    normalized = diff / sigma_eff
    weights = kernel_function(normalized, kernel_params)

    # Flatten weights back
    vals_gpu = weights.T.ravel()

    # Transfer back to CPU
    return (
        cp.asnumpy(rows_gpu),
        cp.asnumpy(cols_gpu),
        cp.asnumpy(vals_gpu).astype(np.float32),
    )


def _cuda_symmetrize_graph_fast(
    rows,
    cols,
    vals,
    n_samples,
    symmetrize="mean",
    balance=True,
    balance_target="mean",
    sink_max_iter=100,
    sink_tol=1e-3,
):
    """
    GPU-accelerated graph symmetrization with optional symmetric Sinkhorn balancing.

    Parameters
    ----------
    rows, cols, vals : np.ndarray
        COO format sparse matrix data
    n_samples : int
        Number of nodes
    symmetrize : str
        Symmetrization strategy: "mean", "max", "min", "umap", "harm", "geom"
    balance : bool
        Whether to apply symmetric Sinkhorn balancing
    balance_target : str or float
        Target for balancing: "mean" or specific value
    sink_max_iter : int
        Maximum iterations for Sinkhorn
    sink_tol : float
        Convergence tolerance

    Returns
    -------
    rows, cols, vals : np.ndarray
        Symmetrized COO format data
    """
    # Transfer to GPU
    rows_gpu = cp.asarray(rows, dtype=cp.int32)
    cols_gpu = cp.asarray(cols, dtype=cp.int32)
    vals_gpu = cp.asarray(vals, dtype=cp.float32)

    # Create sparse matrix
    A = cpsp.coo_matrix(
        (vals_gpu, (rows_gpu, cols_gpu)), shape=(n_samples, n_samples), dtype=cp.float32
    )

    # Apply symmetrization strategy
    if symmetrize == "mean":
        # (A + A.T) / 2
        A_csr = A.tocsr()
        S = (A_csr + A_csr.T) / 2

    elif symmetrize == "max":
        # max(A, A.T)
        A_csr = A.tocsr()
        A_T = A_csr.T.tocsr()
        S = A_csr.maximum(A_T)

    elif symmetrize == "min":
        # min(A, A.T)
        A_csr = A.tocsr()
        A_T = A_csr.T.tocsr()
        S = A_csr.minimum(A_T)

    elif symmetrize == "umap":
        # A + A.T - A*A.T (fuzzy union)
        A_csr = A.tocsr()
        A_T = A_csr.T.tocsr()
        S = A_csr + A_T - A_csr.multiply(A_T)

    elif symmetrize == "harm":
        # Harmonic mean approximation for sparse matrices
        # True harmonic mean: 2*A*A.T / (A + A.T)
        # We use geometric mean as approximation: sqrt(A * A.T)
        # This maintains sparsity (intersection) and is efficient on GPU
        A_csr = A.tocsr()
        A_T = A_csr.T.tocsr()

        # Geometric mean of A_ij and A_ji (good approximation to harmonic)
        S = A_csr.multiply(A_T)
        if S.nnz > 0:
            S.data = cp.sqrt(S.data)
            # Scale by 2 to match harmonic mean magnitude
            S = S.multiply(2.0 / 1.414)  # 2/sqrt(2) scaling factor
        else:
            # Empty intersection
            S = cpsp.csr_matrix((n_samples, n_samples), dtype=cp.float32)

    elif symmetrize == "geom":
        # Geometric mean: sqrt(A * A.T)
        A_csr = A.tocsr()
        A_T = A_csr.T.tocsr()
        S = A_csr.multiply(A_T)
        S.data = cp.sqrt(S.data)

    else:
        raise ValueError(f"Unknown symmetrize mode: {symmetrize}")

    # Apply symmetric balancing if requested
    if balance:
        S = _cuda_sym_sinkhorn_balance(
            S, target=balance_target, max_iter=sink_max_iter, tol=sink_tol
        )

    # Convert back to COO format
    S_coo = S.tocoo()

    # Transfer back to CPU
    return (
        cp.asnumpy(S_coo.row).astype(np.int32),
        cp.asnumpy(S_coo.col).astype(np.int32),
        cp.asnumpy(S_coo.data).astype(np.float32),
    )


def _cuda_sym_sinkhorn_balance(
    mat_csr, target="mean", max_iter=10, tol=1e-3, eps=1e-12, w_max=1e9
):
    """
    GPU-accelerated symmetric Sinkhorn balancing.

    Scales P = D_w · A · D_w where D_w is diagonal, maintaining symmetry.
    """
    # Check and enforce symmetry without using != operator
    # Just force symmetry to be safe
    mat_csr = (mat_csr + mat_csr.T) / 2

    n = mat_csr.shape[0]
    w = cp.ones(n, dtype=cp.float32)

    # Determine target sum
    if target is None:
        s = 1.0
    elif target == "mean":
        s = mat_csr.sum() / n
    else:
        s = float(target)
        if s <= 0:
            raise ValueError("target must be positive")

    # Fixed-point iteration
    for iteration in range(max_iter):
        # Compute Aw = mat_csr.dot(w)
        Aw = mat_csr.dot(w)
        Aw = Aw + eps  # Avoid division by zero

        # Update w
        w_new = s / Aw
        w_new = cp.minimum(w_new, w_max)  # Cap growth

        # Check convergence
        rel_change = cp.max(cp.abs(w_new - w) / (w + eps))

        if rel_change < tol:
            w = w_new
            break

        w = w_new

        # Check for numerical issues
        if not cp.isfinite(w).all():
            raise FloatingPointError("w blew up during symmetric Sinkhorn")

    # Build balanced matrix: P = diag(w) @ mat_csr @ diag(w)
    # This is equivalent to scaling rows and columns by w
    w_expanded = w[:, None]
    P = mat_csr.multiply(w_expanded).multiply(w_expanded.T)

    return P.tocsr()


def _cuda_build_weighted_graph(
    X: np.ndarray,
    k: int,
    kernel_function,
    kernel_params,
    symmetrize: str = "max",
    use_zero_rho: bool = False,
    balance_graph: bool = True,
    balance_target: Union[str, float] = "mean",
    batch_size: int = 10000,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Returns
    -------
    rho : np.ndarray
        Local connectivity distance for each point
    sigmas : np.ndarray
        Bandwidth for each point
    edges_i : np.ndarray
        Row indices of edges
    edges_j : np.ndarray
        Column indices of edges
    P_vals : np.ndarray
        Edge weights
    neigh_idx : np.ndarray
        Neighbor indices
    """
    n_samples = X.shape[0]
    print(f"Building weighted graph on GPU with {n_samples} samples...")
    # 1. k-NN search on GPU
    print("  1. GPU k-NN search...")
    distances, indices = _cuml_gpu_knn(X, k, batch_size=batch_size)
    # 2. Local connectivity: rho and sigma
    print("  2. Computing rho and sigma on GPU...")
    rho = distances[:, 1].astype(np.float32)
    if use_zero_rho:
        rho = np.zeros_like(rho)

    # GPU self-tuning sigma
    sigmas = _cuda_self_tuning_sigma(
        rho,
        distances[:, 1:],
        k,
        kernel_function=kernel_function,
        kernel_params=kernel_params,
    ).astype(np.float32)

    # 3. Compute weights on GPU
    print("  3. Computing edge weights on GPU...")
    neigh_idx = indices[:, 1:]
    neigh_dist = distances[:, 1:]

    # Encode symmetrize mode
    _SYM_MAP = {"mean": 0, "max": 1, "umap": 2, "geom": 3, "min": 4, "harm": 5}
    sym_code = _SYM_MAP.get(symmetrize, 0)

    rows, cols, vals = _cuda_compute_weights_vectorized(
        neigh_idx,
        neigh_dist,
        rho,
        sigmas,
        sym_code,
        n_samples,
        k,
        kernel_function=kernel_function,
        kernel_params=kernel_params,
    )

    # 4. Symmetrize the graph on GPU
    print("  4. Symmetrizing graph on GPU...")
    edges_i, edges_j, P_vals = _cuda_symmetrize_graph_fast(
        rows,
        cols,
        vals,
        n_samples,
        symmetrize=symmetrize,
        balance=balance_graph,
        balance_target=balance_target,
    )

    print("  ✓ GPU graph construction complete!")
    print(f"    Final graph: {len(P_vals)} edges")

    return rho, sigmas, edges_i, edges_j, P_vals, neigh_idx


# Numba optimized training functions
@numba.njit(fastmath=True)
def cauchy_attractive_force(r2, a, b):
    """Cauchy kernel attractive force"""
    ab = a * b
    return -ab / (a * r2 + 1.0)


@numba.njit(fastmath=True)
def cauchy_repulsive_force(r2, a, b):
    """Cauchy kernel repulsive force"""
    ab = a * b
    return ab / (a * r2 + 1.0) ** (b + 1.0)


@numba.njit(fastmath=True)
def student_t_attractive_force(r2, nu):
    """Student-t kernel attractive force"""
    power = (nu + 1.0) / 2.0
    return -power / (nu + r2)


@numba.njit(fastmath=True)
def student_t_repulsive_force(r2, nu):
    """Student-t kernel repulsive force"""
    power = (nu + 1.0) / 2.0
    return power / (nu * (1.0 + r2 / nu) ** (power + 1.0))


@numba.njit(inline="always")
def fast_randint(k, epoch, neg_idx, n_pts, seed):
    """PCG-based random number generator"""
    state = (
        np.uint64(k)
        ^ (np.uint64(epoch) << 16)
        ^ (np.uint64(neg_idx) << 32)
        ^ np.uint64(seed)
    )
    state = state * np.uint64(6364136223846793005) + np.uint64(1442695040888963407)
    xorshifted = np.uint32(((state >> 18) ^ state) >> 27)
    rot = np.int32(state >> 59)
    result = (xorshifted >> rot) | (xorshifted << ((-rot) & 31))
    return int(result % np.uint32(n_pts))


@cuda.jit(device=True, inline=True)
def fast_randint_cuda(k, epoch, neg_idx, n_pts, seed):
    """PCG-based random number generator, inlined on the GPU."""
    # build a 64-bit state
    state = uint64(k) ^ (uint64(epoch) << 16) ^ (uint64(neg_idx) << 32) ^ uint64(seed)
    # PCG step
    state = state * uint64(6364136223846793005) + uint64(1442695040888963407)
    # xorshift and rotate
    xorshifted = uint32(((state >> 18) ^ state) >> 27)
    rot = int32(state >> 59)
    # final result
    rnd = (xorshifted >> rot) | (xorshifted << ((-rot) & 31))
    return int(rnd % uint32(n_pts))


@cuda.jit(device=True, inline=True)
def cauchy_attractive_force_cuda(r2, a, b):
    # note: inlined device function, returns a scalar
    # protect against r2→0 by caller
    return -(a * b) / (a * r2 + 1.0)


@cuda.jit(device=True, inline=True)
def cauchy_repulsive_force_cuda(r2, a, b):
    return (a * b) / (a * r2 + 1.0) ** (b + 1.0)


@cuda.jit(device=True, inline=True)
def student_t_attractive_force_cuda(r2, nu):
    power = (nu + 1.0) / 2.0
    return -power / (nu + r2)


@cuda.jit(device=True, inline=True)
def student_t_repulsive_force_cuda(r2, nu):
    power = (nu + 1.0) / 2.0
    return power / (nu * (1.0 + r2 / nu) ** (power + 1.0))


@cuda.jit
def cuda_training_epoch(
    Y,
    ei,
    ej,
    ui,
    us,
    uir,
    usr,
    epoch,
    lr,
    grad_clip,
    negative_sample_rate,
    kernel_type,
    kernel_params,
    random_seed,
):
    """
    Prevents race conditions!
    """
    # Total number of points to process
    n_edges = ei.shape[0]
    n_pts, dim = Y.shape

    tidx = cuda.threadIdx.x
    bidx = cuda.blockIdx.x

    grid_dim = cuda.gridDim.x
    block_dim = cuda.blockDim.x

    start = tidx + block_dim * bidx
    stride = block_dim * grid_dim

    for k in range(start, n_edges, stride):
        # Attractive forces
        if us[k] <= epoch:
            i = ei[k]
            j = ej[k]

            r2 = 0.0
            for d in range(dim):
                diff = Y[i, d] - Y[j, d]
                r2 += diff * diff
            r2 = max(r2, 1e-12)

            # Use cauchy cuda force
            if kernel_type == 0:
                coef = cauchy_attractive_force_cuda(
                    r2, kernel_params[0], kernel_params[1]
                )
            else:
                coef = student_t_attractive_force_cuda(r2, kernel_params[0])

            for d in range(dim):
                grad = 2.0 * coef * (Y[i, d] - Y[j, d])
                if grad_clip > 0:
                    grad = max(-grad_clip, min(grad_clip, grad))
                # Atomic add
                # cuda.atomic.add(Y, (i, d), lr * grad)
                # cuda.atomic.sub(Y, (j, d), lr * grad)

                # Has raise conditions but potentially faster!
                # Hogwilde style loop
                Y[i, d] += lr * grad
                Y[j, d] -= lr * grad

            us[k] += ui[k]

        if uir[k] > 0:
            n_neg = int((epoch - usr[k]) / uir[k])

            if n_neg > 0:
                i = ei[k]

                for neg_idx in range(n_neg):
                    l = fast_randint_cuda(k, epoch, neg_idx, n_pts, random_seed)
                    if i == l:
                        continue

                    r2 = 0.0
                    for d in range(dim):
                        diff = Y[i, d] - Y[l, d]
                        r2 += diff * diff
                    r2 = max(r2, 1e-12)

                    if kernel_type == 0:
                        coef_r = cauchy_repulsive_force_cuda(
                            r2, kernel_params[0], kernel_params[1]
                        )
                    else:
                        coef_r = student_t_repulsive_force_cuda(r2, kernel_params[0])

                    # Update position
                    for d in range(dim):
                        grad = 2.0 * coef_r * (Y[i, d] - Y[l, d])
                        if grad_clip > 0:
                            grad = max(-grad_clip, min(grad_clip, grad))

                        # Push them in opposite directions and conserve momentum.
                        # cuda.atomic.add(Y, (i, d), lr * grad)
                        # cuda.atomic.sub(Y, (l, d), lr * grad)

                        # Has raise conditions but potentially faster
                        Y[i, d] += lr * grad
                        Y[l, d] -= lr * grad

                usr[k] += n_neg * uir[k]


# Original Kernel classes remain the same
class Kernel(ABC):
    """Base class for force kernels"""

    @abstractmethod
    def attractive_force(self, r2):
        pass

    @abstractmethod
    def repulsive_force(self, r2):
        pass

    @abstractmethod
    def get_numba_params(self):
        """Return parameters for Numba kernel"""
        pass


class CauchyKernel(Kernel):
    """Cauchy kernel: Q(r) = 1/(1 + ar²)^b"""

    def __init__(self, a=2.0, b=1.0):
        self.a = a
        self.b = b
        self.ab = a * b

    def attractive_force(self, r2):
        """F_attr = -ab/(a*r² + 1)"""
        return -self.ab / (self.a * r2 + 1.0)

    def repulsive_force(self, r2):
        """F_rep = ab/(a*r² + 1)^(b+1)"""
        return self.ab / cp.power(self.a * r2 + 1.0, self.b + 1.0)

    def get_numba_params(self):
        return np.array([self.a, self.b], dtype=np.float32)


class StudentTKernel(Kernel):
    """Student-t kernel: Q(r) = (1 + r²/ν)^(-(ν+1)/2)"""

    def __init__(self, nu=1.0):
        self.nu = nu
        self.power = (nu + 1.0) / 2.0

    def attractive_force(self, r2):
        """Derivative of -log(Q(r)) w.r.t r²"""
        return -(self.power) / (self.nu + r2)

    def repulsive_force(self, r2):
        """Second derivative scaled appropriately"""
        denominator = cp.power(1.0 + r2 / self.nu, self.power + 1.0)
        return self.power / (self.nu * denominator)

    def get_numba_params(self):
        return np.array([self.nu, 0.0], dtype=np.float32)  # Pad to same size


def compute_intervals_cuda(weights, n_epochs, data_precision=1.0):
    """Compute sampling intervals based on edge weights"""
    w_max = cp.max(weights)
    intervals = cp.empty_like(weights, dtype=cp.float32)

    mask = weights > 0
    n_s = data_precision * n_epochs * (weights[mask] / w_max)
    intervals[mask] = (data_precision * n_epochs) / n_s
    intervals[~mask] = cp.float32(-1.0)

    return intervals


def initialize_schedule(weights, n_epochs, negative_sample_rate=5.0):
    """Initialize update schedules based on edge weights"""
    ui = compute_intervals_cuda(weights, n_epochs)
    uir = ui / negative_sample_rate

    # CORRECT INITIALIZATION
    us = ui.copy()
    usr = uir.copy()

    return ui, uir, us, usr


class cuCTMCEmbeddings:
    def __init__(
        self,
        n_neighbors: int = 15,
        n_components: int = 2,
        n_epochs: int = 200,
        learning_rate: float = 1.0,
        force_kernel: Optional["Kernel"] = None,
        graph_kernel: Optional[Callable] = None,
        kernel_params: Any = None,
        negative_sample_rate: float = 5.0,
        gradient_clip: Optional[float] = None,
        random_state: Optional[int] = None,
        verbose: bool = True,
    ):
        self.n_neighbors = n_neighbors
        self.n_components = n_components
        self.n_epochs = n_epochs
        self.learning_rate = learning_rate
        self.force_kernel = force_kernel or CauchyKernel()
        self.graph_kernel = graph_kernel
        self.kernel_params = kernel_params
        self.negative_sample_rate = negative_sample_rate
        self.gradient_clip = gradient_clip if gradient_clip is not None else 0.0
        self.random_state = random_state if random_state is not None else 42
        self.verbose = verbose

        # Fitted attributes
        self.embedding_ = None
        self.n_samples_ = None
        self._is_fitted = False

    def fit(self, X, y=None):
        """Fit the embedding model using CPU-based Numba training"""
        X = np.asarray(X, dtype=np.float32)

        # Build graph and initialize (stays on GPU)
        ei, ej, weights, Y = self._prepare_graph_and_init(X)

        # Initialize schedules (on GPU)
        ui, uir, us, usr = initialize_schedule(
            weights, self.n_epochs, self.negative_sample_rate
        )

        # Transfer to CPU for training
        Y_cpu = Y.get().astype(np.float32)
        ei_cpu = ei.get().astype(np.int32)
        ej_cpu = ej.get().astype(np.int32)
        ui_cpu = ui.get().astype(np.float32)
        us_cpu = us.get().astype(np.float32)
        uir_cpu = uir.get().astype(np.float32)
        usr_cpu = usr.get().astype(np.float32)

        # Determine kernel type and get parameters
        if isinstance(self.force_kernel, CauchyKernel):
            kernel_type = 0
        elif isinstance(self.force_kernel, StudentTKernel):
            kernel_type = 1
        else:
            raise ValueError("Unknown kernel type")

        kernel_params = self.force_kernel.get_numba_params()

        # Run training loop
        if self.verbose:
            print(f"Training with Numba CPU backend...")
            start_time = time.time()

        Y_host = cuda.to_device(Y_cpu)
        ei_host = cuda.to_device(ei_cpu)
        ej_host = cuda.to_device(ej_cpu)

        ui_host = cuda.to_device(ui_cpu)
        us_host = cuda.to_device(us_cpu)
        uir_host = cuda.to_device(uir_cpu)
        usr_host = cuda.to_device(usr_cpu)

        d_kernel_params = cuda.to_device(kernel_params)

        for epoch in tqdm(range(self.n_epochs)):

            lr = self.learning_rate * (1.0 - epoch / float(max(1, self.n_epochs)))
            grad_clip = self.gradient_clip
            negative_sample_rate = self.negative_sample_rate
            kernel_type = kernel_type
            kernel_params = kernel_params
            random_seed = self.random_state

            blocks_per_grid = 256
            threads_per_block = 256

            cuda_training_epoch[blocks_per_grid, threads_per_block](
                Y_host,
                ei_host,
                ej_host,
                ui_host,
                us_host,
                uir_host,
                usr_host,
                epoch,
                lr,
                grad_clip,
                negative_sample_rate,
                kernel_type,
                d_kernel_params,
                random_seed,
            )

        if self.verbose:
            elapsed = time.time() - start_time
            print(f"Training complete in {elapsed:.2f}s")

        self._is_fitted = True
        Y_cpu = Y_host.copy_to_host()
        self.embedding_ = Y_cpu
        return self

    def fit_transform(self, X, y=None):
        """Fit the model and return the embedding"""
        self.fit(X, y)
        return self.embedding_

    def _prepare_graph_and_init(self, X):
        """Build weighted graph and initialize positions"""
        self.n_samples_ = X.shape[0]

        if self.verbose:
            print("Building graph...")

        # This part stays as-is using GPU
        rho, sigmas, ei, ej, P_vals, neigh_idx = _cuda_build_weighted_graph(
            X,
            self.n_neighbors,
            kernel_function=self.graph_kernel,
            kernel_params=None,
            symmetrize="umap",
            use_zero_rho=False,
            balance_graph=True,
        )

        # Spectral initialization
        from scipy.sparse import coo_matrix

        W = coo_matrix((P_vals, (ei, ej)), shape=(X.shape[0], X.shape[0]))

        if self.verbose:
            # Check connectivity
            from scipy.sparse.csgraph import connected_components

            n_components, _ = connected_components(W, directed=False)
            print(f"Connected components: {n_components}")

        Y, eigenvals = _cuda_spectral_init(
            W, n_components=self.n_components, check_connectivity=False
        )
        # print("WARNING: Using umap initialization!")
        # ump = UMAP(n_neighbors=self.n_neighbors, n_components=2, verbose=False, n_epochs=2, learning_rate=0.00000001)
        # Y = ump.fit_transform(X)

        scale = X.std(axis=0).mean() / (Y.std() + 1e-12)
        Y *= scale

        # Convert to CuPy arrays
        Y = cp.asarray(Y, dtype=cp.float32)
        ei = cp.asarray(ei, dtype=cp.int32)
        ej = cp.asarray(ej, dtype=cp.int32)
        P_vals = cp.asarray(P_vals, dtype=cp.float32)

        return ei, ej, P_vals, Y


def cauchy_kernel(r, params=None):
    """Cauchy kernel: 1 / (1 + r^2)"""
    return 1.0 / (1.0 + 2.0 * r * r) ** (0.67)
