import torch

def ensure_2d(tensor):
    if not torch.is_tensor(tensor):
        tensor = torch.tensor(tensor, dtype=torch.float32)
    if tensor.ndim == 1:
        tensor = tensor.unsqueeze(1)  # Convert shape (n,) → (n,1)
    elif tensor.ndim > 2:
        raise ValueError(f"Expected a 1D or 2D tensor, but got shape {tensor.shape}")
    return tensor

def global_conditional_distance_correlation(
    X: torch.Tensor,
    Y: torch.Tensor,
    Z: torch.Tensor,
    h: float = 0.1,
    eps: float = 1e-12
) -> torch.Tensor:
    """
    One‑shot O(n^2) global conditional distance correlation.
    Returns a scalar in [0,1].
    """    
    X = ensure_2d(X).float()
    Y = ensure_2d(Y).float()
    Z = ensure_2d(Z).float()
    n = X.shape[0]

    DX = torch.cdist(X, X, p=2)
    DY = torch.cdist(Y, Y, p=2)

    D2Z = torch.cdist(Z, Z, p=2) ** 2
    K = torch.exp(-D2Z / (h))

    W     = K
    W_sum = W.sum()
    w_row = W.sum(dim=1)
    w_col = W.sum(dim=0)

    A_bar_row = (W * DX).sum(dim=1) / w_row.clamp(min=eps)
    A_bar_col = (W * DX).sum(dim=0) / w_col.clamp(min=eps)
    A_bar_all = (W * DX).sum() / W_sum.clamp(min=eps)

    A = (
        DX
        - A_bar_row.unsqueeze(1)
        - A_bar_col.unsqueeze(0)
        + A_bar_all
    )

    B_bar_row = (W * DY).sum(dim=1) / w_row.clamp(min=eps)
    B_bar_col = (W * DY).sum(dim=0) / w_col.clamp(min=eps)
    B_bar_all = (W * DY).sum() / W_sum.clamp(min=eps)

    B = (
        DY
        - B_bar_row.unsqueeze(1)
        - B_bar_col.unsqueeze(0)
        + B_bar_all
    )

    num   = (W * A * B).sum() / W_sum.clamp(min=eps)
    varX  = (W * A * A).sum() / W_sum.clamp(min=eps)
    varY  = (W * B * B).sum() / W_sum.clamp(min=eps)

    rho = num / torch.sqrt((varX * varY + 1e-12).clamp(min=eps))

    return rho**2


def global_conditional_distance_correlation_true(
    X, Y, Z, h=1.0, eps=1e-12, ref_indices=None, subsample_method='uniform', c=None, method="max",
):
    """
    Compute global conditional distance correlation by averaging local correlations.
    
    Arguments:
    - X: tensor of shape (n, d_x)
    - Y: tensor of shape (n, d_y)
    - Z: tensor of shape (n, d_z)
    - h: bandwidth for RBF kernel in Z
    - eps: small constant to avoid zero denominators
    - ref_indices: optional tensor/list of reference indices of shape (c,), integers in [0,n).
                   If None and c is None: use all indices 0..n-1 (i.e., full tensorization).
    - subsample_method: if ref_indices is None but c is given (<n), method to choose references:
        - 'uniform': uniform random sample of c indices
        - 'importance': sample with probability proportional to sum of kernel weights per point
    - c: if ref_indices is None and c is provided (<n), number of reference points to sample.
         If None, defaults to n (i.e., use all references).
    
    Returns:
    - rho_locals: tensor of shape (c,) giving local distance correlations at each reference
    - rho_global: scalar = mean(rho_locals)
    """
    X = ensure_2d(X).float()
    Y = ensure_2d(Y).float()
    Z = ensure_2d(Z).float()
    n = X.shape[0]
    device = X.device

    DX = torch.cdist(X, X, p=2)
    DY = torch.cdist(Y, Y, p=2)
    
    D2_Z = torch.cdist(Z, Z, p=2)**2
    Kz = torch.exp(-D2_Z / (2.0 * h * h))
    
    # Determine reference indices
    if ref_indices is None:
        if c is None or c >= n:
            # take 50% of all points as references
            c = n // 5 if n >= 5 else n  # at least 1 reference
            # sample c references
            if subsample_method == 'uniform':
                perm = torch.randperm(n, device=device)
                ref_indices = perm[:c]
            elif subsample_method == 'importance':
                # importance sampling: weight by kernel sum per point
                sums = Kz.sum(dim=1)  # shape (n,)
                probs = sums / sums.sum()
                ref_indices = torch.multinomial(probs, num_samples=c, replacement=False)
            else:
                raise ValueError(f"Unknown subsample_method: {subsample_method}")
        else:
            # sample c references
            if subsample_method == 'uniform':
                perm = torch.randperm(n, device=device)
                ref_indices = perm[:c]
            elif subsample_method == 'importance':
                # importance sampling: weight by kernel sum per point
                sums = Kz.sum(dim=1)
                probs = sums / sums.sum()
                ref_indices = torch.multinomial(probs, num_samples=c, replacement=False)
            else:
                raise ValueError(f"Unknown subsample_method: {subsample_method}")
    else:
        ref_indices = torch.as_tensor(ref_indices, device=device, dtype=torch.long)
        c = ref_indices.shape[0]
    if c is None:
        c = ref_indices.shape[0]

    K_ref = Kz[ref_indices, :]
    W_ref = K_ref / (K_ref.sum(dim=1, keepdim=True) + eps)
    

    M1_X = W_ref @ DX
    M2_X = (W_ref @ DX.T)
    
    a_all_X = (W_ref * M1_X).sum(dim=1)

    M1_Y = W_ref @ DY
    M2_Y = (W_ref @ DY.T)
    b_all_Y = (W_ref * M1_Y).sum(dim=1)

    DX_expand = DX.unsqueeze(0)
    M2_X_expand = M2_X.unsqueeze(2)
    M1_X_expand = M1_X.unsqueeze(1)
    A = DX_expand - M2_X_expand - M1_X_expand + a_all_X.view(c,1,1)

    DY_expand = DY.unsqueeze(0)
    M2_Y_expand = M2_Y.unsqueeze(2)
    M1_Y_expand = M1_Y.unsqueeze(1)
    B = DY_expand - M2_Y_expand - M1_Y_expand + b_all_Y.view(c,1,1)

    W_outer = W_ref.unsqueeze(2) * W_ref.unsqueeze(1)

    num = (W_outer * A * B).sum(dim=(1,2))

    varX = (W_outer * A * A).sum(dim=(1,2))
    varY = (W_outer * B * B).sum(dim=(1,2))

    # Local distance correlations:
    denom = torch.sqrt(varX.clamp(min=eps) * varY.clamp(min=eps) + 1e-12)
    rho_locals = num / denom

    # Global aggregation
    if method in ["max", "standard"]:
        rho_global = rho_locals.max()
    elif method == "mean":
        rho_global = rho_locals.mean()
    else:
        raise ValueError(f"Unknown method: {method}. Use 'max' or 'mean'.")

    return rho_global

def classic_dcor_center(A):
    # A: (n, n)
    A_bar_row = A.mean(dim=1, keepdim=True)
    A_bar_col = A.mean(dim=0, keepdim=True)
    A_bar_all = A.mean()
    # Double-centering
    return A - A_bar_row - A_bar_col + A_bar_all

def classical_distance_correlation(X,Y,eps=1e-12):
    # X, Y: (n, d), torch tensors
    X = ensure_2d(X)
    Y = ensure_2d(Y)

    n = X.shape[0]

    DX = torch.cdist(X, X, p=2)
    DY = torch.cdist(Y, Y, p=2)

    AX = classic_dcor_center(DX)
    AY = classic_dcor_center(DY)

    numerator = torch.sum(AX * AY) / n**2
    denom_X = torch.sum(AX * AX) / n**2
    denom_Y = torch.sum(AY * AY) / n**2

    return (numerator / (torch.sqrt(denom_X * denom_Y + 1e-12).clamp(min=eps)))


# debug and print
if __name__ == "__main__":
    X = torch.randn(1000, 3)
    Y = torch.randn(1000, 3)
    Z = torch.randn(1000, 3)

    print(global_conditional_distance_correlation(X, Y, Z, h=0.5))
    print(global_conditional_distance_correlation_true(X, Y, Z, h=0.5, c=10, subsample_method='uniform'))
