import torch
import torch.nn.functional as F
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.stats import t, skew, kurtosis
from typing import List, Dict, Callable, Any

from tqdm import tqdm

from clustering.one_dimensional_clustering import OneDimensionalClustering


def summarize_values(values: List[float]) -> Dict[str, float]:
    """
    Compute summary statistics for a list of floats.
    Returns a dict with descriptive metrics.
    """
    arr = np.array(values)
    q25, q50, q75 = np.percentile(arr, [25, 50, 75])
    return {
        'count': float(arr.size),
        'min': float(np.min(arr)),
        '25%': float(q25),
        'median': float(q50),
        '75%': float(q75),
        'max': float(np.max(arr)),
        'mean': float(np.mean(arr)),
        'std': float(np.std(arr, ddof=1)),
        'variance': float(np.var(arr, ddof=1)),
        'IQR': float(q75 - q25),
        'skewness': float(skew(arr)),
        'kurtosis': float(kurtosis(arr, fisher=False))
    }


def test_partial_correlation(
    X: torch.Tensor, Y: torch.Tensor, Z: torch.Tensor
) -> float:
    """Gaussian partial-correlation p-values summary."""
    n, d = X.shape
    Z1 = torch.cat([Z, torch.ones(n, 1).to(Z.device)], dim=1)
    coef_y = torch.linalg.lstsq(Z1, Y.view(-1, 1), ).solution
    resid_y = Y.squeeze() - Z1 @ coef_y[:Z1.shape[1]].squeeze()

    coef_x = torch.linalg.lstsq(Z1, X.view(-1, 1), ).solution
    resid_x = X.squeeze() - Z1.squeeze() @ coef_x[:Z1.shape[1]].squeeze()

    rx_centered = resid_x - resid_x.mean()
    ry_centered = resid_y - resid_y.mean()

    corr = (rx_centered.T @ ry_centered) / (rx_centered.norm() * ry_centered.norm())

    return corr.item()



def test_cmi_knn(
    X: torch.Tensor, Y: torch.Tensor, Z: torch.Tensor, k: int = 5
) -> float:
    """KSG k-NN estimate of conditional mutual information."""
    from scipy.special import psi
    data_xyz = torch.cat([X, Y.unsqueeze(1), Z], dim=1).detach().cpu().numpy()
    data_xz = torch.cat([X, Z], dim=1).detach().cpu().numpy()
    data_yz = torch.cat([Y.unsqueeze(1), Z], dim=1).detach().cpu().numpy()
    data_z = Z.detach().cpu().numpy()
    nbrs = NearestNeighbors(n_neighbors=k+1)
    nbrs.fit(data_xyz)
    dist, _ = nbrs.kneighbors(data_xyz)
    eps = dist[:, k] - 1e-10

    def count(data):
        nbrs = NearestNeighbors().fit(data)
        counts = []
        for i, radius in enumerate(eps):
            neighbors = nbrs.radius_neighbors(data[i:i + 1], radius=radius, return_distance=False)[0]
            counts.append(len(neighbors) - 1)  # exclude the point itself
        return np.array(counts)

    nxz, nyz = count(data_xz), count(data_yz)#, count(data_z)
    return float(psi(k) + np.log(len(Z)) - np.mean(psi(nxz + 1) + psi(nyz + 1)))


def rbf_kernel(A: torch.Tensor, B: torch.Tensor, gamma: float = None) -> torch.Tensor:
    sq = (A.unsqueeze(1) - B.unsqueeze(0)).pow(2).sum(2)
    g = gamma if gamma is not None else 1.0 / A.shape[1]
    return torch.exp(-g * sq)


def test_kci(
    X: torch.Tensor, Y: torch.Tensor, Z: torch.Tensor, num_perm: int = 500
) -> float:
    """Kernel CI via HSIC residualization with permutations."""
    n = X.shape[0]
    Kx, Ky, Kz = rbf_kernel(X, X), rbf_kernel(Y.unsqueeze(1), Y.unsqueeze(1)), rbf_kernel(Z, Z)
    H = torch.eye(n, device=Kz.device) - (1 / n) * torch.ones(n, n, device=Kz.device) #torch.eye(n).to(X.device) - Kz @ torch.pinverse(Kz)
    stat = torch.trace(H @ Kx @ H @ H @ Ky @ H) / n
    count = 0
    device = X.device
    subsample_size: int = 100
    for _ in tqdm(range(num_perm)):

        idx = torch.randperm(n)[:subsample_size].to(device)
        X_s, Y_s, Z_s = X[idx], Y[idx], Z[idx]

        Kx_s = rbf_kernel(X_s, X_s)
        Kz_s = rbf_kernel(Z_s, Z_s)
        H_s = torch.eye(subsample_size, device=device) - Kz_s @ torch.pinverse(Kz_s)

        perm = torch.randperm(subsample_size).to(device)
        Y_perm = Y_s[perm]
        Ky_p = rbf_kernel(Y_perm.unsqueeze(1), Y_perm.unsqueeze(1))

        perm_stat = torch.trace(H_s @ Kx_s @ H_s @ H_s @ Ky_p @ H_s) / subsample_size
        if perm_stat >= stat:
            count += 1
    return float((count + 1) / (num_perm + 1))


def test_crt(
    X: torch.Tensor, Y: torch.Tensor, Z: torch.Tensor,
    sample_x_given_z: Callable[[torch.Tensor], torch.Tensor],
    test_stat: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], float],
    B: int = 100
) -> float:
    """Conditional Randomization Test p-value."""
    t0 = test_stat(X, Y, Z)
    count = sum(1 for _ in range(B) if test_stat(sample_x_given_z(Z), Y, Z) >= t0)
    return float((count + 1) / (B + 1))


def clustering_correlation(X: torch.Tensor, Y: torch.Tensor, Z: torch.Tensor):
    clustering = OneDimensionalClustering()
    clustering.fit(Z)
    clusters = clustering.predict_cluster(Z)
    n_clusters = torch.max(clusters).int().item() + 1
    correlations = []
    for c in range(n_clusters):
        bin_idx = (clusters == c).squeeze()
        x_bin = X[bin_idx]
        y_bin = Y[bin_idx]
        correlations.append(np.corrcoef(x_bin.detach().cpu().numpy().squeeze(), y_bin.detach().cpu().numpy().squeeze())[0, 1])
    correlations = [c.item() for c in correlations if not np.isnan(c)]
    return max(correlations, key=abs), np.mean(np.abs(correlations)).item()

def run_all_tests(
    X: torch.Tensor, Y: torch.Tensor, Z: torch.Tensor,
    sample_x_given_z: Callable[[torch.Tensor], torch.Tensor] = None,
    test_stat: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], float] = None
) -> Dict[str, Any]:
    """
    Execute all CI tests and return a dict mapping metric names to values.
    If sample_x_given_z and test_stat are provided, includes CRT.
    """
    results: Dict[str, Any] = {}
    # Partial correlation summary
    pc = test_partial_correlation(X, Y, Z)
    clustering_corr, mean_corr = clustering_correlation(X,Y,Z)
    results[f'partial_correlation'] = pc

    results[f'clustering_correlation'] = clustering_corr
    results[f'mean_clustering_correlation'] = mean_corr
    # Residual permutation test
    # results['residual_permutation_pval'] = test_residual_independence(X, Y, Z)
    # k-NN CMI
    results['cmi_knn'] = test_cmi_knn(X, Y, Z)
    # KCI test
    results['kci_pval'] = test_kci(X, Y, Z)
    # Optional CRT
    if sample_x_given_z and test_stat:
        results['crt_pval'] = test_crt(X, Y, Z, sample_x_given_z, test_stat)
    return results
