import numpy as np

import torch
from torch_geometric.data import Data


# def gaussian_kernel(distance, sigma=1.0):
#     return np.exp(-distance**2 / (2 * sigma**2))


def gaussian_kernel_vectorized(distances, bandwidth):
    """
    Vectorized Gaussian Kernel function.
    - distances: Pairwise distances (query points - data points), shape (n_queries, n_samples)
    - bandwidth: Bandwidth parameter for smoothing
    """
    # exponent = -0.5 * (distances / bandwidth) ** 2
    # return np.exp(exponent) / (np.sqrt(2 * np.pi) * bandwidth)
    return np.exp(-distances ** 2 / (2 * bandwidth ** 2))


def create_adj_from_dist(d, method='knn', k=None, sigma=None, keep_edge_weight=True):
    assert method in ['knn', 'gaussian']
    if method == 'knn':
        # Sort the distance matrix along each row and get indices of the k smallest distances (excluding self)
        nearest_indices = np.argsort(d, axis=1)[:, 1:k + 1]

        # Create an empty adjacency matrix
        adjacency_matrix = np.zeros_like(d, dtype=int)

        row_indices = np.arange(d.shape[0]).repeat(k)

        if keep_edge_weight:
            adjacency_matrix[row_indices, nearest_indices.flatten()] = d[row_indices, nearest_indices.flatten()]
        else:
            # Use advanced indexing to set the nearest neighbors to 1
            adjacency_matrix[row_indices, nearest_indices.flatten()] = 1
            
        row, col = np.where(adjacency_matrix > 0)  # Find indices of non-zero elements
        edge_index = torch.tensor([row, col], dtype=torch.long)  # Convert to PyTorch tensor
        edge_weights = torch.from_numpy(adjacency_matrix[row, col]).to(torch.float32)

        # Create PyG graph
        num_nodes = adjacency_matrix.shape[0]
        data = Data(edge_index=edge_index, num_nodes=num_nodes, edge_weight=edge_weights)
        return data

    elif method == 'gaussian':
        bandwidth = np.median(d)
        gaussian_weights = gaussian_kernel_vectorized(d, sigma * bandwidth)

        # Extract the upper triangle of the Gaussian weights matrix, excluding diagonal
        upper_triangle_indices = np.triu_indices_from(gaussian_weights, k=1)
        upper_triangle_weights = gaussian_weights[upper_triangle_indices]

        # Compute the median of the upper triangle
        median_threshold = np.median(upper_triangle_weights)

        # find edge index
        src, dst = np.where(gaussian_weights >= median_threshold)
        mask = src != dst  # Ensure src != dst (no self-loops)
        src, dst = src[mask], dst[mask]

        if keep_edge_weight:
            edge_weights = torch.from_numpy(gaussian_weights[src, dst]).to(torch.float32)
        else:
            edge_weights = None

        # Create PyG graph
        edge_index = torch.tensor([src, dst], dtype=torch.long)
        num_nodes = gaussian_weights.shape[0]
        data = Data(edge_index=edge_index, num_nodes=num_nodes, edge_weight=edge_weights)
        return data

    # return adjacency_matrix


