import torch
import torch_sparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import functional, layer, neuron
class HypergraphProcessor(nn.Module):
    def __init__(self, in_channels=256,out_channels=256,k=5):
        """
        Initialize hypergraph processor
        :param k: K nearest neighbors per node
        """
        super(HypergraphProcessor, self).__init__()
        self.k = k
        self.linear = nn.Conv1d(in_channels, out_channels, 1)
        self.bn = nn.BatchNorm1d(out_channels)
        self.sn = neuron.IFNode(step_mode='m', backend='cupy')  # spiking neuron node

    def generate_hypergraph_batch(self, X_batch):
        """
        Generate hypergraph adjacency H in batch
        :param X_batch: input features [T, C, V]
        :return: sparse hypergraph adjacency H
        """
        T, C, V = X_batch.size()
        device = X_batch.device

        # flatten node features to [T * V, C]
        X_flat = X_batch.permute(1, 2, 0).reshape(T * V, C)

        # pairwise distance
        dist_matrix = torch.cdist(X_flat.unsqueeze(0), X_flat.unsqueeze(0), p=2).squeeze(0)

        # kNN indices per node
        knn_edges = torch.topk(-dist_matrix, k=self.k, dim=1).indices

        # build row/col
        num_nodes = T * V
        row = knn_edges.flatten()
        col = torch.arange(num_nodes, device=device).repeat_interleave(self.k)

        # weights (inverse distance)
        value = torch.exp(-dist_matrix[row, col])

        # build sparse adjacency H
        H = torch_sparse.SparseTensor(row=row,
                                       col=col,
                                       value=value,
                                       sparse_sizes=(num_nodes, num_nodes)).to(device)
        return H

    def forward(self, X):
        """
        Forward: build sparse hypergraph adjacency and apply conv
        :param X: input tensor [N, T, C, V]
        :return: output [N, T, C, V]
        """
        T,N, C, V = X.shape
        X=X.permute(0, 1, 3, 2)
        num_nodes = T * V
        device = X.device

        # build sparse hypergraph adjacency H
        H_list = [self.generate_hypergraph_batch(X[i]) for i in range(N)]

        # flatten in batch
        X_flat = X.permute(0, 1, 3, 2).reshape(N, num_nodes, C).to(device)

        # batched hypergraph matmul
        X_out_flat = torch.stack([H_list[i].matmul(X_flat[i]) for i in range(N)], dim=0)

        # restore shape
        X_out = X_out_flat.view(N, T, V, C).permute(0, 1, 3, 2)
        X_out = X_out.flatten(0,1)
        X_out = self.linear(X_out)
        X_out = self.bn(X_out).view(T,N,C,V)
        X_out = self.sn(X_out)
        return X_out

def main():
    """
    Main: build sparse hypergraph adjacency and apply convolution.
    """
    # check GPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # parameters
    N, T, C, V = 32, 16, 256, 25  # batch, time, channels, joints
    k = 3  # kNN per node

    # input X
    X = torch.randn(T,N,C, V).to(device)

    # init processor
    processor = HypergraphProcessor(k=k)

    # forward
    X_out = processor(X)

    # print shapes
    print("input shape:", X.shape)
    print("output shape:", X_out.shape)

if __name__ == "__main__":
    main()
