import numpy as np


# ========== Unified distance calculation function ==========
def pairwise_distances(Z, metric="l2", max_samples=None, return_triu=False):
    """
    Calculate pairwise distance matrix (unified function).
    
    Parameters:
        Z: Sample matrix, shape (n, d), or tuple (X, Y) for cross-distances
        metric: "l1" or "l2"
        max_samples: Optional limit on number of samples (for backward compatibility)
        return_triu: If True, return only upper triangular part (for backward compatibility)
    
    Returns:
        Distance matrix, shape (n, n) or (m, n) if Z is tuple
    """
    if metric not in ("l1", "l2"):
        raise ValueError("metric must be 'l1' or 'l2'")
    
    # Handle tuple input (X, Y) for cross-distances
    if isinstance(Z, tuple):
        X, Y = Z
        X, Y = np.asarray(X), np.asarray(Y)
        if max_samples is not None:
            X, Y = X[:max_samples], Y[:max_samples]
        diff = X[:, None, :] - Y[None, :, :]
    else:
        Z = np.asarray(Z)
        if max_samples is not None:
            Z = Z[:max_samples]
        diff = Z[:, None, :] - Z[None, :, :]
    
    # Compute distance based on metric
    output = np.sum(np.abs(diff), axis=-1) if metric == "l1" else np.sqrt(np.sum(diff**2, axis=-1))
    
    if return_triu:
        return output[np.triu_indices(output.shape[0], k=1)]
    return output


# ========== Unified kernel matrix function ==========
def kernel_matrix(pairwise_matrix, kernel, bandwidth, metric=None, rq_kernel_exponent=0.5):
    """
    Compute kernel matrix for a given kernel and bandwidth (unified function).
    
    Parameters:
        pairwise_matrix: (n, n) matrix of pairwise distances
        kernel: Kernel type string, e.g., "gaussian", "laplace", "imq", "matern_*", etc.
        bandwidth: Bandwidth parameter
        metric: "l1" or "l2" (optional, auto-inferred from kernel name if not provided)
        rq_kernel_exponent: Exponent parameter for Rational quadratic kernel (default: 0.5)
    
    Returns:
        K: (n, n) kernel matrix
    """
    # Auto-infer metric from kernel name if not provided
    if metric is None:
        if kernel in ("laplace", "matern_0.5_l1", "matern_1.5_l1", "matern_2.5_l1", 
                      "matern_3.5_l1", "matern_4.5_l1"):
            metric = "l1"
        else:
            metric = "l2"
    
    metric_value = metric
    
    d = pairwise_matrix / bandwidth
    
    if kernel == "gaussian" and metric_value == "l2":
        return np.exp(-(d**2) / 2)
    elif kernel == "laplace" and metric_value == "l1":
        return np.exp(-d * np.sqrt(2))
    elif kernel == "rq" and metric_value == "l2":
        return (1 + d**2 / (2 * rq_kernel_exponent)) ** (-rq_kernel_exponent)
    elif kernel == "imq" and metric_value == "l2":
        return (1 + d**2) ** (-0.5)
    elif (kernel == "matern_0.5_l1" and metric_value == "l1") or (kernel == "matern_0.5_l2" and metric_value == "l2"):
        return np.exp(-d)
    elif (kernel == "matern_1.5_l1" and metric_value == "l1") or (kernel == "matern_1.5_l2" and metric_value == "l2"):
        return (1 + np.sqrt(3) * d) * np.exp(-np.sqrt(3) * d)
    elif (kernel == "matern_2.5_l1" and metric_value == "l1") or (kernel == "matern_2.5_l2" and metric_value == "l2"):
        return (1 + np.sqrt(5) * d + 5 / 3 * d**2) * np.exp(-np.sqrt(5) * d)
    elif (kernel == "matern_3.5_l1" and metric_value == "l1") or (kernel == "matern_3.5_l2" and metric_value == "l2"):
        return (1 + np.sqrt(7) * d + 2 * 7 / 5 * d**2 + 7 * np.sqrt(7) / 3 / 5 * d**3) * np.exp(-np.sqrt(7) * d)
    elif (kernel == "matern_4.5_l1" and metric_value == "l1") or (kernel == "matern_4.5_l2" and metric_value == "l2"):
        return (1 + 3 * d + 3 * (6**2) / 28 * d**2 + (6**3) / 84 * d**3 + (6**4) / 1680 * d**4) * np.exp(-3 * d)
    else:
        raise ValueError(f'Invalid combination of metric="{metric_value}" and kernel="{kernel}"')


# ========== Unified bandwidth calculation functions ==========
def compute_bandwidths_from_distances(distances, number_bandwidths):
    """
    Calculate bandwidth list from distance array (unified function).
    
    Parameters:
        distances: NumPy array, distance values (1D or 2D)
        number_bandwidths: Number of bandwidths
    
    Returns:
        bandwidths: NumPy array, bandwidth list
    """
    distances = np.asarray(distances, dtype=float)
    median = np.median(distances)
    distances = distances + (distances == 0) * median
    dd = np.sort(distances.flatten())
    lambda_min = dd[int(np.floor(len(dd) * 0.05))] / 2
    lambda_max = dd[int(np.floor(len(dd) * 0.95))] * 2
    return np.linspace(lambda_min, lambda_max, number_bandwidths)


def get_median_bandwidth(Z, metric="l2"):
    """
    Calculate median of distances as default bandwidth.
    
    Parameters:
        Z: Sample matrix, shape (n, d)
        metric: "l1" or "l2"
    
    Returns:
        median_bandwidth: Median bandwidth
    """
    distances = pairwise_distances(Z, metric=metric, return_triu=True)
    return np.median(distances)

