import scipy.stats
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair
from torch.nn import functional as F
from typing import List
from spikingjelly.clock_driven.neuron import MultiStepLIFNode,MultiStepParametricLIFNode
from functools import partial
from timm.models.layers import trunc_normal_   #timm==0.5.4
import numpy as np
import scipy
from spikingjelly.clock_driven import functional
from collections import Counter
 
class Betascheduler:
    def __init__(self,initial,final,total_epochs):
        self.initial = initial
        self.final = final
        self.total_epochs = total_epochs
        self.currrent_epochs = 0
        self.beta = initial

    def step(self):

        self.beta += (self.final - self.initial)*(1/self.total_epochs)
        self.currrent_epochs += 1

    def get(self):
        return self.beta



class Decoder_MLP(nn.Module):
    def __init__(self,in_channels,hidden_channels,T,lif_mode ='lif',norm_layer_td = 'batch'):
        super().__init__()


        self.linear_1 = nn.Linear(in_channels,hidden_channels,bias=False)
        self.bn_1 = nn.BatchNorm1d(hidden_channels)
        self.lif_1 =  MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        self.T = T

        self.linear_2 = nn.Linear(hidden_channels,in_channels,bias=False)
        self.bn_2 = nn.BatchNorm1d(in_channels)
        self.lif_2 =  MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        self.hidden = hidden_channels

    def forward(self,x):   
        _ ,N,C = x.shape
     
        #Bn层接受B C 和 B C L的输入
        x = self.linear_1(x).transpose(-1,-2).contiguous()  # TB 4C N

        x = self.bn_1(x).reshape(self.T,-1,self.hidden,N).transpose(-1,-2).contiguous()  # T B 4C N

        x = self.lif_1(x).reshape(-1,self.hidden,N).transpose(-1,-2).contiguous()  # TB N 4C

        x = self.linear_2(x).transpose(-1,-2).contiguous() # TB C N

        x = self.bn_2(x).reshape(self.T,-1,C,N).transpose(-1,-2).contiguous()  # T B C N

        x = self.lif_2(x).reshape(-1,C,N).transpose(-1,-2).contiguous()  # TB N C
       
        return x














class Decoder1(nn.Module):
    def __init__(self,in_channels,out_channels,T,lif_mode ='lif',norm_layer_td = 'none'):
        super().__init__()

        if norm_layer_td == 'batch':
            print("using batch")
            self.bn_td = nn.BatchNorm1d(out_channels)
        elif norm_layer_td == 'layer':
            print("using layer")
            norm_layer_td = partial(nn.LayerNorm,eps=1e-06)
            self.bn_td = norm_layer_td(out_channels)
        elif norm_layer_td == 'none':
            print("do not use norm layer")
            self.bn_td = nn.Identity()
        else:
            print("no available norm_layer")
            exit(0)

        self.out_channels = out_channels
        self.linear_td = nn.Linear(in_channels,out_channels,bias=False)
        # self.linear_td2 = nn.Linear(in_channels,out_channels,bias=False)
 
        self.T = T

        if lif_mode == 'lif':
            print("using lif")
            self.lif_td = MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        elif lif_mode == 'plif':
            print("using plif")
            self.lif_td = MultiStepParametricLIFNode(init_tau=2.0,detach_reset=True,backend='torch')
        else:
            print("no available lif mode")

        # self.apply(self._init_weights)



    # def _init_weights(self, m):
    #     print(m)
    #     if isinstance(m, nn.Linear):
    #         trunc_normal_(m.weight, std=.02)
    #         if isinstance(m, nn.Linear) and m.bias is not None:
    #             nn.init.constant_(m.bias, 0)
    #     elif isinstance(m, nn.LayerNorm):
    #         nn.init.constant_(m.bias, 0)
    #         nn.init.constant_(m.weight, 1.0)
        # elif hasattr(m, 'init_weights'):
        #     m.init_weights()

        



    def forward(self,x):   
        _ ,N,C = x.shape
        #接受B C 和 B C L的输入
        # x = self.linear_td2(self.linear_td(x)).transpose(-1,-2).contiguous()  # TB C N
        x = self.linear_td(x).transpose(-1,-2).contiguous()  # TB C N

        x = self.bn_td(x).reshape(self.T,-1,C,N).transpose(-1,-2).contiguous()  # T B N C

        x = self.lif_td(x).reshape(-1,C,N).transpose(-1,-2).contiguous()  # TB N C
        return x
    


class Decoder2(nn.Module):
    def __init__(self,in_channels,out_channels,T,lif_mode ='lif',norm_layer_td = 'batch'):
        super().__init__()

        if norm_layer_td == 'batch':
            print("using batch")
            self.bn_td = nn.BatchNorm1d(out_channels)
        elif norm_layer_td == 'layer':
            print("using layer")
            norm_layer_td = partial(nn.LayerNorm,eps=1e-06)
            self.bn_td = norm_layer_td(out_channels)
        elif norm_layer_td == 'none':
            print("do not use norm layer")
            self.bn_td = nn.Identity()
        else:
            print("no available norm_layer")
            exit(0)

        self.out_channels = out_channels
        self.linear_td = nn.Linear(in_channels,out_channels,bias=False)
        # self.linear_td_2 = nn.Linear(in_channels,out_channels,bias=False)
 
        self.T = T
        print("using Decoder2")
        if lif_mode == 'lif':
            print("using lif")
            self.lif_td = MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        elif lif_mode == 'plif':
            print("using plif")
            self.lif_td = MultiStepParametricLIFNode(init_tau=2.0,detach_reset=True,backend='torch')
        else:
            print("no available lif mode")
            self.lif_td = nn.Identity()
        # print(Decoder2)

    def forward(self,x):   
        TB,N,C = x.shape
   
        # x = x.flatten(0,1).flatten(2,3).transpose(-1,-2)   #TB N C

        #接受B C 和 B C L的输入
        x = self.bn_td(x.transpose(-1,-2).contiguous()) # TB C N
        # x = self.linear_td_2(self.linear_td(x)).transpose(-1,-2).contiguous()  # TB C N

        x = self.linear_td(x.transpose(-1,-2)).reshape(self.T,-1,N,C).contiguous()  # T B N C

       
        x = self.lif_td(x).contiguous()  # T B N C

        x = x.flatten(0,1)   # TB N C

        return x
    



class Decoder_sdt(nn.Module):
    def __init__(self,in_channels,out_channels,T,lif_mode ='lif',norm_layer_td = 'batch'):
        super().__init__()

        self.conv_td = nn.Conv2d(in_channels,out_channels, kernel_size=1, stride=1)

        self.bn_td = nn.BatchNorm2d(out_channels)

        self.out_channels = out_channels
        self.T = T


        print("using Decoder_sdt")
        if lif_mode == 'lif':
            print("using lif")
            self.lif_td = MultiStepLIFNode(tau = 2.0,detach_reset = True,backend = 'torch')
        elif lif_mode == 'plif':
            print("using plif")
            self.lif_td = MultiStepParametricLIFNode(init_tau=2.0,detach_reset=True,backend='torch')
        else:
            print("no available lif mode")
            self.lif_td = nn.Identity()
        # print(Decoder_sdt)

    def forward(self,x):   
        TB,N,C = x.shape
        H,W = 8,8
        x = x.reshape(self.T,TB//self.T,N,C).transpose(-1,-2)   #T B C N
        x = x.reshape(self.T,TB//self.T,C,H,W)          # T B C H W
        
        #接受B C 和 B C L的输入
        x = self.lif_td(x)

        x = self.conv_td(x.flatten(0, 1))       # TB C H W

        x = self.bn_td(x).reshape(TB,C,N).transpose(-1,-2).contiguous()   # TB N C

        return x


# is a hook function
def activation_hook(module, input, output):
    # 计算激活值的统计信息
  
  
    # input = input[0]
    activations = output.detach().cpu().numpy()

    flattened_activations = activations.flatten()

    # normalized_activations = (flattened_activations - np.min(flattened_activations)) / (np.max(flattened_activations) - np.min(flattened_activations))
    
    hist, bin_edges = np.histogram(flattened_activations, bins=100, density=True)

    h = scipy.stats.entropy(hist)

    print("The entropy of ",type(module).__name__,"is",h)


    return 



def TET_loss(outputs, labels, criterion):
    T = outputs.size(0)
    Loss_es = 0
    for t in range(T):
        Loss_es += criterion(outputs[t, ...], labels)
    Loss_es = Loss_es / T  # L_TET
    # if lamb != 0:
    #     MMDLoss = torch.nn.MSELoss()
    #     y = torch.zeros_like(outputs).fill_(means)
    #     Loss_mmd = MMDLoss(outputs, y)  # L_mse
    # else:
    #     Loss_mmd = 0
    return  Loss_es   # L_Total



def LTS_loss(outputs, labels, criterion):
    return criterion(outputs[-1], labels)


def spike_pattern(model,device,loader_eval,validate,args,amp_autocast):

        # define the hook
        hook_output = []
        def direct_coding_hook(module, inp, out):
            hook_output.append(out.data.cpu().numpy())

        for name, layer in model.named_modules():
            print(name)
            if name == "stage3.1.mlp.mlp2_lif":  # 确保名称匹配
                layer.register_forward_hook(direct_coding_hook)
                print(f"Hook registered on: {name}")

        criterion = torch.nn.CrossEntropyLoss()

        # switch to evaluation mode
        model.eval()

        eval_metrics = validate(model, loader_eval, criterion, args, amp_autocast=amp_autocast, 
                                    td=args.top_down)
        # 查看direct coding的统计
        print(hook_output[0][0].shape)
        reshaped_data = hook_output[0][0].reshape(4, -1).T  # 形状变为 (N, 4)
        counter = Counter(map(tuple, reshaped_data))

        # 创建 16 种可能组合
        possible_combinations = [tuple(format(i, '04b')) for i in range(16)]
        possible_combinations = [tuple(map(int, comb)) for comb in possible_combinations]

        # 统计结果
        combination_counts = {comb: counter.get(comb, 0) for comb in possible_combinations}
        # 计算总数量
        total_count = sum(combination_counts.values())
        # 计算每个组合的比例
        combination_ratios = {comb: count / total_count for comb, count in combination_counts.items()}
        for comb, count in combination_counts.items():
            ratio = combination_ratios[comb]
            print(f"Combination: {comb}, Count: {count}, Ratio: {ratio:.6f}")

        exit(0)
        
        
        
        
def calculate_firing_rate(td):
    if td is not None:
        # Ensure td is a binary tensor, and calculate the firing rate
        firing_rate = torch.sum(td > 0) / td.numel()
        return firing_rate
    
    
    
def check_frozen_params(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} is trainable.")
        else:
            print(f"{name} is frozen.")

# Assuming 'model' is your Feedbackstage instance


def print_tensor_variance(tensor):
    variance = torch.var(tensor)
    print(f"The variance of the tensor is: {variance.item()}")
    
    
    
def tensor_distribution(tensor, bins=10, min_value=None, max_value=None):
    
    
    
    tensor = tensor.to(torch.float32)

    tensor_min = torch.min(tensor)
    tensor_max = torch.max(tensor)
    
    if min_value is None:
        min_value = tensor_min.item()
    if max_value is None:
        max_value = tensor_max.item()
    
    # 计算每个区间的边界
    interval_size = (max_value - min_value) / bins
    intervals = [min_value + i * interval_size for i in range(bins+1)]  # 包含左闭右开的区间

    # 计算张量的直方图
    hist = torch.histc(tensor, bins=bins, min=min_value, max=max_value)

    # 打印每个区间及其对应的数量
    print(f"Data distribution with {bins} intervals:")
    for i in range(bins):
        print(f"Interval [{intervals[i]:.2f}, {intervals[i+1]:.2f}): {int(hist[i].item())} items")

