import torch.nn as nn
import torch.nn.functional as F
import torch
# import matplotlib.pyplot as plt
import numpy as np

mem_distill =[]

class TensorNormalization(nn.Module):
    def __init__(self,mean, std):
        super(TensorNormalization, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.mean = mean
        self.std = std
    def forward(self,X):
        return normalizex(X,self.mean,self.std)

def normalizex(tensor, mean, std):
    mean = mean[None, :, None, None]
    std = std[None, :, None, None]
    if mean.device != tensor.device:
        mean = mean.to(tensor.device)
        std = std.to(tensor.device)
    return tensor.sub(mean).div(std)


class SeqToANNContainer(nn.Module):
    # This code is form spikingjelly https://github.com/fangwei123456/spikingjelly
    def __init__(self, *args):
        super().__init__()
        if len(args) == 1:
            self.module = args[0]
        else:
            self.module = nn.Sequential(*args)

    def forward(self, x_seq: torch.Tensor):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y_seq = self.module(x_seq.flatten(0, 1).contiguous())
        y_shape.extend(y_seq.shape[1:])
        return y_seq.view(y_shape)

class Layer(nn.Module):
    def __init__(self,in_plane,out_plane,kernel_size,stride,padding):
        super(Layer, self).__init__()
        self.fwd = SeqToANNContainer(
            nn.Conv2d(in_plane,out_plane,kernel_size,stride,padding),
            nn.BatchNorm2d(out_plane)
        )
        self.act = LIFSpike()

    def forward(self,x):
        x = self.fwd(x)
        x = self.act(x)
        return x

class APLayer(nn.Module):
    def __init__(self,kernel_size):
        super(APLayer, self).__init__()
        self.fwd = SeqToANNContainer(
            nn.AvgPool2d(kernel_size),
        )
        self.act = LIFSpike()

    def forward(self,x):
        x = self.fwd(x)
        x = self.act(x)
        return x


class ZIF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, gama):
        out = (input > 0).float()
        L = torch.tensor([gama])
        ctx.save_for_backward(input, out, L)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        (input, out, others) = ctx.saved_tensors
        gama = others[0].item()
        grad_input = grad_output.clone()
        tmp = (1 / gama) * (1 / gama) * ((gama - input.abs()).clamp(min=0))
        grad_input = grad_input * tmp
        return grad_input, None

# class LIFSpike(nn.Module):    # original
#     def __init__(self, thresh=1.0, tau=0.5, gama=1.0):
#         super(LIFSpike, self).__init__()
#         self.act = ZIF.apply
#         self.thresh = thresh
#         self.tau = tau
#         self.gama = gama
#
#     def forward(self, x):
#         mem = 0
#         spike_pot = []
#         T = x.shape[1]
#
#         for t in range(T):
#             mem = mem * self.tau + x[:, t, ...]
#             spike = self.act(mem - self.thresh, self.gama)
#             # spike = self.act((mem - self.thresh)*self.k)
#             mem = (1 - spike) * mem
#             spike_pot.append(spike)
#             out = torch.stack(spike_pot, dim=1)
#         # print(out.mean())
#         return out

class LIFSpike(nn.Module):
    def __init__(self, thresh=1.0, tau=0.5, gama=1.0):
        super(LIFSpike, self).__init__()
        self.act = ZIF.apply
        self.thresh = thresh
        self.tau = tau
        self.gama = gama

    def forward(self, x):
        global mem_distill
        mem = 0
        spike_pot = []
        T = x.shape[1]

        is_training = torch.is_grad_enabled()
        if is_training:
            mem_recod = []
            for t in range(T):
                mem = mem * self.tau + x[:, t, ...]
                mem_recod.append(mem.clone())
                spike = self.act(mem - self.thresh, self.gama)
                mem = (1 - spike) * mem
                spike_pot.append(spike)

            mem_recod = torch.stack(mem_recod,dim=1)
            out_spike = torch.stack(spike_pot, dim=1)
            mem_distill.append(mem_recod)
            return out_spike
        else:
            for t in range(T):
                mem = mem * self.tau + x[:, t, ...]
                spike = self.act(mem - self.thresh, self.gama)
                # spike = self.act((mem - self.thresh)*self.k)
                mem = (1 - spike) * mem
                spike_pot.append(spike)
            out = torch.stack(spike_pot, dim=1)
            return out


def add_dimention(x, T):
    x.unsqueeze_(1)
    x = x.repeat(1, T, 1, 1, 1)
    return x


# ----- For ResNet19 code -----




def plot_tensor_histogram(tensor, bins=100, title='Tensor Distribution'):
    # 将tensor展平成一维数组
    flattened = tensor.flatten().detach().cpu().numpy()

    # 创建图形
    plt.figure(figsize=(10, 6))

    # 绘制直方图
    plt.hist(flattened, bins=bins, density=True, alpha=0.7)

    # 计算并显示数据范围
    min_val = np.min(flattened)
    max_val = np.max(flattened)

    # 添加标题和标签
    plt.title(f'{title}\nRange: [{min_val:.4f}, {max_val:.4f}]')
    plt.xlabel('Value')
    plt.ylabel('Density')

    # 添加统计信息
    mean = np.mean(flattened)
    std = np.std(flattened)
    plt.text(0.02, 0.95,
             f'Mean: {mean:.4f}\nStd: {std:.4f}\n'
             f'Min: {min_val:.4f}\nMax: {max_val:.4f}',
             transform=plt.gca().transAxes,
             bbox=dict(facecolor='white', alpha=0.8))

    # 显示网格
    plt.grid(True, alpha=0.3)

    plt.show()


# 如果你想分别查看不同channel的分布：
def plot_channel_histograms(tensor, channel_indices=None):
    if channel_indices is None:
        channel_indices = range(tensor.shape[2])  # 所有channel

    for idx in channel_indices:
        channel_data = tensor[:, :, idx, :, :]  # 选择特定channel的数据
        plot_tensor_histogram(channel_data,
                              title=f'Channel {idx} Distribution')
class tdLayer(nn.Module):
    def __init__(self, layer, bn=None):
        super(tdLayer, self).__init__()
        self.layer = SeqToANNContainer(layer)
        self.bn = bn

    def forward(self, x):
        x_ = self.layer(x)
        if self.bn is not None:
            x_ = self.bn(x_)
        return x_


class tdBatchNorm(nn.Module):
    def __init__(self, out_panel):
        super(tdBatchNorm, self).__init__()
        self.bn = nn.BatchNorm2d(out_panel)
        self.seqbn = SeqToANNContainer(self.bn)

    def forward(self, x):
        y = self.seqbn(x)
        return y