import torch
import numpy as np
from math import sqrt
from sklearn.neighbors import kneighbors_graph
from scipy.sparse.csgraph import minimum_spanning_tree, dijkstra
from scipy.sparse import csr_matrix
import sys


def compute_mst_geodesics(x_test, x_train, k=30, random_state=None):
    """
    Computes test-train geodesic distances over the MST of the k-NN graph.

    Args:
        x_test (ndarray): High-dimensional test data of shape (n_test, d).
        x_train (ndarray): High-dimensional training data of shape (n_train, d).
        k (int): Number of neighbors to consider for the k-NN graph. Defaults to 30.
        random_state (int or None): Random seed.

    Returns:
        geodesics (ndarray): Array of shape (n_test, n_train), geodesic distances from each test point to each train point.
    """
    if random_state is not None:
        np.random.seed(random_state)

    # Stack train and test together
    x = np.vstack([x_train, x_test])
    n_total = x.shape[0]
    n_test = x_test.shape[0]
    n_train = x_train.shape[0]
    test_indices = np.arange(n_total - n_test, n_total)
    train_indices = np.arange(n_train)

    # Step 1: Build k-NN graph
    knn_graph = kneighbors_graph(x, n_neighbors=k, mode='distance', include_self=False, n_jobs=-1)

    # Step 2: Compute MST from k-NN graph
    mst = minimum_spanning_tree(knn_graph)

    # Step 3: Compute all-pair shortest paths on MST
    geodesic_dist_all = dijkstra(csgraph=mst, directed=False)

    # Step 4: Extract only test-to-train distances
    geodesic_dist = geodesic_dist_all[np.ix_(test_indices, train_indices)]

    return geodesic_dist


def dist_preservation(x, y):
    """
    Computes Stress loss, Spearman correlation, and Pearson correlation between
    high-dimensional and low-dimensional pairwise distances.
    Supports non-square distance matrices (e.g., n_anchors x n).

    Args:
        x (torch.Tensor): Pairwise distances in high-dimensional space (n_anchors x n).
        y (torch.Tensor): Pairwise distances in low-dimensional space (same shape as x).

    Returns:
        tuple: (Spearman correlation, Stress loss, Pearson correlation)
    """
    # Flatten
    x_flat = x.flatten()
    y_flat = y.flatten()

    # Stress-1 normalized
    stress_loss = torch.sqrt(((x_flat - y_flat) ** 2).sum() / (x_flat ** 2).sum())

    # Spearman correlation (rank-based)
    x_rank = x_flat.argsort().argsort().float()
    y_rank = y_flat.argsort().argsort().float()
    x_rank_mean = x_rank.mean()
    y_rank_mean = y_rank.mean()
    cov_spearman = ((x_rank - x_rank_mean) * (y_rank - y_rank_mean)).sum()
    std_spearman_x = torch.sqrt(((x_rank - x_rank_mean) ** 2).sum())
    std_spearman_y = torch.sqrt(((y_rank - y_rank_mean) ** 2).sum())
    spearman_corr = cov_spearman / (std_spearman_x * std_spearman_y + 1e-8)

    # Pearson correlation (raw distances)
    x_mean = x_flat.mean()
    y_mean = y_flat.mean()
    cov_pearson = ((x_flat - x_mean) * (y_flat - y_mean)).sum()
    std_x = torch.sqrt(((x_flat - x_mean) ** 2).sum())
    std_y = torch.sqrt(((y_flat - y_mean) ** 2).sum())
    pearson_corr = cov_pearson / (std_x * std_y + 1e-8)

    return pearson_corr.item(), spearman_corr.item(), stress_loss.item()



def qnx_trust_cont(dist_hd, dist_ld, device='cpu'):
    """
    Computes the average QNX, Trustworthiness, and Continuity scores between high- and low-dimensional 
    distance matrices of shape (n_anchors, n).
    Self-distances are not masked, but ranking naturally avoids them.

    Parameters:
        dist_hd (torch.Tensor): High-dimensional distances (n_anchors, n).
        dist_ld (torch.Tensor): Low-dimensional distances (same shape).
        device (str): Device to perform the computations.

    Returns:
        tuple: (average QNX, average Trustworthiness, average Continuity)
    """
    dist_hd = dist_hd.clone().to(device)
    dist_ld = dist_ld.clone().to(device)

    n_anchors, n_total = dist_hd.shape

    # Precompute ranks
    hd_sorted_indices = torch.argsort(dist_hd, dim=1)
    ld_sorted_indices = torch.argsort(dist_ld, dim=1)

    hd_ranks = torch.zeros_like(hd_sorted_indices, dtype=torch.long)
    hd_ranks.scatter_(1, hd_sorted_indices, torch.arange(n_total, device=device).expand_as(hd_sorted_indices))

    ld_ranks = torch.zeros_like(ld_sorted_indices, dtype=torch.long)
    ld_ranks.scatter_(1, ld_sorted_indices, torch.arange(n_total, device=device).expand_as(ld_sorted_indices))

    max_k = int(sqrt(n_total))
    ks = torch.linspace(5, max_k, steps=10, device=device).long().unique()

    qnx_scores = []
    trust_scores = []
    cont_scores = []

    for k in ks:
        # Nearest neighbors in LD and HD
        _, ld_indices = torch.topk(dist_ld, k, largest=False, dim=1)
        _, hd_indices = torch.topk(dist_hd, k, largest=False, dim=1)

        # === QNX ===
        matches = (ld_indices.unsqueeze(2) == hd_indices.unsqueeze(1)).sum(dim=-1).float()
        qnx_value = matches.sum() / (n_anchors * k)

        # === Trustworthiness ===
        hd_rank_vals = hd_ranks.gather(1, ld_indices)
        trust_penalties = torch.clamp(hd_rank_vals - k, min=0)
        trust_sum = trust_penalties.sum()
        trust_norm = n_anchors * k * (2 * n_total - 3 * k - 1) / 2
        trust_value = 1 - (trust_sum / trust_norm)

        # === Continuity ===
        ld_rank_vals = ld_ranks.gather(1, hd_indices)
        cont_penalties = torch.clamp(ld_rank_vals - k, min=0)
        cont_sum = cont_penalties.sum()
        cont_norm = n_anchors * k * (2 * n_total - 3 * k - 1) / 2
        cont_value = 1 - (cont_sum / cont_norm)

        qnx_scores.append(qnx_value)
        trust_scores.append(trust_value)
        cont_scores.append(cont_value)

    avg_qnx = torch.stack(qnx_scores).mean().item()
    avg_trust = torch.stack(trust_scores).mean().item()
    avg_cont = torch.stack(cont_scores).mean().item()

    return avg_qnx, avg_trust, avg_cont