import torch.nn as nn
import torch
from functools import partial
from spikingjelly.clock_driven.neuron import (
    MultiStepLIFNode,
    MultiStepParametricLIFNode,
)
from sklearn.metrics import mutual_info_score
import numpy as np


class Decoder1(nn.Module):
    def __init__(self,in_channels,out_channels,T,lif_mode ='lif',norm_layer_td = 'batch',res = False):
        super().__init__()

        if norm_layer_td == 'batch':
            print("using batch")
            self.bn_td = nn.BatchNorm1d(out_channels)
        elif norm_layer_td == 'layer':
            print("using layer")
            norm_layer_td = partial(nn.LayerNorm,eps=1e-06)
            self.bn_td = norm_layer_td(out_channels)
        elif norm_layer_td == 'none':
            print("do not use norm layer")
            self.bn_td = nn.Identity()
        else:
            print("no available norm_layer")
            self.lif_td = nn.Identity()

        self.out_channels = out_channels
        self.in_channels = in_channels
        self.linear_td = nn.Linear(in_channels,out_channels,bias=False)

        self.res = res
        if res:
            print("using res")
        self.T = T
     

        if lif_mode == 'lif':
            print("using lif")
            self.lif_td = MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        elif lif_mode == 'plif':
            print("using plif")
            self.lif_td = MultiStepParametricLIFNode(init_tau=2.0,detach_reset=True,backend='torch')
        else:
            print("no available lif mode")

    def forward(self,x):   
   
        T_B ,N, C = x.shape
        B = T_B // self.T
        tmp = x
        #接受B C 和 B C L的输入


        x = self.linear_td(x) # T_B, N, C


        x = self.bn_td(x.transpose(-1,-2)).reshape(self.T,B,self.out_channels,-1).contiguous()  # T B C N


        x = self.lif_td(x).flatten(0,1).transpose(-1,-2).contiguous()  # TB N C
 


        if self.res and self.in_channels == self.out_channels:
            return tmp + x


        return x
    
    
def mutual_info(x, y, bins=20):
    # Flatten the input tensors to 1D arrays
    x = x.flatten().cpu().numpy()
    y = y.flatten().cpu().numpy()
    
    # Use mutual_info_score to compute the mutual information
    return mutual_info_score(None, None, contingency=np.histogram2d(x, y, bins=bins)[0])

# Function to calculate the mutual information matrix
def calculate_mutual_info_matrix(output_list, bins=20):
    n = len(output_list)
    mi_matrix = torch.zeros((n, n))  # Initialize the matrix
    
    # Calculate mutual information between each pair of outputs
    for i in range(n):
        for j in range(i, n):
            mi = mutual_info(output_list[i], output_list[j], bins)
            mi_matrix[i, j] = mi
            mi_matrix[j, i] = mi  # Ensure the matrix is symmetric
    
    print("Mutual Information Matrix (symmetric):\n", mi_matrix)
    return mi_matrix
    
def calculate_firing_rate(td):
    if td is not None:
        # Ensure td is a binary tensor, and calculate the firing rate
        firing_rate = torch.sum(td > 0) / td.numel()
        return firing_rate


