import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from typing import Optional


class BinaryConcrete:
    """
    Binary concrete (relaxed Bernoulli) relaxation.
    Discretize a continuous variable to binary using Gumbel-Softmax trick.
    """
    def __init__(self, temperature: float = 0.1):
        self.temperature = temperature

    def relax_binary_concrete(self, p: torch.Tensor, eps: float = 1e-10):
        """
        Apply binary concrete (relaxed Bernoulli) relaxation.
        In effect samed as gumbel_softmax for binary.
        """
        u = torch.rand_like(p)
        gumbel = torch.log(u) - torch.log1p(-u)
        logits = torch.log(p) - torch.log1p(-p)
        relaxed_sample = torch.sigmoid((logits + gumbel) / self.temperature)
        return relaxed_sample

    def hard_binary_concrete(self, p: torch.Tensor):
        relaxed = self.relax_binary_concrete(p)
        hard = (relaxed > 0.5).float()
        return (hard - relaxed).detach() + relaxed


def top_k_mask(scores: torch.Tensor, k: int, dim=1):
    """
    Top-k sparsification.
    scores: real-valued logits
    k: number of edges to keep
    """
    if k == 0:
        return torch.zeros_like(scores)
    
    top_k_indices = torch.topk(scores, k=k, dim=dim).indices
    
    mask = torch.zeros_like(scores)
    mask.scatter_(1, top_k_indices, 1)
    return mask
    

def prob_eps_mask(probs: torch.Tensor, eps: float):
    """
    Epsillon sparsification.
    scores: real-valued logits
    eps: sparsification threshold
    """
    mask = (probs >= eps).float()
    return mask
    

def add_self_loop_adj_sparse(A: torch.Tensor):
    n_nodes = A.size(0)
    node_ids = torch.arange(n_nodes, device=A.device)
    # sparse identity
    I = torch.sparse_coo_tensor(
        torch.stack([node_ids, node_ids], dim=0),
        torch.ones(n_nodes, device=A.device),
        (n_nodes, n_nodes)
    )
    A_hat = (A + I).coalesce()
    clamped_vals = A_hat.values().clamp_max(1.0)
    A_hat = torch.sparse_coo_tensor(
        A_hat.indices(), clamped_vals, (n_nodes, n_nodes)).coalesce()
    clamped_vals = A_hat.values().clamp_max(1.0)
    A_hat = torch.sparse_coo_tensor(
        A_hat.indices(), clamped_vals, (n_nodes, n_nodes)).coalesce()
    return A_hat


def get_deg_adj_sparse(A: torch.Tensor):
    row, _ = A.indices()
    deg = torch.zeros(A.size(0), device=A.device)
    deg.index_add_(0, row, A.values())
    return deg


def normalize_adj_sparse(A: torch.Tensor):
    n_nodes = A.size(0)
    A_hat = add_self_loop_adj_sparse(A)
    deg = get_deg_adj_sparse(A_hat)

    # inv-sqrt, zero out infs
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0

    # build normalized values: v_norm[i] = v * D^{-1/2}[row] * D^{-1/2}[col]
    i2 = A_hat.indices()
    v2 = A_hat.values()
    v_norm = deg_inv_sqrt[i2[0]] * v2 * deg_inv_sqrt[i2[1]]

    # return sparse normalized adjacency
    A_norm = torch.sparse_coo_tensor(
        i2, v_norm, (n_nodes, n_nodes)).to(A.device)
    return A_norm.coalesce()


def normalize_adj(A: torch.Tensor):
    I = torch.eye(A.size(0), device=A.device)
    A_hat = A + I
    A_hat[A_hat > 1] = 1
    
    D = A_hat.sum(dim=1)
    D_inv_sqrt = torch.pow(D, -0.5)  # Element-wise inverse sqrt
    D_inv_sqrt[D_inv_sqrt == float('inf')] = 0

    D_inv_sqrt_mat = torch.diag(D_inv_sqrt)
    A_norm = D_inv_sqrt_mat @ A_hat @ D_inv_sqrt_mat
    return A_norm

def inverse_sigmoid(y):
    return torch.log(y / (1 - y))


def remove_flipped_duplicates(edge_index):
    sorted_edges = torch.sort(edge_index, dim=0)[0]  # sort each pair
    unique_edges = torch.unique(sorted_edges, dim=1)
    return unique_edges


def filter_edges(edge1: torch.Tensor,
                 edge2: torch.Tensor,
                 n_nodes: Optional[int] = None,) -> torch.Tensor:
    """
    Remove from `edge1` any edges that also appear in `edge2`, 
    in either direction.
    """
    if n_nodes is None:
        n_nodes = int(max(edge1.max(), edge2.max())) + 1
    # linearize (u, v) -> u * N + v
    lin1 = edge1[0] * n_nodes + edge1[1]
    lin2 = edge2[0] * n_nodes + edge2[1]
    lin2_rev = edge2[1] * n_nodes + edge2[0]

    remove = torch.cat([lin2, lin2_rev], dim=0)

    # mask out any entries of edge1 that are in remove
    mask = ~torch.isin(lin1, remove)

    return mask


def get_knn_edge_index(x: torch.Tensor,
                 k: int,
                 metric: str = "euclidean") -> torch.Tensor:
    """
    Get the k nearest neighbors for each node in the edge index.
    """
    n_nodes = x.size(0)
    if metric == "euclidean":
        dist = torch.cdist(x, x, p=2)
    elif metric == "manhattan":
        dist = torch.cdist(x, x, p=1)
    elif metric == "cosine":
        x_norm = F.normalize(x, p=2, dim=1)
        sim = x_norm @ x_norm.t()
        dist = 1.0 - sim
    else:
        raise ValueError(f"Unsupported metric '{metric}'. Choose from 'euclidean', 'manhattan', 'cosine'.")

    # set distances to self to inf
    dist.fill_diagonal_(float("inf"))

    # find k smallest distances for each node
    _, idx = dist.topk(k, largest=False)

    # Build edge index: each i repeated k times paired with its k neighbors
    src = torch.arange(n_nodes, device=x.device).unsqueeze(1).repeat(1, k)
    dst = idx 
    edge_index = torch.stack([src.reshape(-1), dst.reshape(-1)], dim=0)
    return edge_index
