import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import functional, layer, neuron
from torch.nn.parameter import Parameter
from torch_sparse import SparseTensor
class HypergraphProcessor(nn.Module):
    def __init__(self, in_channels=256,out_channels=256,k=5):

        super(HypergraphProcessor, self).__init__()
        self.k = k
        self.linear = nn.Linear(in_channels, out_channels)
        self.bn = nn.BatchNorm1d(out_channels)
        self.sn = neuron.ParametricLIFNode(step_mode='m', backend='cupy')

    def generate_hypergraph_batch(self, X_batch):

        T, C, V = X_batch.size()  
        device = X_batch.device  

      
        X_flat = X_batch.permute(1, 2, 0).reshape(T * V, C)  # [T * V, C]

   
        dist_matrix = torch.cdist(X_flat.unsqueeze(0), X_flat.unsqueeze(0), p=2).squeeze(0)  # [T * V, T * V]

        knn_edges = torch.topk(-dist_matrix, k=self.k, dim=1).indices  # [T * V, k]

      
        num_nodes = T * V
        row = knn_edges.flatten()  
        col = torch.arange(num_nodes, device=device).repeat_interleave(self.k) 


        value = torch.exp(-dist_matrix[row, col])  

        H = SparseTensor(row=row,col=col,value=value,sparse_sizes=(num_nodes, num_nodes)).to(device)

        return H, row, col, value

    def forward(self, X):
 
        T,N, C, V = X.shape
        X=X.permute(1,0,2,3)
        num_nodes = T * V
        device = X.device  


        H_list, rows, cols, values = [], [], [], []
        for i in range(N):
            H, row, col, value = self.generate_hypergraph_batch(X[i])
            H_list.append(H)
            hypergraph_info = {
                'H_row': row.cpu().numpy(),
                'H_col': col.cpu().numpy(),
                'H_values': value.cpu().detach().numpy(),
                }

        X_flat = X.permute(0, 1, 3, 2).reshape(N, num_nodes, C).to(device)  # [N, T * V, C]

  
        X_out_flat = torch.stack([H_list[i].matmul(X_flat[i]) for i in range(N)], dim=0)  # [N, T * V, C]

        X_out = X_out_flat.view(N, T, V, C)
        X_out = X_out.flatten(0,1)
        X_out = self.linear(X_out)
        X_out = X_out.transpose(-1, -2)
        X_out = self.bn(X_out).view(T,N,C,V)
        X_out = self.sn(X_out)
        return X_out, hypergraph_info
    
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1_linear = nn.Linear(in_features, hidden_features)
        self.fc1_bn = nn.BatchNorm1d(hidden_features)
        self.fc1_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy')

        self.fc2_linear = nn.Linear(hidden_features, out_features)
        self.fc2_bn = nn.BatchNorm1d(out_features)
        self.fc2_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy')

        self.c_hidden = hidden_features
        self.c_output = out_features

    def forward(self, x):
        x = x.transpose(-1, -2)
        T,B,N,C = x.shape
        x_ = x.flatten(0, 1)
        x = self.fc1_linear(x_)
        x = self.fc1_bn(x.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, self.c_hidden).contiguous()
        x = self.fc1_lif(x)

        x = self.fc2_linear(x.flatten(0,1))
        x = self.fc2_bn(x.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous()
        x = self.fc2_lif(x).transpose(-1, -2)
        return x



class GSA(nn.Module):
    """
    Global Spiking Attention module as described in the paper:
    - Input: (T, B, C, N)
    - Spatial attention: 1x1 conv to reduce channels (C -> 1), BN, spiking activation
    - Channel attention: grouped 1x1 conv (groups=G), BN, spiking activation
    - Fusion: element-wise multiplication + residual connections
    """
    def __init__(self, channel=128, G=8, backend='cupy'):
        super().__init__()
        assert channel % G == 0, "channel must be divisible by G"
        self.C = channel
        self.G = G

        # Spatial attention branch
        self.conv_spatial = nn.Conv1d(channel, 1, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn_spatial   = nn.BatchNorm1d(1)
        self.sn_spatial   = neuron.ParametricLIFNode(step_mode='m', backend=backend)

        # Channel attention branch (grouped conv along channels)
        self.conv_channel = nn.Conv1d(channel, channel, kernel_size=1, stride=1, padding=0,
                                      groups=G, bias=False)
        self.bn_channel   = nn.BatchNorm1d(channel)
        self.sn_channel   = neuron.ParametricLIFNode(step_mode='m', backend=backend)

    def forward(self, x):
        # x: (T, B, C, N)
        T, B, C, N = x.shape
        x_tb = x.flatten(0,1)  # (T*B, C, N)

        # Spatial attention
        att_s = self.sn_spatial(self.bn_spatial(self.conv_spatial(x_tb)))  # (T*B, 1, N)
        x_sp  = x_tb * att_s + x_tb                                        # (T*B, C, N)

        # Channel attention
        att_c = self.sn_channel(self.bn_channel(self.conv_channel(x_sp)))  # (T*B, C, N)
        x_ch  = x_sp * att_c                                               # (T*B, C, N)

        # Fusion
        out = x_ch + x_sp
        out = out.view(T, B, C, N)
        return out



class GCNLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNLayer, self).__init__()
        self.linear = nn.Conv1d(in_channels, out_channels, 1)
        self.bn = nn.BatchNorm1d(out_channels)
        self.sn = neuron.ParametricLIFNode(step_mode='m', backend='cupy')  

    def forward(self, x, edge_indices):
 
        B, C, N = x.size()  # 
        row, col = edge_indices[:, 0], edge_indices[:, 1]  # edge indices (B, E)

        
        offset = torch.arange(B, device=x.device).repeat_interleave(row.size(1)) * N
        row = row.flatten() + offset 
        col = col.flatten() + offset 

      
        adj = SparseTensor(row=row, col=col, sparse_sizes=(B * N, B * N)).to(x.device)

        x_flat = x.permute(0, 2, 1).reshape(B * N, C)  # (B*N, C)

  
        out_flat = adj.matmul(x_flat)  # (B*N, C)

     
        out = self.linear(out_flat.permute(1, 0)).permute(1, 0)  # (B*N, C)
        out = self.bn(out)

        out = out.view(B, N, C).permute(0, 2, 1)  # (B, C, N)

      
        T = N // x.size(2) 
        V = N // T 
        out = out.view(B, C, T, V).permute(2, 0, 1, 3)  # (T,B,C,V)
        out = self.sn(out)

        return out

    
    
class SparseSemanticExtractor(nn.Module):
    def __init__(self, in_channels, k_neighbors=5, init_sparsity=0.1):
        super(SparseSemanticExtractor, self).__init__()
        self.in_channels = in_channels
        self.k_neighbors = k_neighbors
        self.cond = nn.Conv1d(in_channels, in_channels, 1)
        self.bn = nn.BatchNorm1d(in_channels)
        self.proj_lif_1 = neuron.ParametricLIFNode(step_mode='m', backend='cupy')
        

        self.sparsity_threshold = nn.Parameter(torch.tensor(init_sparsity, dtype=torch.float32))

        self.gla =GSA(in_channels=in_channels, out_channels=in_channels)

        self.gcn1 = HypergraphProcessor(in_channels, in_channels,self.k_neighbors)
        self.proj_lif = neuron.ParametricLIFNode(step_mode='m', backend='cupy')
        self.gcn2 = HypergraphProcessor(in_channels, in_channels,self.k_neighbors)


    # def construct_sparse_graph(self, x):

    #     B, T, C, V = x.size()
    #     TV = T * V  
    #   
    #     x_flat = x.permute(0, 2, 1, 3).reshape(B, C, TV)  # (B, C, TV)

    #     pairwise_distance = torch.cdist(x_flat.permute(0, 2, 1), x_flat.permute(0, 2, 1), p=2)  # (B, TV, TV)
    #     idx = pairwise_distance.topk(k=self.k_neighbors, dim=-1, largest=False)[1]  

    #     row_indices = torch.arange(TV, device=x.device).repeat(B, self.k_neighbors, 1).permute(0, 2, 1)  # (B, TV, k)
    #     row_indices = row_indices.reshape(B, -1)  # (B, TV*k)

    #     col_indices = idx.reshape(B, -1)  # (B, TV*k)


    #     edge_indices = torch.stack((row_indices, col_indices), dim=1)  # (B, 2, TV*k)

    #     return edge_indices, x_flat

    
    def forward(self, x):
        x = x.permute(1, 0, 2, 3)
        batch_size, time_steps, channels, num_points = x.size()
  
        x = x.reshape(time_steps, batch_size, -1, num_points)
        x_gcn1,hypergraph_info1 = self.gcn1(x)  # (B, C, TV)
        x_gcn1 = x_gcn1.reshape(time_steps, batch_size, -1, num_points)
        x = self.gla(x_gcn1).reshape(time_steps, batch_size, -1, num_points)
        x,hypergraph_info2 = self.gcn2(x)
        x = x.reshape(time_steps, batch_size, -1, num_points) 
        return x

def main():
    device_id = 1
    device = torch.device(f"cuda:{device_id}")
    print(f"Using device: {device}")

    torch.manual_seed(42)

    batch_size = 4 
    time_steps = 3  
    channels = 256   
    num_points = 25
    k_neighbors = 10 

   
    x = torch.randn(time_steps,batch_size,  channels, num_points).to(device)

    extractor = SparseSemanticExtractor(in_channels=channels, k_neighbors=k_neighbors, init_sparsity=0.1).to(device)


    output = extractor(x)


if __name__ == "__main__":
    main()
