import torch
import torch.nn as nn
from stork.connections import BaseConnection, Connection, ConvConnection, Conv2dConnection
from stork.layers import AbstractLayer
from stork import utils
import stork.nodes as nodes
from torch.nn.parameter import Parameter
import numpy as np
from stork import core
from stork import constraints as stork_constraints
import torch.autograd as autograd
from torch.autograd import Function
from .RepConv1d import *



class Connection_withBatchNorm(BaseConnection):
    def __init__(
        self,
        src,
        dst,
        operation=nn.Linear,
        target=None,
        bias=False,
        requires_grad=True,
        propagate_gradients=True,
        flatten_input=False,
        name=None,
        regularizers=None,
        constraints=None,
        row=False,
        common_linear=None,
        session_encode_linear=None,
        session_encode_input=None,
        adaptive_bn=False,
        bn_type = "channel",
        bn2ln = False,
        shortcut=False,
        **kwargs
    ):
        super(Connection_withBatchNorm, self).__init__(
            src,
            dst,
            name=name,
            target=target,
            regularizers=regularizers,
            constraints=constraints,
        )

        self.requires_grad = requires_grad
        self.propagate_gradients = propagate_gradients
        self.flatten_input = flatten_input
        self.row = row
        self.adaptive_bn = adaptive_bn
        self.bn_type = bn_type
        self.bn2ln = bn2ln
        self.shortcut = shortcut

        # if operation==nn.Linear:
        #     if flatten_input:
        #         self.op = operation(src.nb_units, dst.shape[0], bias=bias, **kwargs)
        #     else:
        #         self.op = operation(src.shape[0], dst.shape[0], bias=bias, **kwargs)
        # else:
        #     self.flatten_input = False
        #     self.op = operation(src.shape[0], dst.shape[0], bias=bias, **kwargs)
        if operation == RepVGGLinearBlock1d:
            self.flatten_input = True
            self.op = operation(src.nb_units, dst.nb_units, dst.shape[0], bias=bias, **kwargs)
        elif self.row:
            self.flatten_input=True
            self.op = operation(src.nb_units, dst.nb_units, bias=bias, **kwargs)
            self.output_shape = (dst.shape[0], dst.shape[1])
        elif flatten_input:
            assert len(dst.shape) == 1, "a"
            self.op = operation(src.nb_units, dst.shape[0], bias=bias, **kwargs)
        else:
            self.op = operation(src.shape[0], dst.shape[0], bias=bias, **kwargs)


        # if operation == RepVGGplusBlock1d or operation == RepVGGLinearBlock1d or operation == RepVGGplusBlock1dV2:
        #     self.bn = None
        if isinstance(self.op, RepClassModule):
            self.bn = None
        else:
            if self.bn2ln:
                self.bn = nn.LayerNorm(dst.shape[-1])
            else:
                if self.bn_type == "units":
                    self.bn = nn.BatchNorm1d(dst.nb_units)
                elif self.bn_type == "channel":
                    self.bn = nn.BatchNorm1d(dst.shape[0])

        for param in self.op.parameters():
            param.requires_grad = requires_grad
        if self.bn is not None:
            for param in self.bn.parameters():
                param.requires_grad = requires_grad

        if common_linear is not None:
            self.common_linear = common_linear
        else:
            self.common_linear = False
        if session_encode_linear is not None:
            self.session_encode_linear = session_encode_linear
            self.session_encode_input = session_encode_input
        else:
            self.session_encode_linear = False
            self.session_encode_input = False

        if self.adaptive_bn:
            self.a = nn.Parameter(0.5 * torch.ones(1))


    def configure(self, batch_size, nb_steps, time_step, device, dtype):
        super().configure(batch_size, nb_steps, time_step, device, dtype)

    def add_diagonal_structure(self, width=1.0, ampl=1.0):
        if not isinstance(self.op, nn.Linear):
            raise ValueError("Expected op to be nn.Linear to add diagonal structure.")
        A = np.zeros(self.op.weight.shape)
        x = np.linspace(0, A.shape[0], A.shape[1])
        for i in range(len(A)):
            A[i] = ampl * np.exp(-((x - i) ** 2) / width**2)
        self.op.weight.data += torch.from_numpy(A)

    def get_weights(self):
        return self.op.weight

    def get_regularizer_loss(self):
        reg_loss = torch.tensor(0.0, device=self.device)
        for reg in self.regularizers:
            reg_loss += reg(self.get_weights())
        return reg_loss

    def forward(self):
        preact = self.src.out
        if not self.propagate_gradients:
            preact = preact.detach()
        if self.flatten_input:
            shp = preact.shape
            preact = preact.reshape(shp[:1] + (-1,))

        if self.common_linear:
            preact = self.common_linear(preact)

        if self.session_encode_linear:
            session_code_preact = self.session_encode_input.out
            if self.flatten_input:
                shp = session_code_preact.shape
                session_code_preact = session_code_preact.reshape(shp[:1] + (-1,))
            session_code_preact = self.session_encode_linear(session_code_preact)
            preact+= session_code_preact

        out = self.op(preact)

        # if self.bn is not None:
        #     out = self.bn(out)
        if self.bn_type=="units":
            if self.bn is not None:
                if self.adaptive_bn:
                    # 使用自适应BN
                    out = self.bn(out) * self.a + out * (1 - self.a)
                else:
                    out = self.bn(out)
            if self.row:
                out = out.reshape(out.shape[0], self.output_shape[0], self.output_shape[1] )
        else:
            if self.row:
                out = out.reshape(out.shape[0], self.output_shape[0], self.output_shape[1] )
            if self.bn is not None:
                if self.adaptive_bn:
                    # 使用自适应BN
                    out = self.bn(out) * self.a + out * (1 - self.a)
                else:
                    out = self.bn(out)

        if self.shortcut:
            out=out+preact.reshape(shp)

        self.dst.add_to_state(self.target, out)

    def propagate(self):
        self.forward()

    def apply_constraints(self):
        for const in self.constraints:
            const.apply(self.op.weight)

    def get_equivalent_weight_bias(self):
        op_weight, op_bias = self._fuse_bn_tensor(self.op, self.bn)
        return op_weight, op_bias

    def _fuse_bn_tensor(self, op, bn):
        weight, running_mean, running_var, gamma, beta, eps = op.weight, bn.running_mean, bn.running_var, bn.weight, bn.bias, bn.eps
        if hasattr(op, 'bias') and op.bias is not None:
            bias = op.bias
        else:
            bias = torch.zeros(op.weight.shape[0], device=weight.device, dtype=weight.dtype)
        std = (running_var + eps).sqrt()

        # 处理权重维度不匹配
        weight_shape_mismatch = weight.shape[0] != gamma.shape[0]
        # 处理偏置维度不匹配
        bias_shape_mismatch = bias.shape[0] != beta.shape[0]

        # 处理通道维度不匹配情况
        if weight_shape_mismatch or bias_shape_mismatch:
            if self.row:
                # 当权重通道是gamma通道的整数倍时(可能是分组卷积或特殊结构)
                repeat_factor = weight.shape[0] // gamma.shape[0]
                gamma = gamma.repeat_interleave(repeat_factor)
                beta = beta.repeat_interleave(repeat_factor)
                running_mean = running_mean.repeat_interleave(repeat_factor)
                running_var = running_var.repeat_interleave(repeat_factor)
                std = (running_var + eps).sqrt()

        # 根据操作类型调整reshape维度
        if isinstance(op, nn.Linear):
            t = (gamma / std)
            # 确保t的维度与weight兼容
            if weight.dim() > 1:
                reshape_dims = [t.shape[0]] + [1] * (weight.dim() - 1)
                t = t.reshape(*reshape_dims)
        # elif isinstance(op, nn.Conv1d):
        #     t = (gamma / std).reshape(-1, 1)
        # elif isinstance(op, nn.Conv2d):
        #     t = (gamma / std).reshape(-1, 1, 1)
        else:
            raise NotImplementedError("Unsupported operation type for batch normalization fusion")

        # 确保所有参与计算的张量维度匹配
        assert t.shape[0] == weight.shape[0] , f"权重维度不匹配: t.shape={t.shape}, weight.shape={weight.shape}"
        assert bias.shape[0] == beta.shape[0], f"偏置维度不匹配: bias.shape={bias.shape}, beta.shape={beta.shape}"

        return weight * t, ((bias - running_mean) * gamma / std) + beta

    def switch_to_deploy(self):
        if self.bn is None:
            return
        else:
            weight, bias = self.get_equivalent_weight_bias()
            self.op.weight = Parameter(weight)
            if not hasattr(self.op, 'bias') or self.op.bias is None:
                # 如果原操作没有偏置，创建一个新的
                self.op.bias = Parameter(bias)
            else:
                # 原操作有偏置，更新它
                self.op.bias = Parameter(bias)
            self.bn = None






# 带batchNorm卷积的connection
class ConvConnection_withBatchNorm(Connection_withBatchNorm):
    def __init__(self, src, dst, conv=nn.Conv1d, **kwargs):
        super(ConvConnection_withBatchNorm, self).__init__(src, dst, operation=conv, **kwargs)

# 不带batchNorm卷积的connection
class Channel1dConvConnection(Connection):
    def __init__(self, src, dst, conv=nn.Conv1d, **kwargs):
        super(Channel1dConvConnection, self).__init__(src, dst, operation=conv, **kwargs)

class Connection_with_VS_shortcut_withBatchNorm(Connection_withBatchNorm):
    def __init__(
        self,
        src,
        src_shortcut,
        shortcut_opFlag,
        dst,
        operation=nn.Linear,
        row=False,
        target=None,
        bias=False,
        requires_grad=True,
        propagate_gradients=True,
        flatten_input=False,
        name=None,
        regularizers=None,
        constraints=None,
        **kwargs
    ):
        super(Connection_with_VS_shortcut_withBatchNorm, self).__init__(
            src,
            dst,
            operation=operation,
            row=row,
            target=target,
            bias=bias,
            requires_grad=requires_grad,
            propagate_gradients=propagate_gradients,
            flatten_input=flatten_input,
            name=name,
            regularizers=regularizers,
            constraints=constraints,
            **kwargs
        )

        self.src_shortcut=src_shortcut
        self.shortcut_opFlag = shortcut_opFlag


    def forward(self):
        preact = self.src.out
        preact_shortcut = self.src_shortcut.out
        if not self.propagate_gradients:
            preact = preact.detach()
            preact_shortcut = preact_shortcut.detach()
        if self.flatten_input:
            shp = preact.shape
            preact = preact.reshape(shp[:1] + (-1,))


        out = self.op(preact)

        if self.row:
            out = out.reshape(out.shape[0], self.output_shape[0], self.output_shape[1] )

        out = self.bn(out)

        if self.shortcut_opFlag:
            out = out + self.op(preact_shortcut)
        else:
            out = out + preact_shortcut

        self.dst.add_to_state(self.target, out)

class Connection_with_VS_shortcut(Connection):
    def __init__(
        self,
        src,
        src_shortcut,
        shortcut_opFlag,
        dst,
        operation=nn.Linear,
        target=None,
        bias=False,
        requires_grad=True,
        propagate_gradients=True,
        flatten_input=True,
        name=None,
        regularizers=None,
        constraints=None,
        **kwargs
    ):
        super(Connection_with_VS_shortcut, self).__init__(
            src,
            dst,
            operation=operation,
            target=target,
            bias=bias,
            requires_grad=requires_grad,
            propagate_gradients=propagate_gradients,
            flatten_input=flatten_input,
            name=name,
            regularizers=regularizers,
            constraints=constraints,
            **kwargs
        )

        self.src_shortcut=src_shortcut
        self.shortcut_opFlag = shortcut_opFlag


    def forward(self):
        preact = self.src.out
        preact_shortcut = self.src_shortcut.out
        if not self.propagate_gradients:
            preact = preact.detach()
            preact_shortcut = preact_shortcut.detach()
        if self.flatten_input:
            shp = preact.shape
            preact = preact.reshape(shp[:1] + (-1,))
            preact_shortcut = preact_shortcut.reshape(shp[:1] + (-1,))


        out = self.op(preact)

        if self.shortcut_opFlag:
            out = out + self.op(preact_shortcut)
        else:
            out = out + preact_shortcut

        self.dst.add_to_state(self.target, out)

# 拼接多个src的，恒等映射的connection, src should be a list
class Connection_identity_with_multi_src(BaseConnection):
    def __init__(
            self,
            src,
            dst,
            requires_grad=False,  # 默认关闭梯度更新
            propagate_gradients=False,
            flatten_input=True,
            name=None,
            regularizers=None,
            constraints=None,
            **kwargs
    ):
        super(Connection_identity_with_multi_src, self).__init__(
            src,
            dst,
            name=name,
            target=None,
            regularizers=regularizers,
            constraints=constraints,
        )

        self.requires_grad = requires_grad
        self.propagate_gradients = propagate_gradients
        assert flatten_input==True, "flatten_input should be True for Connection_identity_with_multi_src"
        self.flatten_input = flatten_input

        in_features=0
        for each_src in src:
            in_features += each_src.nb_units

        # 创建线性层并将权重初始化为单位矩阵
        self.op = nn.Linear(in_features, in_features, bias=False)
        nn.init.eye_(self.op.weight)  # 初始化为单位矩阵

        # 固定权重不变
        for param in self.op.parameters():
            param.requires_grad = requires_grad

    def configure(self, batch_size, nb_steps, time_step, device, dtype):
        super().configure(batch_size, nb_steps, time_step, device, dtype)

    def get_weights(self):
        return self.op.weight

    def get_regularizer_loss(self):
        reg_loss = torch.tensor(0.0, device=self.device)
        for reg in self.regularizers:
            reg_loss += reg(self.get_weights())
        return reg_loss

    def forward(self):
        preact=[]
        for each_src in self.src:
            each_preact=each_src.out
            shp = each_preact.shape
            each_preact = each_preact.reshape(shp[:1] + (-1,))
            preact.append(each_preact)
        out = preact = torch.cat(preact,dim=1)
        # out = preact.reshape(preact.shape[0], self.dst.shape[0], self.dst.shape[1])

        self.dst.add_to_state(self.target, out)

    def propagate(self):
        self.forward()

    def apply_constraints(self):
        for const in self.constraints:
            const.apply(self.op.weight)


# 单头注意力的operation
class attention_operation(nn.Module):
    def __init__(self, q, k, v, scale=1, propagate_gradients=True):
        super(attention_operation, self).__init__()
        self.q = q
        self.k = k
        self.v = v
        self.scale = scale
        self.propagate_gradients=propagate_gradients

    def forward(self):
        preact_q = self.src_q.out
        preact_k = self.src_k.out
        preact_v = self.src_v.out
        assert preact_q.shape == preact_k.shape == preact_v.shape, \
            "The shape of input group q, k, v must be the same."

        if not self.propagate_gradients:
            preact_q = preact_q.detach()
            preact_k = preact_k.detach()
            preact_v = preact_v.detach()

        preact_q = preact_q.unsqueeze(2)
        preact_k = preact_k.unsqueeze(2)
        preact_v = preact_v.unsqueeze(2)

        x = preact_k.transpose(-2, -1) @ preact_v
        out = (preact_q @ x) * self.scale
        out = out.reshape(out.shape[0], out.shape[1])
        return out

# 单头注意力的connection
class ChannelAttentionConnection(core.NetworkNode):
    def __init__(
        self,
        # src_shortcut,
        src_q,
        src_k,
        src_v,
        dst,
        # shortcut=None,
        # operation=nn.Linear,
        target=None,
        # bias=False,
        requires_grad=True,
        propagate_gradients=True,
        # flatten_input=False,
        name=None,
        regularizers=None,
        scale=1, # 缩放因子默认为1
        # constraints=None,
        **kwargs
    ):

        super(ChannelAttentionConnection, self).__init__(name=name, regularizers=regularizers)
        self.src_q = src_q
        self.src_k = src_k
        self.src_v = src_v
        self.dst = dst
        self.scale = scale
        self.op = attention_operation(src_q, src_k, src_v, scale, propagate_gradients)

        # self.src_shortcut = src_shortcut
        # self.shortcut = shortcut

        if target is None:
            self.target = dst.default_target
        else:
            self.target = target

        self.requires_grad = requires_grad
        self.propagate_gradients = propagate_gradients

    def configure(self, batch_size, nb_steps, time_step, device, dtype):
        super().configure(batch_size, nb_steps, time_step, device, dtype)

    def get_regularizer_loss(self):
        reg_loss = torch.tensor(0.0, device=self.device)
        for reg in self.regularizers:
            reg_loss += reg(self.get_weights())
        return reg_loss

    def apply_constraints(self):
        pass

    def forward(self):
        out = self.op()
        self.dst.add_to_state(self.target, out)

    def propagate(self):
        self.forward()

# 多头注意力的operation
class multihead_attention_operation(nn.Module):
    def __init__(self, num_heads, scale=1, propagate_gradients=True, linearAttention=True,):
        super(multihead_attention_operation, self).__init__()
        # self.src_q = q
        # self.src_k = k
        # self.src_v = v
        self.num_heads = num_heads
        self.scale = scale
        self.propagate_gradients=propagate_gradients
        self.LinearAttention=linearAttention

    def forward(self, preact_q, preact_k, preact_v):
        # preact_q = self.src_q.out
        # preact_k = self.src_k.out
        # preact_v = self.src_v.out
        assert preact_q.shape == preact_k.shape == preact_v.shape, \
            "The shape of input group q, k, v must be the same."
        B, C, N = preact_q.shape

        if not self.propagate_gradients:
            preact_q = preact_q.detach()
            preact_k = preact_k.detach()
            preact_v = preact_v.detach()

        preact_q = (
            preact_q.transpose(-1, -2)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )  # B, self.num_heads, N, C//self.num_heads
        preact_k = (
            preact_k.transpose(-1, -2)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )
        preact_v = (
            preact_v.transpose(-1, -2)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )


        if self.LinearAttention:
            x = preact_k.transpose(-2,-1) @ preact_v
            x = (preact_q @ x) * self.scale
        else:
            x = preact_q @ preact_k.transpose(-2,-1)
            x = (x @ preact_v) * self.scale
        out = x.transpose(2, 3).reshape(B, C, N).contiguous()
        return out

# 多头注意力的connection
class ChannelAttentionConnection_multiHead(core.NetworkNode):
    def __init__(
        self,
        # src_shortcut,
        src_q,
        src_k,
        src_v,
        dst,
        num_heads,
        # shortcut=None,
        # operation=nn.Linear,
        target=None,
        # bias=False,
        requires_grad=True,
        propagate_gradients=True,
        # flatten_input=False,
        name=None,
        regularizers=None,
        scale=1, # 缩放因子默认为1
        linearAttention=True,
        # constraints=None,
        **kwargs
    ):

        super(ChannelAttentionConnection_multiHead, self).__init__(name=name, regularizers=regularizers)
        self.src_q = src_q
        self.src_k = src_k
        self.src_v = src_v
        self.dst = dst
        self.scale = scale
        self.num_heads = num_heads
        self.op = multihead_attention_operation(num_heads, scale, propagate_gradients)

        if target is None:
            self.target = dst.default_target
        else:
            self.target = target

        self.requires_grad = requires_grad
        self.propagate_gradients = propagate_gradients
        self.LinearAttention = linearAttention

    def configure(self, batch_size, nb_steps, time_step, device, dtype):
        super().configure(batch_size, nb_steps, time_step, device, dtype)

    def get_regularizer_loss(self):
        reg_loss = torch.tensor(0.0, device=self.device)
        for reg in self.regularizers:
            reg_loss += reg(self.get_weights())
        return reg_loss

    def apply_constraints(self):
        pass

    def forward(self):
        out=self.op(self.src_q.out, self.src_k.out, self.src_v.out)
        self.dst.add_to_state(self.target, out)

    def propagate(self):
        self.forward()


# 恒等映射的connection
class Connection_Identity(BaseConnection):
    def __init__(
        self,
        src,
        dst,
        bias=False,
        requires_grad=False,  # 默认关闭梯度更新
        propagate_gradients=False,
        flatten_input=True,
        name=None,
        regularizers=None,
        constraints=None,
        input_range="all",
        feedback=False,
        **kwargs
    ):
        super(Connection_Identity, self).__init__(
            src,
            dst,
            name=name,
            target=None,
            regularizers=regularizers,
            constraints=constraints,
        )

        self.requires_grad = requires_grad
        self.propagate_gradients = propagate_gradients
        self.flatten_input = flatten_input

        self.feedback = feedback
        # 确保输入等于输出的线性层
        # if feedback:
        #     in_features = 2
        # else:
        #     in_features = src.nb_units-2
        # # if flatten_input:
        # #     in_features = src.nb_units
        # # else:
        # #     in_features = src.shape[0]
        if input_range=="all":
            self.input_start=0
            self.input_end=src.shape[0]
            in_features=src.shape[0]
        else:
            self.input_start = input_range[0]
            self.input_end = input_range[1]
            in_features = int(input_range[1]) - int(input_range[0])

        # 创建线性层并将权重初始化为单位矩阵
        self.op = nn.Linear(in_features, in_features, bias=bias)
        nn.init.eye_(self.op.weight)  # 初始化为单位矩阵

        # 固定权重不变
        for param in self.op.parameters():
            param.requires_grad = requires_grad


    def configure(self, batch_size, nb_steps, time_step, device, dtype):
        super().configure(batch_size, nb_steps, time_step, device, dtype)

    def get_weights(self):
        return self.op.weight

    def get_regularizer_loss(self):
        reg_loss = torch.tensor(0.0, device=self.device)
        for reg in self.regularizers:
            reg_loss += reg(self.get_weights())
        return reg_loss

    def forward(self):
        preact = self.src.out
        if not self.propagate_gradients:
            preact = preact.detach()

        if self.feedback:
            if preact.shape[1]==1:
                preact=preact[:,:,-2:]
                out = preact.expand(-1, self.dst.shape[0], -1)
                # out = preact.repeat(preact.shape[1], 1)
            elif tuple(preact.shape[1:])==self.dst.shape:
                out = preact
            else:
                raise NotImplementedError
        else:
            preact = preact[:,:,self.input_start:self.input_end]
            out=preact.reshape(preact.shape[0], self.dst.shape[0], self.dst.shape[1])


        # if self.feedback:
        #     if preact.shape[1]==1:
        #         preact=preact[:,:,-2:]
        #         out = preact.expand(-1, self.dst.shape[0], -1)
        #         # out = preact.repeat(preact.shape[1], 1)
        #     elif tuple(preact.shape[1:])==self.dst.shape:
        #         out = preact
        #     else:
        #         raise NotImplementedError
        # else:
        #     preact = preact[:, :, :-2]
        #     out=preact

        # if self.flatten_input:
        #     shp = preact.shape
        #     preact = preact.reshape(shp[:1] + (-1,))
        # if self.feedback:
        #     preact=preact[:,-2:]
        # else:
        #     preact = preact[:, :-2]
        #
        # out = self.op(preact)
        # assert torch.equal(out, preact), "Output is not equal to input."
        #
        # out = out.reshape(out.shape[0], 1, out.shape[1])

        self.dst.add_to_state(self.target, out)

    def propagate(self):
        self.forward()

    def apply_constraints(self):
        for const in self.constraints:
            const.apply(self.op.weight)

# src should be a list
class Connection_withBatchNorm_with_multi_src(BaseConnection):
    def __init__(
        self,
        src,
        dst,
        operation=nn.Linear,
        target=None,
        bias=False,
        requires_grad=True,
        propagate_gradients=True,
        flatten_input=True,
        name=None,
        regularizers=None,
        constraints=None,
        row=False,
        common_linear=None,
        session_encode_linear=None,
        session_encode_input=None,
        **kwargs
    ):
        super(Connection_withBatchNorm_with_multi_src, self).__init__(
            src,
            dst,
            name=name,
            target=target,
            regularizers=regularizers,
            constraints=constraints,
        )

        self.requires_grad = requires_grad
        self.propagate_gradients = propagate_gradients
        self.flatten_input = flatten_input
        self.row = row

        self.op=[]
        for each_src in src:
            if self.row:
                each_op=operation(each_src.nb_units, dst.nb_units, bias=bias, **kwargs)
                nn.init.kaiming_normal_(each_op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
                self.op.append(each_op)
                self.output_shape = (dst.shape[0], dst.shape[1])
            elif flatten_input:
                each_op=operation(each_src.nb_units, dst.shape[0], bias=bias, **kwargs)
                nn.init.kaiming_normal_(each_op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
                self.op.append(each_op)
            else:
                each_op=operation(each_src.shape[0], dst.shape[0], bias=bias, **kwargs)
                nn.init.kaiming_normal_(each_op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
                self.op.append(each_op)

        for each_op in self.op:
            for param in each_op.parameters():
                param.requires_grad = requires_grad
        self.bn = nn.BatchNorm1d(dst.shape[0])

        if common_linear is not None:
            if common_linear==True:
                m_features=0
                for each_op in self.op:
                    m_features += each_op.out_features
                self.common_linear=nn.Linear(m_features, dst.shape[0])
            else:
                self.common_linear = common_linear
        else:
            self.common_linear = False
        if session_encode_linear is not None:
            self.session_encode_linear = session_encode_linear
            self.session_encode_input = session_encode_input
        else:
            self.session_encode_linear = False
            self.session_encode_input = False


    def configure(self, batch_size, nb_steps, time_step, device, dtype):
        super().configure(batch_size, nb_steps, time_step, device, dtype)

    def add_diagonal_structure(self, width=1.0, ampl=1.0):
        if not isinstance(self.op, nn.Linear):
            raise ValueError("Expected op to be nn.Linear to add diagonal structure.")
        A = np.zeros(self.op.weight.shape)
        x = np.linspace(0, A.shape[0], A.shape[1])
        for i in range(len(A)):
            A[i] = ampl * np.exp(-((x - i) ** 2) / width**2)
        self.op.weight.data += torch.from_numpy(A)

    def get_weights(self):
        return self.op.weight

    def get_regularizer_loss(self):
        reg_loss = torch.tensor(0.0, device=self.device)
        for reg in self.regularizers:
            reg_loss += reg(self.get_weights())
        return reg_loss

    def forward(self):
        preact = self.src.out

        out=[]
        for each_src,each_op in zip(self.src,self.op):
            each_preact = each_src.out
            if not self.propagate_gradients:
                each_preact = each_preact.detach()
            if self.flatten_input:
                shp = each_preact.shape
                each_preact = each_preact.reshape(shp[:1] + (-1,))
            each_out= each_op(each_preact)

            out.append(each_out)
        if self.session_encode_linear:
            session_code_preact = self.session_encode_input.out
            if self.flatten_input:
                shp = session_code_preact.shape
                session_code_preact = session_code_preact.reshape(shp[:1] + (-1,))
            session_code_out = self.session_encode_linear(session_code_preact)
            for each_out in out:
                each_out+= session_code_out


        if self.common_linear:
            out=torch.cat(out, dim=1)
            out = self.common_linear(out)
        else:
            out = torch.stack(out, dim=0).sum(dim=0)



        out = self.bn(out)
        if self.row:
            out = out.reshape(out.shape[0], self.output_shape[0], self.output_shape[1] )


        self.dst.add_to_state(self.target, out)

    def propagate(self):
        self.forward()

    def apply_constraints(self):
        for const in self.constraints:
            const.apply(self.op.weight)

# 梯度反转层
class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grad_output):
        # print("GradientReversal!")
        return -ctx.lambda_ * grad_output, None

class Connection_withBatchNorm_with_GradientReverse(Connection_withBatchNorm):
    def __init__(
            self,
            src,
            dst,
            operation=nn.Linear,
            target=None,
            bias=False,
            requires_grad=True,
            propagate_gradients=True,
            flatten_input=False,
            name=None,
            regularizers=None,
            constraints=None,
            row=False,
            lambda_=1.0,  # 梯度反转强度参数
            **kwargs
    ):
        super(Connection_withBatchNorm_with_GradientReverse, self).__init__(
            src,
            dst,
            operation=operation,
            target=target,
            bias=bias,
            requires_grad=requires_grad,
            propagate_gradients=propagate_gradients,
            flatten_input=flatten_input,
            name=name,
            regularizers=regularizers,
            constraints=constraints,
            row=row,
            **kwargs
        )
        self.lambda_ = lambda_  # 存储GRL系数


    def forward(self):
        preact = self.src.out
        if not self.propagate_gradients:
            preact = preact.detach()
        if self.flatten_input:
            shp = preact.shape
            preact = preact.reshape(shp[:1] + (-1,))

        # 应用梯度反转层
        preact = GradientReversalFunction.apply(preact, self.lambda_)  # 应用梯度反转


        out = self.op(preact)

        out = self.bn(out)
        if self.row:
            out = out.reshape(out.shape[0], self.output_shape[0], self.output_shape[1])

        self.dst.add_to_state(self.target, out)



# 卷积Layer
class Channel1dConvConnectionLayer(AbstractLayer):
    """
    Implements a spiking Convolutional Layer
    一维卷积，此处要求输入信号是一维的！
    nb_filters就是注意力的头数
    """

    def __init__(
        self,
        name,
        model,
        input_group,
        kernel_size,
        nb_filters,
        stride=1,
        padding=0,
        shape="same",
        recurrent=True,
        regs=None,
        w_regs=None,
        connection_class=ConvConnection_withBatchNorm,
        neuron_class=nodes.LIFGroup,
        neuron_kwargs={},
        conv=nn.Conv1d,
        connection_kwargs={},
        recurrent_connection_kwargs={},
    ) -> None:

        super().__init__(name, model, recurrent)
        # if conv==RepVGGplusBlock1d or conv==RepVGGplusBlock1dV2:
        if isinstance(conv, RepClassModule):
            if kernel_size==3:
                stride = 1
                kernel_size = 3
                padding = 1

                assert isinstance(
                    nb_filters, int
                ), "Must provide nb_filters to calculate ConvLayer shape"
                shape = utils.convlayer_size(
                    nb_inputs=input_group.shape[1:],
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                )
                shape_dim = len(input_group.shape) - 1
                if shape_dim == 1:
                    shape = tuple([nb_filters, int(shape[0])])
                else:
                    shape = tuple([nb_filters] + [int(i) for i in shape])
            else:
                assert kernel_size % 2 != 0, "kernel_size must be odd"
                stride = 1
                padding = (kernel_size - 1) // 2

                assert isinstance(
                    nb_filters, int
                ), "Must provide nb_filters to calculate ConvLayer shape"
                shape = utils.convlayer_size(
                    nb_inputs=input_group.shape[1:],
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                )
                shape_dim = len(input_group.shape) - 1
                if shape_dim == 1:
                    shape = tuple([nb_filters, int(shape[0])])
                else:
                    shape = tuple([nb_filters] + [int(i) for i in shape])

        elif shape == "same" or stride==1:
            stride=1
            # padding=int((kernel_size-1)/2) # 输出数据长度等于输入数据长度
            padding = 'same'

            shape = (nb_filters, input_group.shape[1])
            print("shape: ", shape)
        else:
            assert isinstance(
                nb_filters, int
            ), "Must provide nb_filters to calculate ConvLayer shape"

            shape = utils.convlayer_size(
                nb_inputs=input_group.shape[1:],
                kernel_size=kernel_size,
                padding=padding,
                stride=stride,
            )

            shape_dim = len(input_group.shape) - 1
            if shape_dim == 1:
                shape = tuple([nb_filters, int(shape[0])])
            else:
                shape = tuple([nb_filters] + [int(i) for i in shape])




            # assert isinstance(
            #         shape, tuple
            #     ), "`shape` must be 'auto' or a tuple of integers"
            # padding = ((shape[-1] - 1) * stride + kernel_size - input_group.shape[-1]) / 2
            # if padding.is_integer():
            #     padding = int(padding)
            # else:
            #     padding2side=padding*2
            #     assert padding2side.is_integer(), "padding2side must be an integer."
            #     padding_left=int((padding2side-1)/2)
            #     padding_right=int((padding2side+1)/2)
            #     padding=(padding_left, padding_right)

            # if padding==0.5:
            #     padding = (1,0)
            # else:
            #     assert padding.is_integer(), "Padding must be an integer."
            #     padding = int(padding)
            # padding = 'same'
        # if shape == "auto":
        #     assert isinstance(
        #         nb_filters, int
        #     ), "Must provide nb_filters to calculate ConvLayer shape"
        #
        #     shape = utils.convlayer_size(
        #         nb_inputs=input_group.shape[1:],
        #         kernel_size=kernel_size,
        #         padding=padding,
        #         stride=stride,
        #     )
        #
        #     shape_dim = len(input_group.shape) - 1
        #     if shape_dim == 1:
        #         shape = tuple([nb_filters, int(shape[0])])
        #     else:
        #         shape = tuple([nb_filters] + [int(i) for i in shape])
        # else:
        #     assert isinstance(
        #         shape, tuple
        #     ), "`shape` must be 'auto' or a tuple of integers"

        # Make neuron group
        nodes = neuron_class(shape, name=self.name, regularizers=regs, **neuron_kwargs)
        self.add_neurons(nodes)

        # Make afferent connection
        con = connection_class(
            input_group,
            nodes,
            conv=conv,
            regularizers=w_regs,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            name=self.name+"_con",
            **connection_kwargs
        )
        self.add_connection(con)

        # Make recurrent connection
        if recurrent:
            # rec_kernel_size = recurrent_connection_kwargs.pop("kernel_size", 5)
            rec_kernel_size = kernel_size
            rec_stride = 1
            rec_padding = "same"

            con = Connection_withBatchNorm(
                input_group,
                nodes,
                row=True,
                flatten_input=True,
                regularizers=w_regs,
                name=self.name + "_recurrent_con",
                **connection_kwargs
            )
            # con = connection_class(
            #     nodes,
            #     nodes,
            #     conv=conv,
            #     regularizers=w_regs,
            #     kernel_size=rec_kernel_size,
            #     stride=rec_stride,
            #     padding=rec_padding,
            #     name=self.name + "_recurrent_con"
            #     **recurrent_connection_kwargs
            # )
            self.add_connection(con)

        self.output_group = nodes

class LinearLayer_with_shortcut(AbstractLayer):

    def __init__(
        self,
        name,
        model,
        size,
        input_group,
        shortcut_group,
        batchNorm=False,
        shortcut_opFlag=False,
        recurrent=True,
        regs=None,
        w_regs=None,
        # connection_class=Connection_with_VS_shortcut,
        neuron_class=nodes.LIFGroup,
        flatten_input_layer=True,
        neuron_kwargs={},
        connection_kwargs={},
    ) -> None:
        super().__init__(name, model, recurrent)

        # Make neuron group
        if len(size)>1: row=True
        else: row=False
        nodes = neuron_class(size, name=self.name, regularizers=regs, **neuron_kwargs)
        self.add_neurons(nodes)

        if batchNorm:
            # Make afferent connection
            con = Connection_with_VS_shortcut_withBatchNorm(
                src=input_group,
                src_shortcut=shortcut_group,
                shortcut_opFlag=shortcut_opFlag,
                dst=nodes,
                row=row,
                name="block2_linera_connection_MLP",
                **connection_kwargs
            )
        else:
            # Make afferent connection
            con = Connection_with_VS_shortcut(
                src=input_group,
                src_shortcut=shortcut_group,
                shortcut_opFlag=shortcut_opFlag,
                dst=nodes,
                name="block2_linera_connection_MLP",
                **connection_kwargs
            )
        self.add_connection(con)

        # Make recurrent connection
        if recurrent:
            con = Connection_withBatchNorm(
                nodes, nodes, row=row, regularizers=w_regs, **connection_kwargs
            )
            self.add_connection(con)

        self.output_group = nodes

class LinearLayer_of_shortcut(AbstractLayer):

    def __init__(
        self,
        name,
        model,
        size,
        # input_group,
        shortcut_group,
        batchNorm=False,
        # shortcut_opFlag=False,
        recurrent=True,
        regs=None,
        w_regs=None,
        # connection_class=Connection_with_VS_shortcut,
        neuron_class=nodes.LIFGroup,
        flatten_input_layer=True,
        operation=nn.Linear,
        neuron_kwargs={},
        connection_kwargs={},
    ) -> None:
        super().__init__(name, model, recurrent)

        # Make neuron group
        nodes = neuron_class(size, name=self.name, regularizers=regs, **neuron_kwargs)
        self.add_neurons(nodes)

        if batchNorm:
            # Make afferent connection
            con = Connection_withBatchNorm(
                src=shortcut_group,
                # src_shortcut=shortcut_group,
                # shortcut_opFlag=shortcut_opFlag,
                dst=nodes,
                name="block2_linera_connection_shortcut",
                flatten_input=flatten_input_layer,
                operation = operation,
                **connection_kwargs
            )
        else:
            # Make afferent connection
            con = Connection(
                src=shortcut_group,
                # src_shortcut=shortcut_group,
                # shortcut_opFlag=shortcut_opFlag,
                dst=nodes,
                name="block2_linera_connection_shortcut",
                flatten_input=flatten_input_layer,
                **connection_kwargs
            )
        self.add_connection(con)

        # Make recurrent connection
        if recurrent:
            con = Connection_withBatchNorm(
                nodes, nodes, regularizers=w_regs, name="block2_connection_shortcut_recurrent", **connection_kwargs
            )
            self.add_connection(con)

        self.output_group = nodes


# 多头注意力qkv合并的connection，使用Repconv
class Connection_of_multihead_qkv(Connection_withBatchNorm):
    def __init__(
        self,
        src,
        dst,
        operation=RepVGGplusBlock1dV2,
        target=None,
        bias=False,
        requires_grad=True,
        propagate_gradients=True,
        flatten_input=False,
        name=None,
        regularizers=None,
        constraints=None,
        row=False,
        common_linear=None,
        session_encode_linear=None,
        session_encode_input=None,
        **kwargs
    ):
        assert flatten_input==False and row==False, "卷积不需要flatten_input和row"
        kwargs['out_groups'] = 3

        super(Connection_of_multihead_qkv, self).__init__(
            src,
            dst,
            operation=operation,
            target=target,
            bias=bias,
            requires_grad=requires_grad,
            propagate_gradients=propagate_gradients,
            flatten_input=flatten_input,
            name=name,
            regularizers=regularizers,
            constraints=constraints,
            row=row,
            common_linear=common_linear,
            session_encode_linear=session_encode_linear,
            session_encode_input=session_encode_input,
            **kwargs
        )

# 多头注意力qkv合并的conv_Layer
class Channel1dConvConnection_multihead_qkv_Layer(AbstractLayer):
    """
    Implements a spiking Convolutional Layer
    一维卷积，此处要求输入信号是一维的！
    nb_filters就是注意力的头数
    """

    def __init__(
        self,
        name,
        model,
        input_group,
        kernel_size,
        nb_filters,
        stride=1,
        padding=0,
        shape="same",
        recurrent=True,
        regs=None,
        w_regs=None,
        # connection_class=Connection_of_multihead_qkv,
        neuron_class=nodes.LIFGroup,
        neuron_kwargs={},
        # conv=nn.Conv1d,
        connection_kwargs={},
        recurrent_connection_kwargs={},
    ) -> None:

        super().__init__(name, model, recurrent)

        connection_class = Connection_of_multihead_qkv
        conv = RepVGGplusBlock1dV2

        # if conv==RepVGGplusBlock1d or conv==RepVGGplusBlock1dV2:
        if isinstance(conv, RepClassModule):
            if kernel_size==3:
                stride = 1
                kernel_size = 3
                padding = 1

                assert isinstance(
                    nb_filters, int
                ), "Must provide nb_filters to calculate ConvLayer shape"
                shape = utils.convlayer_size(
                    nb_inputs=input_group.shape[1:],
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                )
                shape_dim = len(input_group.shape) - 1
                if shape_dim == 1:
                    shape = tuple([nb_filters, int(shape[0])])
                else:
                    shape = tuple([nb_filters] + [int(i) for i in shape])
            else:
                assert kernel_size % 2 != 0, "kernel_size must be odd"
                stride = 1
                padding = (kernel_size - 1) // 2

                assert isinstance(
                    nb_filters, int
                ), "Must provide nb_filters to calculate ConvLayer shape"
                shape = utils.convlayer_size(
                    nb_inputs=input_group.shape[1:],
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                )
                shape_dim = len(input_group.shape) - 1
                if shape_dim == 1:
                    shape = tuple([nb_filters, int(shape[0])])
                else:
                    shape = tuple([nb_filters] + [int(i) for i in shape])
        elif shape == "same" or stride==1:
            stride=1
            # padding=int((kernel_size-1)/2) # 输出数据长度等于输入数据长度
            padding = 'same'

            shape = (nb_filters, input_group.shape[1])
            print("shape: ", shape)
        else:
            assert isinstance(
                nb_filters, int
            ), "Must provide nb_filters to calculate ConvLayer shape"

            shape = utils.convlayer_size(
                nb_inputs=input_group.shape[1:],
                kernel_size=kernel_size,
                padding=padding,
                stride=stride,
            )

            shape_dim = len(input_group.shape) - 1
            if shape_dim == 1:
                shape = tuple([nb_filters, int(shape[0])])
            else:
                shape = tuple([nb_filters] + [int(i) for i in shape])

        shape = (shape[0] * 3, *shape[1:])  # 将通道数乘以3，用于query、key和value
        nodes = neuron_class(shape, name=self.name, regularizers=regs, **neuron_kwargs)
        self.add_neurons(nodes)

        # Make afferent connection
        con = connection_class(
            input_group,
            nodes,
            regularizers=w_regs,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            name=self.name+"_con",
            **connection_kwargs
        )
        self.add_connection(con)

        # Make recurrent connection
        if recurrent:
            # rec_kernel_size = recurrent_connection_kwargs.pop("kernel_size", 5)
            rec_kernel_size = kernel_size
            rec_stride = 1
            rec_padding = "same"

            con = Connection_withBatchNorm(
                input_group,
                nodes,
                row=True,
                flatten_input=True,
                regularizers=w_regs,
                name=self.name + "_recurrent_con",
                **connection_kwargs
            )
            # con = connection_class(
            #     nodes,
            #     nodes,
            #     conv=conv,
            #     regularizers=w_regs,
            #     kernel_size=rec_kernel_size,
            #     stride=rec_stride,
            #     padding=rec_padding,
            #     name=self.name + "_recurrent_con"
            #     **recurrent_connection_kwargs
            # )
            self.add_connection(con)

        self.output_group = nodes


# 多头注意力qkv合并的attention_connection
class ChannelAttentionConnection_multiHead_qkv(core.NetworkNode):
    def __init__(
        self,
        src,
        dst,
        num_heads,
        # shortcut=None,
        # operation=nn.Linear,
        target=None,
        # bias=False,
        requires_grad=True,
        propagate_gradients=True,
        # flatten_input=False,
        name=None,
        regularizers=None,
        scale=1, # 缩放因子默认为1
        linearAttention=True,
        # constraints=None,
        **kwargs
    ):

        super(ChannelAttentionConnection_multiHead_qkv, self).__init__(name=name, regularizers=regularizers)
        self.src = src
        self.dst = dst
        self.scale = scale
        self.num_heads = num_heads
        self.op = multihead_attention_operation(num_heads, scale, propagate_gradients)

        if target is None:
            self.target = dst.default_target
        else:
            self.target = target

        self.requires_grad = requires_grad
        self.propagate_gradients = propagate_gradients
        self.LinearAttention = linearAttention

    def configure(self, batch_size, nb_steps, time_step, device, dtype):
        super().configure(batch_size, nb_steps, time_step, device, dtype)

    def get_regularizer_loss(self):
        reg_loss = torch.tensor(0.0, device=self.device)
        for reg in self.regularizers:
            reg_loss += reg(self.get_weights())
        return reg_loss

    def apply_constraints(self):
        pass

    def forward(self):
        q = self.src.out[:, :self.num_heads, :]
        k = self.src.out[:, self.num_heads:2*self.num_heads, :]
        v = self.src.out[:, 2*self.num_heads:, :]
        out=self.op(q, k, v)
        self.dst.add_to_state(self.target, out)

    def propagate(self):
        self.forward()

# 用于还原firing rate的decoder，输入是group的膜电位
class Connection_mem_spike_decoder(nn.Module):
    def __init__(
        self,
        src,
        output_size,
        bias=False,
        propagate_gradients=True,
        name=None,
        **kwargs
    ):
        super(Connection_mem_spike_decoder, self).__init__()
        self.src = src
        self.input_size = int(src.nb_units)
        self.output_size = int(output_size)
        self.bias = bias
        self.propagate_gradients = propagate_gradients
        self.name = name


        self.linear1 = nn.Linear(self.input_size, self.output_size, bias=bias, **kwargs)
        nn.init.kaiming_normal_(self.linear1.weight, mode='fan_in', nonlinearity='relu')
        self.linear1_bn = nn.BatchNorm1d(self.output_size)
        self.ReLu1 = nn.ReLU()
        self.linear2 = nn.Linear(self.output_size, self.output_size, bias=bias, **kwargs)
        nn.init.kaiming_normal_(self.linear2.weight, mode='fan_in', nonlinearity='relu')
        if bias:
            nn.init.zeros_(self.linear1.bias)
            nn.init.zeros_(self.linear2.bias)

    def reset(self):
        self.output = []

    def forward(self, x=None):
        if x is None:
            x=self.src.mem

        if not self.propagate_gradients:
            x = x.detach()

        shp = x.shape
        x = x.reshape(shp[:1] + (-1,))
        # 第一层线性变换和批归一化
        x = self.linear1(x)
        x = self.linear1_bn(x)
        x = self.ReLu1(x)

        # 第二层线性变换
        x = self.linear2(x)
        x = torch.exp(x)

        self.output.append(x)

        return x

# class Connection_of_feedback_input_withBatchNorm(BaseConnection):
#     def __init__(
#         self,
#         src,
#         dst,
#         operation=nn.Linear,
#         target=None,
#         bias=False,
#         requires_grad=True,
#         propagate_gradients=True,
#         flatten_input=True,
#         name=None,
#         regularizers=None,
#         constraints=None,
#         **kwargs
#     ):
#         super(Connection_of_feedback_input_withBatchNorm, self).__init__(
#             src,
#             dst,
#             name=name,
#             target=target,
#             regularizers=regularizers,
#             constraints=constraints,
#         )
#
#         self.requires_grad = requires_grad
#         self.propagate_gradients = propagate_gradients
#         self.flatten_input = flatten_input
#
#
#         self.op = operation(2, dst.shape[0], bias=bias, **kwargs)
#
#
#         for param in self.op.parameters():
#             param.requires_grad = requires_grad
#         # for param in self.op.parameters():
#         #     param.requires_grad = requires_grad
#         self.bn = nn.BatchNorm1d(dst.shape[0])
#
#
#     def configure(self, batch_size, nb_steps, time_step, device, dtype):
#         super().configure(batch_size, nb_steps, time_step, device, dtype)
#
#     def add_diagonal_structure(self, width=1.0, ampl=1.0):
#         if not isinstance(self.op, nn.Linear):
#             raise ValueError("Expected op to be nn.Linear to add diagonal structure.")
#         A = np.zeros(self.op.weight.shape)
#         x = np.linspace(0, A.shape[0], A.shape[1])
#         for i in range(len(A)):
#             A[i] = ampl * np.exp(-((x - i) ** 2) / width**2)
#         self.op.weight.data += torch.from_numpy(A)
#
#     def get_weights(self):
#         return self.op.weight
#
#     def get_regularizer_loss(self):
#         reg_loss = torch.tensor(0.0, device=self.device)
#         for reg in self.regularizers:
#             reg_loss += reg(self.get_weights())
#         return reg_loss
#
#     def forward(self):
#         preact = self.src.out
#         if not self.propagate_gradients:
#             preact = preact.detach()
#         if self.flatten_input:
#             shp = preact.shape
#             preact = preact.reshape(shp[:1] + (-1,))
#
#         out = self.op(preact)
#
#         if self.row:
#             out = out.reshape(out.shape[0], self.output_shape[0], self.output_shape[1] )
#
#         out = self.bn(out)
#
#         self.dst.add_to_state(self.target, out)
#
#     def propagate(self):
#         self.forward()
#
#     def apply_constraints(self):
#         for const in self.constraints:
#             const.apply(self.op.weight)








# # ChannelAttentionLayer_V1
# # 仅一维注意力（效果可能不好），无batch Norm，仅单头注意力，缩放因子默认为1
# # 膜电位shortcut 有点麻烦，先使用常规shortcut(Vanilla shortcut)
# class ChannelAttentionLayer_V1(AbstractLayer):
#
#     def __init__(
#         self,
#         name,
#         model,
#         input_group_shortcut,
#         input_group_q,
#         input_group_k,
#         input_group_v,
#         shortcut=None,
#         attention_scale=1,
#         regs=None,
#         w_regs=None,
#         # connection_class=ChannelAttentionConnection,
#         neuron_class=nodes.LIFGroup,
#         # flatten_input_layer=True,
#         neuron_kwargs={},
#         connection_kwargs={},
#     ) -> None:
#
#         assert input_group_q.shape==input_group_k.shape==input_group_v.shape, \
#             "The shape of input group q, k and v must be the same."
#         assert len(input_group_q.shape)==1, "The shape of input group q, k and v must be 1 demention."
#         input_group_size = input_group_q.shape[0]
#
#         super().__init__(name, model, recurrent=False)
#
#         # Make neuron group
#         nodes = neuron_class(
#             shape=input_group_size,
#             name=self.name,
#             regularizers=regs,
#             **neuron_kwargs
#         )
#         self.add_neurons(nodes)
#
#         # Make afferent connection
#         con = ChannelAttentionConnection(
#             input_group_shortcut,
#             src_q=input_group_q,
#             src_k=input_group_k,
#             src_v=input_group_v,
#             shortcut=shortcut,
#             dst=nodes,
#             regularizers=w_regs,
#             scale=attention_scale,
#         )
#         self.add_connection(con)
#
#         self.output_group = nodes
#
#
#
# class Connection_with_VS_shortcut_no_attention_input(Connection):
#     def __init__(
#         self,
#         src,
#         dst,
#         # src_shortcut,
#         # shortcut=None,
#         operation=nn.Linear,
#         target=None,
#         bias=False,
#         requires_grad=True,
#         propagate_gradients=True,
#         flatten_input=False,
#         name=None,
#         regularizers=None,
#         constraints=None,
#         **kwargs
#     ):
#         super(Connection_with_VS_shortcut, self).__init__(
#             src,
#             dst,
#             operation=operation,
#             target=target,
#             bias=bias,
#             requires_grad=requires_grad,
#             propagate_gradients=propagate_gradients,
#             flatten_input=flatten_input,
#             name=name,
#             regularizers=regularizers,
#             constraints=constraints,
#             **kwargs
#         )
#         # self.src_shortcut = src_shortcut
#         # self.shortcut = shortcut
#
#     def forward(self):
#         preact = self.src.out
#         if not self.propagate_gradients:
#             preact = preact.detach()
#         if self.flatten_input:
#             shp = preact.shape
#             preact = preact.reshape(shp[:1] + (-1,))
#
#         out = self.op(preact)
#         out = out + preact
#
#         self.dst.add_to_state(self.target, out)
#
#
# class LinearLayer_with_shortcut(AbstractLayer):
#
#     def __init__(
#         self,
#         name,
#         model,
#         size,
#         input_group,
#         recurrent=True,
#         regs=None,
#         w_regs=None,
#         connection_class=connections.Connection,
#         neuron_class=nodes.LIFGroup,
#         flatten_input_layer=True,
#         neuron_kwargs={},
#         connection_kwargs={},
#     ) -> None:
#         super().__init__(name, model, recurrent)
#
#         # Make neuron group
#         nodes = neuron_class(size, name=self.name, regularizers=regs, **neuron_kwargs)
#         self.add_neurons(nodes)
#
#         # Make afferent connection
#         con = connection_class(
#             input_group,
#             nodes,
#             regularizers=w_regs,
#             flatten_input=flatten_input_layer,
#             **connection_kwargs
#         )
#         self.add_connection(con)
#
#         # Make recurrent connection
#         if recurrent:
#             con = connection_class(
#                 nodes, nodes, regularizers=w_regs, **connection_kwargs
#             )
#             self.add_connection(con)
#
#         self.output_group = nodes
