from torch import Tensor
import torch
import torch.nn as nn
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair
from torch.nn import functional as F
from typing import List
from spikingjelly.clock_driven.neuron import MultiStepLIFNode,MultiStepParametricLIFNode
from functools import partial
from timm.models.layers import trunc_normal_   #timm==0.5.4



import torch
from sklearn.metrics import mutual_info_score
import numpy as np

 
class Betascheduler:
    def __init__(self,initial,final,total_epochs):
        self.initial = initial
        self.final = final
        self.total_epochs = total_epochs
        self.currrent_epochs = 0
        self.beta = initial

    def step(self):

        self.beta += (self.final - self.initial)*(1/self.total_epochs)
        self.currrent_epochs += 1

    def get(self):
        return self.beta



class Decoder_MLP(nn.Module):
    def __init__(self,in_channels,hidden_channels,T, lif_mode ='lif',norm_layer_td = 'batch'):
        super().__init__()


        self.linear_1 = nn.Linear(in_channels,hidden_channels,bias=False)
        self.bn_1 = nn.BatchNorm1d(hidden_channels)
        self.lif_1 =  MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        self.T = T

        self.linear_2 = nn.Linear(hidden_channels,in_channels,bias=False)
        self.bn_2 = nn.BatchNorm1d(in_channels)
        self.lif_2 =  MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        self.hidden = hidden_channels

    def forward(self,x):   
        _ ,N,C = x.shape
     
        #Bn层接受B C 和 B C L的输入
        x = self.linear_1(x).transpose(-1,-2).contiguous()  # TB 4C N

        x = self.bn_1(x).reshape(self.T,-1,self.hidden,N).transpose(-1,-2).contiguous()  # T B 4C N

        x = self.lif_1(x).reshape(-1,self.hidden,N).transpose(-1,-2).contiguous()  # TB N 4C

        x = self.linear_2(x).transpose(-1,-2).contiguous() # TB C N

        x = self.bn_2(x).reshape(self.T,-1,C,N).transpose(-1,-2).contiguous()  # T B C N

        x = self.lif_2(x).reshape(-1,C,N).transpose(-1,-2).contiguous()  # TB N C
       
        return x

class Decoder1(nn.Module):
    def __init__(self,in_channels,out_channels,T,lif_mode ='none',norm_layer_td = 'batch'):
        super().__init__()

        if norm_layer_td == 'batch':
            print("using batch")
            self.bn_td = nn.BatchNorm1d(out_channels)
        elif norm_layer_td == 'layer':
            print("using layer")
            norm_layer_td = partial(nn.LayerNorm,eps=1e-06)
            self.bn_td = norm_layer_td(out_channels)
        elif norm_layer_td == 'none':
            print("do not use norm layer")
            self.bn_td = nn.Identity()
        else:
            print("no available norm_layer")
            exit(0)

        self.out_channels = out_channels
        self.linear_td = nn.Linear(in_channels,out_channels,bias=False)
 
        self.T = T

        if lif_mode == 'lif':
            print("using lif")
            self.lif_td = MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
            # self.lif_td = nn.Identity()
        elif lif_mode == 'plif':
            print("using plif")
            self.lif_td = MultiStepParametricLIFNode(init_tau=2.0,detach_reset=True,backend='torch')
        else:
            self.lif_td = nn.Identity()
            print("no available lif mode")


    def forward(self,x):   
        _ ,N,C = x.shape
        #接受B C 和 B C L的输入
        x = self.linear_td(x).transpose(-1,-2).contiguous()  # TB C N

        x = self.bn_td(x).reshape(self.T,-1,C,N).transpose(-1,-2).contiguous()  # T B C N

        x = self.lif_td(x).reshape(-1,C,N).transpose(-1,-2).contiguous()  # TB N C
        return x
    


class Decoder2(nn.Module):
    def __init__(self,in_channels,out_channels,T,lif_mode ='none',norm_layer_td = 'batch'):
        super().__init__()

        if norm_layer_td == 'batch':
            print("using batch")
            self.bn_td = nn.BatchNorm1d(out_channels)
        elif norm_layer_td == 'layer':
            print("using layer")
            norm_layer_td = partial(nn.LayerNorm,eps=1e-06)
            self.bn_td = norm_layer_td(out_channels)
        elif norm_layer_td == 'none':
            print("do not use norm layer")
            self.bn_td = nn.Identity()
        else:
            print("no available norm_layer")
            exit(0)

        self.out_channels = out_channels
        self.linear_td = nn.Linear(in_channels,out_channels,bias=False)
        # self.linear_td_2 = nn.Linear(in_channels,out_channels,bias=False)
 
        self.T = T
        print("using Decoder2")
        if lif_mode == 'lif':
            print("using lif")
            self.lif_td = MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        elif lif_mode == 'plif':
            print("using plif")
            self.lif_td = MultiStepParametricLIFNode(init_tau=2.0,detach_reset=True,backend='torch')
        else:
            print("no available lif mode")
            self.lif_td = nn.Identity()
        # print(Decoder2)

    def forward(self,x):   
        TB,N,C = x.shape
   
        # x = x.flatten(0,1).flatten(2,3).transpose(-1,-2)   #TB N C

        #接受B C 和 B C L的输入
        x = self.bn_td(x.transpose(-1,-2).contiguous()) # TB C N
        # x = self.linear_td_2(self.linear_td(x)).transpose(-1,-2).contiguous()  # TB C N

        x = self.linear_td(x.transpose(-1,-2)).reshape(self.T,-1,N,C).contiguous()  # T B N C

       
        x = self.lif_td(x).contiguous()  # T B N C

        x = x.flatten(0,1)   # TB N C

        return x
    



class Decoder_sdt(nn.Module):
    def __init__(self,in_channels,out_channels,T,lif_mode ='lif',norm_layer_td = 'batch'):
        super().__init__()

        self.conv_td = nn.Conv2d(in_channels,out_channels, kernel_size=1, stride=1)

        self.bn_td = nn.BatchNorm2d(out_channels)

        self.out_channels = out_channels
        self.T = T


        print("using Decoder_sdt")
        if lif_mode == 'lif':
            print("using lif")
            self.lif_td = MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        elif lif_mode == 'plif':
            print("using plif")
            self.lif_td = MultiStepParametricLIFNode(init_tau=2.0,detach_reset=True,backend='torch')
        else:
            print("no available lif mode")
            self.lif_td = nn.Identity()
        # print(Decoder_sdt)

    def forward(self,x):   
        TB,N,C = x.shape
        H,W = 8,8
        x = x.reshape(self.T,TB//self.T,N,C).transpose(-1,-2)   #T B C N
        x = x.reshape(self.T,TB//self.T,C,H,W)          # T B C H W
        
        #接受B C 和 B C L的输入
        x = self.lif_td(x)

        x = self.conv_td(x.flatten(0, 1))       # TB C H W

        x = self.bn_td(x).reshape(TB,C,N).transpose(-1,-2).contiguous()   # TB N C

        return x



def TET_loss(outputs, labels, criterion):
    T = outputs.size(0)
    Loss_es = 0
    for t in range(T):
        Loss_es += criterion(outputs[t, ...], labels)
    Loss_es = Loss_es / T  # L_TET
    # if lamb != 0:
    #     MMDLoss = torch.nn.MSELoss()
    #     y = torch.zeros_like(outputs).fill_(means)
    #     Loss_mmd = MMDLoss(outputs, y)  # L_mse
    # else:
    #     Loss_mmd = 0
    return  Loss_es   # L_Total



def LTS_loss(outputs, labels, criterion):
    return criterion(outputs[-1], labels)


def compute_mutual_info_matrix(output_list, bins=20):

    n = len(output_list)    
    mi_matrix = torch.zeros((n, n))

    # Compute mutual information for each pair of outputs
    for i in range(n):
        for j in range(i, n):
            # Compute mutual information between output_list[i] and output_list[j]
            mi = mutual_info(output_list[i].flatten(), output_list[j].flatten(), bins)
            mi_matrix[i, j] = mi
            mi_matrix[j, i] = mi  # The matrix is symmetric

    return mi_matrix

def mutual_info(x, y, bins):
    # print(x.shape)
    # print(y.shape)
   
    # Flatten the tensors to 1D arrays
    x = x.cpu().numpy()
    y = y.cpu().numpy()
    # print(x.shape)
    # print(y.shape)
    # exit(0)


    # Use sklearn's mutual_info_score to calculate mutual information
    return mutual_info_score(None, None, contingency=np.histogram2d(x, y, bins=bins)[0])



