import torch
import dgl
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import networkx as nx


def gaussian_kernel_graph_cut(pairwise_dist, k, sigma):
    bandwidth = torch.median(pairwise_dist) * sigma  # Standard deviation as scaling factor
    gaussian_weights = torch.exp(-pairwise_dist ** 2 / (2 * bandwidth ** 2))  # Gaussian similarity

    N = pairwise_dist.shape[0]

    # Step 3: Select k-Nearest Neighbors (Graph Cut)
    if k > 0:
        # Get top-k nearest neighbors per row (excluding self-loops)
        _, knn_indices = torch.topk(-pairwise_dist, k=k + 1, dim=1)  # Negative for smallest distances
        knn_indices = knn_indices[:, 1:]  # Remove self-loop (first column)
        # Create edge list (source, destination)
        src = torch.repeat_interleave(torch.arange(N), k)  # Source nodes (repeat each node k times)
        dst = knn_indices.reshape(-1)  # Flattened destination nodes
    else:
        # cut using median

        upper_triangle_indices = np.triu_indices_from(gaussian_weights.numpy(), k=1)
        upper_triangle_weights = gaussian_weights[upper_triangle_indices]

        # Compute the median of the upper triangle
        median_threshold = torch.median(upper_triangle_weights)
        src, dst = torch.where(gaussian_weights >= median_threshold)
        # edge_weights = gaussian_weights[src, dst]

        mask = src != dst  # Ensure src != dst (no self-loops)
        src, dst = src[mask], dst[mask]

    # Step 4: Extract Corresponding Gaussian Similarity Values
    edge_weights = gaussian_weights[src, dst]

    # Step 5: Normalize Edge Weights per Node
    # Ensuring sum of outgoing edges per node equals 1
    # src_unique = src.unique()
    # for s in src_unique:
    #     mask = src == s  # Get all edges for node s
    #     edge_weights[mask] = edge_weights[mask] / edge_weights[mask].sum()  # Normalize

    # Step 6: Construct Sparse Graph in DGL
    g = dgl.graph((src, dst), num_nodes=N)
    g.edata['weight'] = edge_weights.to(torch.float32)  # Assign normalized edge importance
    return g


def get_complet_graph(pairwise_dist, sigma=1.0, method='gaussian'):
    pairwise_dist = pairwise_dist.to(torch.float32)
    N = pairwise_dist.shape[0]
    if method == 'gaussian':
        bandwidth = torch.median(pairwise_dist) * sigma  # Standard deviation as scaling factor
        gaussian_weights = torch.exp(-pairwise_dist ** 2 / (2 * bandwidth ** 2))  # Gaussian similarity
        weights_a = gaussian_weights
    elif method == 'student_t':
        a = 1 + pairwise_dist ** 2
        weights_a = a ** (-1)
    elif method == 'linear':
        # a = (torch.max(pairwise_dist) - pairwise_dist) / torch.max(pairwise_dist)
        a = 1 / pairwise_dist
        mask = torch.eye(a.shape[0], dtype=torch.bool, device=a.device)
        weights_a = a.masked_fill(mask, float(1.0))
    elif method == 'softmax':
        inv_zdist = -1 * pairwise_dist
        mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=inv_zdist.device)
        inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
        weights_a = torch.nn.functional.softmax(inv_zdist, dim=1) + 1e-9
        weights_a = weights_a.masked_fill(mask, float(1.0))
    else:
        raise NotImplementedError

    nan_entry = pairwise_dist[torch.where(torch.isnan(weights_a) == True)]
    # print(nan_entry, bandwidth)
    # print(-nan_entry[:2] ** 2 / (2 * bandwidth ** 2))
    # print(torch.exp(-nan_entry[:2] ** 2 / (2 * bandwidth ** 2)))

    src, dst = torch.where(torch.isnan(weights_a) == False)
    edge_weights = weights_a[src, dst].to(torch.float32)
    edge_indx = torch.stack((src, dst)).long()
    return edge_indx, edge_weights


if __name__ == "__main__":
    # Step 1: Generate Random Node Features and Compute Pairwise Distances
    N = 100  # Number of nodes
    D = 2  # Feature dimension

    # Generate random node features
    node_features = torch.rand(N, D)

    # Compute pairwise Euclidean distance matrix
    pairwise_dist = torch.cdist(node_features, node_features, p=2)


    g = gaussian_kernel_graph_cut(pairwise_dist, k=20, sigma=1)
    edge_weights = g.edata['weight']

    # Step 7: Visualize Graph Connectivity (Optional)
    plt.figure(figsize=(6, 6))

    edge_weights_np = edge_weights.numpy()
    edge_weights_np = (edge_weights_np - edge_weights_np.min()) / (edge_weights_np.max() - edge_weights_np.min())

    # Convert to valid RGBA color mapping
    edge_colors = [mcolors.to_rgba(plt.cm.jet(w)) for w in edge_weights_np]  # Use colormap

    # Draw the graph with valid edge colors
    plt.figure(figsize=(6, 6))
    nx_graph = g.to_networkx().to_undirected()
    pos = np.array(node_features[:, :2])  # Use first two dimensions for visualization
    nx.draw(nx_graph, pos, node_size=50, edge_color=edge_colors, alpha=0.6)

    plt.title("Sparse k-NN Graph with Normalized Gaussian Reweighted Edges")
    plt.show()

    # Print Summary
    print(f"Graph has {g.num_nodes()} nodes and {g.num_edges()} edges")
    print("Sample edge weights after reweighting:", edge_weights[:10])
