import time
from typing import Any, Dict, Optional, Tuple

import numpy as np
from scipy.spatial import cKDTree
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize


def get_kth_value(unsorted, k, axis=-1):
    """
    Args:
        unsorted: numpy.ndarray of any dimensionality.
        k: int
    Returns:
        kth values along the designated axis.
    """
    indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
    k_smallests = np.take_along_axis(unsorted, indices, axis=axis)
    kth_values = k_smallests.max(axis=axis)
    return kth_values


def pr_knn_from_embeddings(
    E_ref: np.ndarray,
    E_out: np.ndarray,
    pca_var: float = 0.90,
    n_samples: int = 3000,
    k: int = 4,
    eps: float = 1e-12,
    rng_seed_ref: int = 0,
    rng_seed_out: int = 1,
    target_block_mem_mb: float = 64.0,
) -> Tuple[float, float, Dict[str, Any]]:
    """
    Precision and Recall via kNN supports with optional PCA.
    NumPy only. No sklearn, no scipy.

    E_ref: (N_ref, d) reference embeddings
    E_out: (N_out, d) generated embeddings
    pca_var: explained variance to keep in PCA on the union. Use 1.0 to skip PCA.
    n_samples: cap on per set sample size before PCA and search
    k: k for within set neighbor radii
    eps: minimum radius to avoid zero balls
    rng_seed_ref, rng_seed_out: seeds for the independent subsampling
    target_block_mem_mb: approximate memory budget per distance block
    """
    # Basic checks
    start = time.time()
    if E_ref.ndim != 2 or E_out.ndim != 2:
        raise ValueError("E_ref and E_out must be 2D arrays")
    if E_ref.shape[1] != E_out.shape[1]:
        raise ValueError("E_ref and E_out must have the same embedding dimension")

    E_ref = np.ascontiguousarray(E_ref, dtype=np.float32)
    E_out = np.ascontiguousarray(E_out, dtype=np.float32)
    N_ref, d = E_ref.shape
    N_out, _ = E_out.shape

    if N_ref == 0 or N_out == 0:
        info = {"reason": "empty set", "n_components": 0, "explained_variance": 0.0}
        return (
            0.0 if N_out > 0 else float("nan"),
            0.0 if N_ref > 0 else float("nan"),
            info,
        )

    # cap sizes
    cap = min(N_ref, N_out, n_samples)
    if N_ref > cap:
        rng = np.random.default_rng(rng_seed_ref)
        E_ref = E_ref[rng.choice(N_ref, size=cap, replace=False)]
        N_ref = cap
    if N_out > cap:
        rng = np.random.default_rng(rng_seed_out)
        E_out = E_out[rng.choice(N_out, size=cap, replace=False)]
        N_out = cap

    # PCA via SVD on union
    def pca_union(Xr: np.ndarray, Xo: np.ndarray, var_keep: float):
        if var_keep >= 1.0:
            return Xr, Xo, Xr.shape[1], 1.0
        X = np.vstack([Xr, Xo]).astype(np.float32, copy=False)
        mean = X.mean(axis=0, keepdims=True)
        Xc = X - mean
        U, S, Vt = np.linalg.svd(Xc, full_matrices=False)
        n = Xc.shape[0]
        ev = (S.astype(np.float64) ** 2) / max(n - 1, 1)
        cumsum = np.cumsum(ev)
        ev_sum = cumsum[-1] if cumsum.size else 0.0
        if ev_sum <= 0.0:
            return Xr - mean, Xo - mean, Xr.shape[1], 0.0
        ratio = cumsum / ev_sum
        m = int(np.searchsorted(ratio, var_keep, side="right")) + 1
        m = max(1, min(m, Vt.shape[0]))
        Z = (X - mean) @ Vt[:m].T
        Z = Z.astype(np.float32, copy=False)
        return Z[: Xr.shape[0]], Z[Xr.shape[0] :], m, float(ratio[m - 1])

    Z_ref, Z_out, m, explained = pca_union(E_ref, E_out, pca_var)

    # helper: choose block so that temp distance block uses about target_block_mem_mb
    def block_size(n_other: int, bytes_per_float: int = 4) -> int:
        max_f = max(1, int((target_block_mem_mb * 1024 * 1024) // bytes_per_float))
        b = max_f // max(1, n_other)
        return int(max(1, min(4096, b)))

    def norms2(X: np.ndarray) -> np.ndarray:
        return np.sum(X.astype(np.float64) ** 2, axis=1).astype(np.float32)

    # squared distances block
    def d2_block(
        A: np.ndarray, B: np.ndarray, nA: np.ndarray, nB: np.ndarray
    ) -> np.ndarray:
        G = A @ B.T
        D2 = nA[:, None] + nB[None, :] - 2.0 * G
        return np.maximum(D2, 0.0, out=D2)

    # k-th neighbor radii within set, returns squared radii to avoid sqrt later
    def kth_radii2_within(X: np.ndarray, k_val: int, eps_val: float) -> np.ndarray:
        n = X.shape[0]
        if n == 1:
            return np.full(1, eps_val**2, dtype=np.float32)
        kk = int(np.clip(k_val, 1, n - 1))
        n2 = norms2(X)
        r2 = np.empty(n, dtype=np.float32)
        bsz = block_size(n)
        for i in range(0, n, bsz):
            j = min(i + bsz, n)
            A = X[i:j]
            D2 = d2_block(A, X, n2[i:j], n2)
            r = np.arange(j - i)
            D2[r, i + r] = np.inf
            part = np.partition(D2, kth=kk - 1, axis=1)
            r2[i:j] = part[:, kk - 1]
            # floor at eps^2
            np.maximum(r2[i:j], eps_val**2, out=r2[i:j])
        return r2

    r2_ref = kth_radii2_within(Z_ref, k, eps)
    r2_out = kth_radii2_within(Z_out, k, eps)

    # cross coverage with early exit and squared comparisons
    covered_out = np.zeros(N_out, dtype=bool)
    covered_ref = np.zeros(N_ref, dtype=bool)
    n2_ref = norms2(Z_ref)
    n2_out = norms2(Z_out)

    bsz = block_size(N_out)
    for i in range(0, N_ref, bsz):
        j = min(i + bsz, N_ref)
        D2 = d2_block(Z_ref[i:j], Z_out, n2_ref[i:j], n2_out)

        # Precision: any out within r_ref of some ref in the block
        block_cov_out = D2 <= r2_ref[i:j, None]
        covered_out |= np.any(block_cov_out, axis=0)

        # Recall: any ref within r_out of some out
        block_cov_ref = D2 <= r2_out[None, :]
        covered_ref[i:j] |= np.any(block_cov_ref, axis=1)

        if covered_out.all() and covered_ref.all():
            break

    precision = float(covered_out.mean())
    recall = float(covered_ref.mean())

    info = {
        "precision": precision,
        "recall": recall,
        "n_components": int(m),
        "explained_variance": float(explained),
        "N_ref": int(N_ref),
        "N_out": int(N_out),
        "d_after": int(Z_ref.shape[1]),
        "impl": "NumPy, squared distances, early exit",
    }
    print(f"PR computation took {time.time() - start:.1f} seconds")
    return precision, recall, info


def PCA_(
    p,
    q,
    whiten=False,
    explained_variance=0.9,
):
    assert 0 < explained_variance < 1

    data1 = np.vstack([q, p])
    # data1 = normalize(data1, norm="l2", axis=1)
    pca = PCA(n_components=None, whiten=whiten)
    pca.fit(data1)

    s = np.cumsum(pca.explained_variance_ratio_)
    idx = np.argmax(s >= explained_variance)  # last index to consider

    data1 = pca.transform(data1)[:, : idx + 1]
    # Cluster
    data1 = data1.astype(np.float32)

    q_pca = data1[: len(q)]
    p_pca = data1[len(q) :]
    print(f"From dimension {p.shape} to {p_pca.shape}")
    return p_pca, q_pca


def compute_pairwise_distance(data_x, data_y=None):
    """
    Args:
        data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
        data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
    Returns:
        numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
    """
    if data_y is None:
        data_y = data_x
    dists = pairwise_distances(data_x, data_y, metric="euclidean", n_jobs=8)
    return dists


def compute_nearest_neighbour_distances(input_features, nearest_k):
    """
    Args:
        input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int
    Returns:
        Distances to kth nearest neighbours.
    """
    distances = compute_pairwise_distance(input_features)
    radii = get_kth_value(distances, k=nearest_k + 1, axis=-1)
    return radii


def compute_prdc2(real_features, fake_features, nearest_k):
    # Compute nearest neighbors for real and fake features
    real_nbrs = NearestNeighbors(n_neighbors=nearest_k).fit(real_features)
    fake_nbrs = NearestNeighbors(n_neighbors=nearest_k).fit(fake_features)

    # Compute nearest neighbor distances
    real_nearest_neighbour_distances = real_nbrs.kneighbors()[0][:, -1]
    fake_nearest_neighbour_distances = fake_nbrs.kneighbors()[0][:, -1]

    # Compute pairwise distances between real and fake features
    distance_real_fake = euclidean_distances(real_features, fake_features)
    print(distance_real_fake.shape)
    # Compute precision
    precision = (
        (distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1))
        .any(axis=0)
        .mean()
    )

    # Compute recall
    recall = (
        (distance_real_fake < np.expand_dims(fake_nearest_neighbour_distances, axis=0))
        .any(axis=1)
        .mean()
    )

    # Compute density
    density = (1.0 / float(nearest_k)) * (
        distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1)
    ).sum(axis=0).mean()

    # Compute coverage
    coverage = (
        distance_real_fake.min(axis=1) < real_nearest_neighbour_distances
    ).mean()

    return precision, recall, density, coverage


def compute_pr_from_embeddings(
    E_ref: np.ndarray,
    E_out: np.ndarray,
    k: int,
    pca_var: float,
    n_samples: int,
) -> Tuple[float, float, Dict[str, Any]]:
    """Wrapper to run pr_knn_from_embeddings with safe checks."""
    t1 = time.time()
    if pca_var < 1.0:
        E_ref, E_out = PCA_(
            E_ref,
            E_out,
            whiten=False,
            explained_variance=pca_var,
        )

    res_prdc = compute_prdc2(E_ref, E_out, nearest_k=k)
    t2 = time.time()
    print("total PRDC time:", round(t2 - t1, 2), "seconds")

    to_return = {
        "precision": res_prdc[0],
        "recall": res_prdc[1],
        "density": res_prdc[2],
        "coverage": res_prdc[3],
    }
    return to_return


# Example
if __name__ == "__main__":
    rng = np.random.default_rng(46)
    E_ref = rng.normal(size=(3000, 500)).astype(np.float32)
    E_out = rng.normal(loc=0.1, scale=1.0, size=(3000, 500)).astype(np.float32)
    k = 4
    pca_var = 0.90
    n_samples = 3000

    p, r, info = pr_knn_from_embeddings(
        E_ref, E_out, pca_var=pca_var, k=k, n_samples=n_samples
    )
    results = compute_pr_from_embeddings(
        E_ref, E_out, k=k, pca_var=pca_var, n_samples=n_samples
    )
    print(f"Precision: {p:.4f}  Recall: {r:.4f}")
    print(
        f"PCA components kept: {info['n_components']}  Explained variance: {info['explained_variance']:.3f}"
    )
    print(f"Results from PRDC method: {results}")
