from BACKEND import cp, np, to_cpu, to_gpu
import matplotlib.pyplot as plt

from .sampling_metrics import SamplingMetric
from .sampling_metrics import PairwiseL2, PairwiseDot

# Sampling step of the S-SWIM algorithm


def top_k_2d(dists, k, shape):
    """
    Retrieve the top k largest values and their indices from a 2D NumPy array.

    Parameters:
    matrix (cp.ndarray): A 2D NumPy array.
    k (int): The number of top elements to retrieve.

    Returns:
    values (cp.ndarray): The top k largest values, sorted in descending order.
    indices (cp.ndarray): The corresponding indices of the top k values in the original matrix.
    """
    # Flatten the matrix to a 1D array
    flat = dists.ravel()

    # Use argpartition to get indices of the top k elements
    partition_indices = cp.argpartition(flat, -k)[-k:]

    # Convert flat indices back to 2D indices
    indices = cp.column_stack(cp.unravel_index(partition_indices, shape))

    return indices


def sample_pairs(X: cp.ndarray, Y: cp.ndarray, k: int, d_in:SamplingMetric=PairwiseL2, d_out:SamplingMetric=PairwiseL2, min_norm=1e-3, seed=42,
                 plot=False, demean_x=True, demean_y=True, top_k=False, return_probabilities=False,):
    X = to_gpu(X).copy()
    Y = to_gpu(Y).copy()
    if demean_x:
        X -= X.mean(axis=2, keepdims=True)
    if demean_y:
        Y -= Y.mean(axis=2, keepdims=True)
    out_dists = d_out(Y)
    in_dists = d_in(X)
    G = out_dists / (in_dists + 1e-6)
    if min_norm > 0:
        G *= (cp.linalg.norm(X, axis=-1).min(axis=-1) >= min_norm)[:, None]
    G = cp.array(G, dtype=cp.float64)

    n = G.shape[0]

    # Extract upper triangular indices (i < j)
    i_upper, j_upper = cp.triu_indices(n, k=1)
    distances = G[i_upper, j_upper]

    if top_k:
        return top_k_2d(distances, k, shape=G.shape)

    # Normalize distances to create a probability distribution
    total_distance = cp.sum(distances)
    if total_distance == 0:
        raise ValueError("Sum of distances is zero. Cannot sample.")

    probabilities = distances / total_distance
    if return_probabilities:
        return probabilities
    # Sample on CPU because CUPY only implements with replacement
    probabilities = to_cpu(probabilities)
    if plot:
        plt.matshow(to_cpu(out_dists), cmap="binary")
        plt.title("Out distance")
        plt.colorbar()
        plt.show()
        plt.matshow(to_cpu(in_dists), cmap="binary")
        plt.title("In distance")
        plt.colorbar()
        plt.show()
        plt.matshow(to_cpu(G), cmap="binary")
        plt.title("Masked ratios")
        plt.colorbar()
        plt.show()
        plt.hist(probabilities, bins=200)
        plt.xlim(left=1e-6)
        plt.xscale('log')
        plt.yscale('log')
        plt.show()
    rng = np.random.default_rng(seed)
    sampled_indices = rng.choice(len(probabilities), size=k, replace=False, p=probabilities)
    sampled_pairs = cp.array(list(zip(i_upper[sampled_indices], j_upper[sampled_indices])))

    return sampled_pairs
