from logging import raiseExceptions

import torch
import stork
from mpmath import zeros
from stork.connections import Connection
from stork.initializers import (
    FluctuationDrivenCenteredNormalInitializer,
    DistInitializer,
    Initializer,
)
from stork.layers import Layer

# Custom additions to stork
from .custom.readout import CustomReadoutGroup, AverageReadouts
from .custom.lif import CustomLIFGroup, Custom_multiSyn_LIFGroup
from .custom.custom_regularizer import LowerBoundL2, UpperBoundL2
from .custom.models import CustomRecurrentSpikingModel
from .custom.custom_connections import *

from .custom.RepConv1d import *
from .custom.custom_inputGroup import *
from .custom.custom_inputGroup import custom_InputGroup as InputGroup
# from stork.nodes import InputGroup
from torch import nn


# from .data import compute_input_firing_rates
from .data_foundationDemo1 import compute_input_firing_rates


import logging
import copy


logger = logging.getLogger(__name__)


def get_regularizers(cfg):
    regs = []
    regLB = stork.regularizers.LowerBoundL2(
        cfg.training.LB_L2_strength, threshold=cfg.training.LB_L2_thresh, dims=False
    )
    regs.append(regLB)
    regUB = stork.regularizers.UpperBoundL2(
        cfg.training.UB_L2_strength, threshold=cfg.training.UB_L2_thresh, dims=1
    )
    regs.append(regUB)
    # regs = []
    # regLB = LowerBoundL2(
    #     cfg.training.LB_L2_strength, threshold=cfg.training.LB_L2_thresh, dims=False
    # )
    # regs.append(regLB)
    # regUB = UpperBoundL2(
    #     cfg.training.UB_L2_strength, threshold=cfg.training.UB_L2_thresh, dims=1
    # )
    # regs.append(regUB)

    return regs


def get_actfn(cfg):
    act_fn = stork.activations.CustomSpike
    if cfg.model.stochastic:
        act_fn.escape_noise_type = "sigmoid"
    else:
        act_fn.escape_noise_type = "step"
    act_fn.escape_noise_params = {"beta": cfg.training.SG_beta}
    act_fn.surrogate_type = "SuperSpike"
    act_fn.surrogate_params = {"beta": cfg.training.SG_beta}

    return act_fn


class ZeroInitializer(Initializer):
    """
    初始化网络连接的权重全部为0。
    """
    def __init__(self):
        super().__init__()

    def _get_weights(self, connection):
        """
        返回一个与连接权重形状相同的全0张量
        """
        shape = connection.op.weight.shape
        return torch.zeros(shape)

    def initialize_connection(self, connection):
        """
        初始化一个连接对象的权重为0
        """
        # 获取全0权重
        weights = self._get_weights(connection)

        # 设置权重和偏置
        self._set_weights_and_bias(connection, weights)

def get_initializers(cfg, nu=None, dtype=torch.float32):

    if nu is None or cfg.initializer.compute_nu is False:
        logger.info("using un-caculated nu")
        nu = cfg.initializer.nu

    if isinstance(nu, list):
        logger.info(f"Initializing with nu = {nu[0]} and {nu[1]}")
        hidden_init = []
        if len(nu) == 2:
            if nu[1]==0:
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=nu[0],
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
                hidden_init.append(ZeroInitializer())
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=nu[0],
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))

            else:
                for i in range(len(nu)):
                    hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                        sigma_u=cfg.initializer.sigma_u,
                        nu=nu[i],
                        timestep=cfg.data.dt,
                        alpha=cfg.initializer.alpha,
                        dtype=dtype
                    ))
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=sum(nu) / len(nu),
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
        elif len(nu)==3:
            if nu[1]==0 or nu[1] is None:
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=nu[0],
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
                hidden_init.append(ZeroInitializer())
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=nu[0],
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
            else:
                for i in range(len(nu)-1):
                    hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                        sigma_u=cfg.initializer.sigma_u,
                        nu=nu[i],
                        timestep=cfg.data.dt,
                        alpha=cfg.initializer.alpha,
                        dtype=dtype
                    ))
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=sum(nu[:-1]) / (len(nu)-1),
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
            hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                sigma_u=cfg.initializer.sigma_u,
                nu=nu[2],
                timestep=cfg.data.dt,
                alpha=cfg.initializer.alpha,
                dtype=dtype
            ))
    else:
        logger.info(f"Initializing with nu = {nu}")

        # Initializers
        hidden_init = FluctuationDrivenCenteredNormalInitializer(
            sigma_u=cfg.initializer.sigma_u,
            nu=nu,
            timestep=cfg.data.dt,
            alpha=cfg.initializer.alpha,
            dtype=dtype
        )


    readout_init = DistInitializer(
        dist=torch.distributions.Normal(0, 1), 
        scaling="1/sqrt(k)", 
        dtype=dtype
    )

    return hidden_init, readout_init

def get_model(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune":
        batchsize_ = cfg.training.batchsize_finetuning
    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    if data is not None:
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None

    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器

    # INPUT LAYER
    # # # # # # # #

    input_group = model.add_group(InputGroup(nb_inputs,
                                             dropout_p=cfg.model.dropout_p))

    # HIDDEN LAYERS
    # # # # # # # #
    current_src_grp = input_group

    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }

    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。初始化隐藏层，并根据输入脉冲频率调整权重。
    for i in range(cfg.model.nb_hidden):

        hidden_layer = Layer(
            name="hidden",
            model=model,
            size=cfg.model.hidden_size[i],
            input_group=current_src_grp,
            recurrent=cfg.model.recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_kwargs={},
        )

        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        if i == 0 and nb_inputs == 192 and data is not None:
            with torch.no_grad():
                hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1

    # READOUT LAYER
    # # # # # # # #
    # 构建输出层:

    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(Connection(current_src_grp, g, dtype=dtype))
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            Connection(current_src_grp, readout_group, dtype=dtype)
        )

        readout_init.initialize(con_ro)

    return model


def get_custom_readouts(cfg, size=None, name=None):
    if size is None:
        size = cfg.data.nb_outputs
    ro_list = []
    for ro, specs in cfg.model["readouts"].items():
        if "tau_mem" in specs:
            tau_mem = specs["tau_mem"]
        else:
            tau_mem = cfg.model.tau_mem_readout
        if "tau_syn" in specs:
            tau_syn = specs["tau_syn"]
        else:
            tau_syn = cfg.model.tau_syn_readout

        if specs["type"] == "default":
            ro_group = CustomReadoutGroup(
                size,
                tau_mem=tau_mem,
                tau_syn=tau_syn,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=False,
                name=name if name is not None else "Readout",
            )
        elif specs["type"] == "delta":
            ro_group = CustomReadoutGroup(
                size,
                tau_mem=tau_mem,
                tau_syn=tau_syn,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=True,
                name=name if name is not None else "Readout",
            )

        ro_group.set_name(ro)
        ro_list.append(ro_group)

    return ro_list

# 对应线性、单头 qkv,
def get_model_attention_V1(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune":
        batchsize_ = cfg.training.batchsize_finetuning
    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    if data is not None:
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None

    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器







    # INPUT LAYER
    # # # # # # # #

    input_group = model.add_group(InputGroup(nb_inputs,
                                             dropout_p=cfg.model.dropout_p))








    # HIDDEN LAYERS
    # # # # # # # #
    current_src_grp = input_group

    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }






    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    for i in range(cfg.model.nb_linear_hidden):

        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            size=cfg.model.linear_hidden_size[i],
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_kwargs={},
        )

        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        if i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
            with torch.no_grad():
                hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1




    attention_scr_size = current_src_grp.shape[0]
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        if cfg.model.Layer_init:
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V1(
                model=model,
                size=attention_scr_size,
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs={},
            )
        else:
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_group(
                model=model,
                size=attention_scr_size,
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs={},
            )




        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        assert len(linear_group_v.shape)==1, \
            "The shape of input group q, k and v must be 1 demention."
        # group
        ChannelAttention_group = CustomLIFGroup(
            attention_scr_size,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        model.add_connection(
            ChannelAttentionConnection(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                # shortcut="VS",
                dtype=dtype,
            )
        )



        # # MLP层
        if cfg.model.Layer_init:
            # MLP层
            if cfg.model.shortcut_Linear:
                linear_layer_MLP = LinearLayer_of_shortcut(
                    name="block2_attention_linear_MLP",
                    model=model,
                    size=attention_scr_size,
                    shortcut_group=current_src_grp,
                    recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                    regs=regs,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_kwargs={},
                )
                # initialize
                hidden_init.initialize(linear_layer_MLP)
                # connection
                linear_connection_MLP = model.add_connection(
                    connection_class(
                        src=ChannelAttention_group,
                        dst=linear_layer_MLP.output_group,
                        name="block2_linera_connection_MLP",
                    )
                )
                hidden_init.initialize(linear_connection_MLP) # initialize


                current_src_grp = linear_layer_MLP.output_group
            else:
                linear_layer_MLP = LinearLayer_with_shortcut(
                    name="block2_attention_linear_MLP",
                    model=model,
                    size=attention_scr_size,
                    input_group=ChannelAttention_group,
                    shortcut_group=current_src_grp,
                    shortcut_opFlag=cfg.model.shortcut_opFlag,
                    batchNorm=batchNorm,
                    recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                    regs=regs,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_kwargs={},
                )
                # initialize
                hidden_init.initialize(linear_layer_MLP)
                current_src_grp = linear_layer_MLP.output_group
        else:
            # group
            linear_group_MLP = CustomLIFGroup(
                attention_scr_size,
                name="block2_linera_group_MLP",
                regularizers=regs,
                **hidden_neuron_kwargs,
            )
            model.add_group(linear_group_MLP)
            # connection
            linear_connection_MLP = model.add_connection(
                Connection_with_VS_shortcut_withBatchNorm(
                    src=ChannelAttention_group,
                    src_shortcut=current_src_grp,
                    shortcut_opFlag=cfg.model.shortcut_opFlag,
                    dst=linear_group_MLP,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_init.initialize(linear_connection_MLP)  # initialize
            # recurrent connection
            if cfg.model.Attention_hidden_recurrent_MLP[i]:
                linear_connection_MLP_recurrent = model.add_connection(
                    Connection_withBatchNorm(
                        linear_group_MLP,
                        linear_group_MLP,
                        name="block2_linera_connection_q_recurrent",
                    )
                )
                hidden_init.initialize(linear_connection_MLP_recurrent)
            current_src_grp = linear_group_MLP






    # block3: 输出层
    # READOUT LAYER
    # # # # # # # #
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(Connection(current_src_grp, g, dtype=dtype))
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            Connection(current_src_grp, readout_group, dtype=dtype)
        )

        readout_init.initialize(con_ro)

    return model


def get_model_attention_V2(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    # if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
    #     output_feedback=cfg.model.output_feedback
    #     nb_inputs += 2
    # else:
    #     output_feedback=False

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
        requires_grad = (not cfg.model.pretrain_forze)
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
        requires_grad = True
    else:
        batchsize_ = cfg.training.batchsize
    assert requires_grad==True, "requires_grad 只能为True"

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    print("计算发放率:")
    # assert data is not None, "数据不能为空"
    if data is not None and cfg.initializer.compute_nu:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        Warning("数据为空，无法计算输入脉冲频率，使用默认值")
        mean1 = cfg.initializer.nu



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器
    print(mean1)



    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    connection_class = Connection_withBatchNorm
    batchNorm = True
    # if cfg.model.BatchNorm:
    #     batchNorm=True
    # else:
    #     connection_class=Connection
    #     batchNorm=False
    model.seq_forward=cfg.model.seq_forward
    model.propagate_seq=[]




    # INPUT LAYER
    if cfg.model.reshape_num>1:
        if nb_inputs==192:
            inputshape=(int(cfg.model.reshape_num), int(nb_inputs/cfg.model.reshape_num))
        elif nb_inputs==96:
            inputshape=(int(cfg.model.reshape_num/2), int(2*nb_inputs/cfg.model.reshape_num))
    else:
        inputshape = (1, nb_inputs)
    input_group = model.add_group(
        InputGroup(
            inputshape,
            dropout_p=cfg.model.dropout_p,
            # output_feedback=output_feedback,
        )
    )
    current_src_grp = input_group
    model.propagate_seq.append(input_group)



    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
        "spike_after_dynamic": cfg.model.spike_after_dynamic,
    }
    connection_kwargs = {
        "requires_grad": requires_grad,
        "bias": cfg.model.Linear_bias,
        "adaptive_bn": cfg.model.adaptive_bn,
        "bn2ln": True if cfg.model.Repconv == "LayerNorm" else False,
    }
    if cfg.model.Repconv == "LayerNorm":
        print("LayerNorm")




    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    # assert cfg.model.nb_linear_hidden==0, "还没进行该部分的优化，目前不能用这个"
    for i in range(cfg.model.nb_linear_hidden):

        # hidden_layer = Layer(
        #     name="block1_linear_hidden",
        #     model=model,
        #     size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
        #     input_group=current_src_grp,
        #     recurrent=cfg.model.linear_hidden_recurrent[i],
        #     regs=regs,
        #     neuron_class=CustomLIFGroup,
        #     neuron_kwargs=hidden_neuron_kwargs,
        #     connection_class=connection_class,
        #     connection_kwargs={
        #         "row": True
        #     },
        # )
        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            # size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
            size=input_group.shape,
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={**connection_kwargs,
                               'row': True,
                               'name':'block1_linear_hidden_connection',
                               'shortcut': cfg.model.hidden_shortcut,
                               },
        )
        if cfg.model.hidden_shortcut:
            print("hidden_shortcut is True")
            hidden_layer.connections[1].shortcut=False
        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # # if output_feedback and data is not None:
        # #     assert  hidden_layer.connections[0].op.weight.shape[1] == 98 or hidden_layer.connections[0].op.weight.shape[1] == 194, \
        # #         "The shape of input group q, k and v must be 1 demention."
        # #     with torch.no_grad():
        # #         hidden_layer.connections[0].op.weight[:, -2:] /= mean3 / mean1
        # #     if i == 0 and nb_inputs == 194 and cfg.data.data_len == 192:
        # #         with torch.no_grad():
        # #             hidden_layer.connections[0].op.weight[:, 96:-2] /= mean2 / mean1
        # # elif i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
        # #     with torch.no_grad():
        # #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1
        # if i == 0 and nb_inputs == 192 and data is not None:
        #     with torch.no_grad():
        #         if mean2 == 0:
        #             hidden_layer.connections[0].op.weight[:, 96:193].zero_()
        #         else:
        #             hidden_layer.connections[0].op.weight[:, 96:193] /= mean2 / mean1

        for c in hidden_layer.connections:
            model.propagate_seq.append(c)
        model.propagate_seq.append(hidden_layer.output_group)





    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        if cfg.model.Attention_qkv=="conv_linear" :
            print("Attention_qkv is conv_linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V5(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                conv=(
                    RepVGGplusBlock1d_ln if cfg.model.Repconv == "LayerNorm"
                    else RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                    else RepVGGplusBlock1dV2 if cfg.model.Repconv
                    else nn.Conv1d),
                connection_kwargs=connection_kwargs,
                cfg=cfg,
            )
        elif cfg.model.Attention_qkv=="conv":
            print("Attention_qkv is conv")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V2(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                conv=(
                    RepVGGplusBlock1d_ln if cfg.model.Repconv=="LayerNorm"
                    else RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                    else RepVGGplusBlock1dV2 if cfg.model.Repconv
                    else nn.Conv1d),
                connection_kwargs=connection_kwargs,
            )
        elif cfg.model.Attention_qkv=="linear":
            if cfg.model.commom_linear:
                common_linear = nn.Linear(current_src_grp.nb_units, current_src_grp.nb_units)
                nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布

            print("Attention_qkv is linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer(
                model=model,
                size=(cfg.model.nb_Attention_embed[i],cfg.model.Attention_hidden_size[i]),
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                # mean1=mean1,mean2=mean2,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs={**connection_kwargs, 'row': True}
            )
        else:
            raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))
        if not requires_grad:
            for group in [linear_group_q, linear_group_k, linear_group_v]:
                for param in group.parameters():
                    param.requires_grad = False
        if cfg.model.Attention_parameter_scale != 1:
            for c in model.connections:
                with torch.no_grad():  # 禁用梯度计算
                    c.op.weight *= cfg.model.Attention_parameter_scale
            # if cfg.model.nb_linear_hidden==0:
            #     for c in model.connections:
            #         with torch.no_grad():  # 禁用梯度计算
            #             c.op.weight *= 100
            # else:
            #     ValueError("Attention_parameter_scale 只能在nb_linear_hidden=0时使用, 但当前nb_linear_hidden={}".format(cfg.model.nb_linear_hidden))
            #





        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        ChannelAttention_group = CustomLIFGroup(
            linear_group_q.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        atten_connection=model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        model.propagate_seq.append(atten_connection)
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                    **connection_kwargs
                    # requires_grad = requires_grad,
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)
            model.propagate_seq.append(linear_connection_Attention_recurrent)
        if not requires_grad:
            for param in ChannelAttention_group.parameters():
                param.requires_grad = False
        model.propagate_seq.append(ChannelAttention_group)




        # MLP层
        if cfg.model.shortcut_Linear:
            # 这个地方其实只是增加了inputgroup和MLPgroup之间的connection以及recurrent connection
            linear_layer_MLP = LinearLayer_of_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=cfg.model.MLP_size[i],
                shortcut_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                batchNorm=batchNorm,
                regs=regs,
                neuron_class=CustomLIFGroup,
                operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear ,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs=connection_kwargs
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            # if nb_inputs == 192:
            #     with torch.no_grad():
            #         if mean2 == 0:
            #             linear_layer_MLP.connections[0].op.weight[:, 96:193].zero_()
            #         else:
            #             linear_layer_MLP.connections[0].op.weight[:, 96:193] /= mean2 / mean1

            # connection用于连接attention的output group以及MLPgroup
            linear_connection_MLP = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=linear_layer_MLP.output_group,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    **connection_kwargs
                )
            )
            hidden_init.initialize(linear_connection_MLP) # initialize
            if cfg.model.MLP_id_map:
                with torch.no_grad():  # 禁用梯度计算
                    linear_connection_MLP.op.weight.data.zero_()  # 使用 in-place 操作 _zero_()


            for c in linear_layer_MLP.connections:
                model.propagate_seq.append(c)
            model.propagate_seq.append(linear_connection_MLP)
            # if cfg.model.forward_shortcut:
            #     linear_connection_MLP.shortcut = True
            #     if current_src_grp != input_group:
            #         linear_layer_MLP.connections[0].shortcut = True

            if current_src_grp!=input_group:
                input_connection_MLP = model.add_connection(
                    connection_class(
                        src=input_group,
                        dst=linear_layer_MLP.output_group,
                        flatten_input=True,
                        name="block2_input_connection_MLP",
                        operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                        **connection_kwargs
                    )
                )
                hidden_init.initialize(input_connection_MLP)  # initialize
                model.propagate_seq.append(input_connection_MLP)

            model.propagate_seq.append(linear_layer_MLP.output_group)
            current_src_grp = linear_layer_MLP.output_group
        else:
            linear_layer_MLP = LinearLayer_with_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=cfg.model.MLP_size[i],
                input_group=ChannelAttention_group,
                shortcut_group=current_src_grp,
                shortcut_opFlag=cfg.model.shortcut_opFlag,
                batchNorm=batchNorm,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={},
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            current_src_grp = linear_layer_MLP.output_group
            for c in linear_layer_MLP.connections:
                model.propagate_seq.append(c)
            model.propagate_seq.append(linear_layer_MLP.output_group)




    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    **connection_kwargs
                )
            )
            if not cfg.model.output_BN:
                print("Removing output batch normalization layer...")
                con_ro.bn = None
            readout_init.initialize(con_ro)
            model.propagate_seq.append(con_ro)
            model.propagate_seq.append(g)

        average_ro_g = model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
        model.propagate_seq.append(average_ro_g)
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp,
                             readout_group,
                             dtype=dtype,
                             flatten_input=True,
                             operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                             **connection_kwargs)
        )
        if not cfg.model.output_BN:
            print("Removing output batch normalization layer...")
            con_ro.bn = None
        readout_init.initialize(con_ro)

        model.propagate_seq.append(con_ro)
        model.propagate_seq.append(readout_group)

    # if output_feedback:
    #     model.groups[0].add_src(model.groups[-1])
    if not cfg.model.BatchNorm:
        model.rep()

    # if stateFlag == "pretrain" or stateFlag == 'default':
    #     model.multi_BN_set(cfg)
    #     model.multiHidden_set(cfg)
    #     # if cfg.model.multi_BN:
    #     #
    #     #     model.multiBN = nn.ModuleDict()
    #     #     for key in cfg.pretrain_monkeys:
    #     #         model.multiBN[key] = nn.ModuleDict()
    #     #
    #     #     for c_idx, c in enumerate(model.connections):
    #     #         c_id = f"connection_{c_idx}"
    #     #
    #     #         if hasattr(c, "bn") and c.bn is not None:
    #     #             for key in cfg.pretrain_monkeys:
    #     #                 model.multiBN[key][c_id] = copy.deepcopy(c.bn).to(cfg.device)
    #     #             # 存储引用以便后续访问
    #     #             c.multiBN_id = c_id
    #     #
    #     #         if isinstance(c.op, RepClassModule):
    #     #             for key in cfg.pretrain_monkeys:
    #     #                 model.multiBN[key][c_id] = nn.ModuleDict()
    #     #
    #     #             # 存储引用以便后续访问
    #     #             c.multiBN_id = c_id
    #     #             # c.module_ids = {}
    #     #
    #     #             if hasattr(c.op, "rbr_1x1"):
    #     #                 for key in cfg.pretrain_monkeys:
    #     #                     model.multiBN[key][c_id]["rbr_1x1_bn"]=copy.deepcopy(c.op.rbr_1x1.bn).to(cfg.device)
    #     #             if hasattr(c.op, "rbr_dense"):
    #     #                 for key in cfg.pretrain_monkeys:
    #     #                     model.multiBN[key][c_id]["rbr_dense_bn"]=copy.deepcopy(c.op.rbr_dense.bn).to(cfg.device)
    #     #             if hasattr(c.op, "rbr_identity") and c.op.rbr_identity is not None:
    #     #                 for key in cfg.pretrain_monkeys:
    #     #                     model.multiBN[key][c_id]["rbr_identity"]=copy.deepcopy(c.op.rbr_identity).to(cfg.device)
    #     #             if hasattr(c.op, "model") and isinstance(c.op.model, nn.ModuleList):
    #     #                 ValueError("MultiBN does not support nn.ModuleList in RepClassModule")
    #     #
    #     #
    #     #
    #     #             # for module_idx, module in enumerate(c.op.modules()):
    #     #             #     if isinstance(module, nn.BatchNorm1d):
    #     #             #         module_id = f"module_{module_idx}"
    #     #             #         c.module_ids[module] = module_id
    #     #             #
    #     #             #         for key in cfg.pretrain_monkeys:
    #     #             #             model.multiBN[key][c_id][module_id] = copy.deepcopy(module).to(cfg.device)
    #     #             #             # model.multiBN[key][c][module] = copy.deepcopy(module.state_dict())
    #     # else:
    #     #     model.multiBN = False
    #
    #     # if cfg.model.multi_hidden:
    #     #     assert cfg.model.multi_BN, "multiHidden 需要 multiBN 的支持"
    #     #     assert cfg.model.nb_linear_hidden==1, "multiHidden 只支持 nb_linear_hidden=1 的情况"
    #     #     model.multiHidden = nn.ModuleDict()
    #     #     for key in cfg.pretrain_monkeys:
    #     #         model.multiHidden[key] = nn.ModuleDict()
    #     #         model.multiHidden[key]["op"] = copy.deepcopy(hidden_layer.connections[0].op).to(cfg.device)
    #     #         model.multiHidden[key]["LIF"] = copy.deepcopy(hidden_layer.output_group).to(cfg.device)
    #     #         if cfg.model.linear_hidden_recurrent[0]:
    #     #             model.multiHidden[key]["recurrent"] = copy.deepcopy(hidden_layer.connections[1].op).to(cfg.device)
    #     # else:
    #     #     model.multiHidden = False
    # else:
    #     model.multiBN = False
    #     model.multiHidden = False
    model.multi_BN_set(cfg)
    model.multiHidden_set(cfg)

    return model

# 对应卷积hidden layer
def get_model_attention_V3(cfg, nb_inputs, dtype, data=None, stateFlag="default"):


    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune":
        batchsize_ = cfg.training.batchsize_finetuning
    else:
        batchsize_ = cfg.training.batchsize
    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    if data is not None:
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None

    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False





    # INPUT LAYER
    # # # # # # # #
    input_group = model.add_group(
        InputGroup(
            (1,nb_inputs),
            dropout_p=cfg.model.dropout_p
        )
    )








    # HIDDEN LAYERS
    # # # # # # # #
    current_src_grp = input_group

    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }





    # assert cfg.model.nb_linear_hidden==0 or cfg.model.nb_conv_hidden==0, \
    #     "at most one kind of hidden can be used"
    # block1: 卷积隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    for i in range(cfg.model.nb_conv_hidden):

        # if nb_inputs==192:
        #     conv_stride = 2
        # elif nb_inputs==96:
        #     conv_stride = 1
        # else:
        #     raise RuntimeError("Unsupported input size. Expected nb_inputs to be either 192 or 96.")
        # conv_stride = nb_inputs/cfg.model.conv_hidden_size[i]
        # assert conv_stride.is_integer(), "conv_stride must be an integer."
        # conv_stride=int(conv_stride)

        hidden_layer = Channel1dConvConnectionLayer(
            name="block1_conv_hidden",
            model=model,
            kernel_size=cfg.model.conv_hidden_kernel_size[i],
            nb_filters=cfg.model.conv_hidden_nb_filters[i],

            stride=cfg.model.conv_hidden_stride,
            shape=None,

            input_group=current_src_grp,
            recurrent=cfg.model.conv_hidden_recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_kwargs={},
        )


        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # if i == 0 and nb_inputs == 192 and data is not None:
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1





    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V2(
            model=model,
            kernel_size=cfg.model.Attention_kernel_size[i],
            nb_head=cfg.model.nb_Attention_embed[i],
            input_group=current_src_grp,
            recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
            regs=regs,
            hidden_init=hidden_init,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            # connection_class=connection_class,
            connection_kwargs={},
        )




        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        ChannelAttention_group = CustomLIFGroup(
            linear_group_q.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )



        # # MLP层
        # MLP层
        if cfg.model.shortcut_Linear:
            linear_layer_MLP = LinearLayer_of_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=attention_scr_size[1],
                shortcut_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                batchNorm=batchNorm,
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={},
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            # connection
            linear_connection_MLP = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=linear_layer_MLP.output_group,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_init.initialize(linear_connection_MLP) # initialize


            current_src_grp = linear_layer_MLP.output_group
        else:
            linear_layer_MLP = LinearLayer_with_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=attention_scr_size,
                input_group=ChannelAttention_group,
                shortcut_group=current_src_grp,
                shortcut_opFlag=cfg.model.shortcut_opFlag,
                batchNorm=batchNorm,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={},
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            current_src_grp = linear_layer_MLP.output_group






    # block3: 输出层
    # READOUT LAYER
    # # # # # # # #
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(Connection(current_src_grp, g, dtype=dtype))
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            Connection(current_src_grp, readout_group, dtype=dtype)
        )

        readout_init.initialize(con_ro)

    return model

# 对应output feed back cross attention
def get_model_attention_V4(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
        output_feedback=cfg.model.output_feedback
        nb_inputs += 2
    else:
        output_feedback=False

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune":
        batchsize_ = cfg.training.batchsize_finetuning
    else:
        batchsize_ = cfg.training.batchsize
    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    if data is not None:
        if output_feedback:
            mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        else:
            mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        # mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None

    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False


    # HIDDEN LAYERS
    # # # # # # # #
    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }


    # INPUT LAYER
    # # # # # # # #
    input_group = model.add_group(
        InputGroup(
            (1,nb_inputs),
            dropout_p=cfg.model.dropout_p,
            output_feedback=output_feedback,
        )
    )
    current_src_grp = input_group
    if cfg.model.output_crossAttention:
        input_input_group = custom_fake_InputGroup(
            (1, nb_inputs-2),
            name="input_input",
        )
        model.add_group(input_input_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=input_input_group,
                dtype=dtype,
                feedback=False,
                name="input_input_con",
            )
        )

        input_feedback_group = custom_feedback_InputGroup(
            (cfg.model.output_feedback_timestep, 2),
            name="input_feedback",
            output_feedback_timestep=cfg.model.output_feedback_timestep,
        )
        model.add_group(input_feedback_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=input_feedback_group,
                dtype=dtype,
                feedback=True,
                name="input_feedback_con",
            )
        )

        feedback_group = CustomLIFGroup(
            input_feedback_group.shape,
            name="feedback_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(feedback_group)
        model.add_connection(
            Connection_Identity(
                src=input_feedback_group,
                dst=feedback_group,
                dtype=dtype,
                feedback=True,
                name="feedback_con",
            )
        )

        current_src_grp = input_input_group











    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    # assert cfg.model.nb_linear_hidden==0, "还没进行该部分的优化，目前不能用这个"
    for i in range(cfg.model.nb_linear_hidden):

        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={
                "row": True
            },
        )

        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # if output_feedback and data is not None:
        #     assert  hidden_layer.connections[0].op.weight.shape[1] == 98 or hidden_layer.connections[0].op.weight.shape[1] == 194, \
        #         "The shape of input group q, k and v must be 1 demention."
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, -2:] /= mean3 / mean1
        #     if i == 0 and nb_inputs == 194 and cfg.data.data_len == 192:
        #         with torch.no_grad():
        #             hidden_layer.connections[0].op.weight[:, 96:-2] /= mean2 / mean1
        # elif i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1
        if i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len == 192:
            with torch.no_grad():
                hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1






    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V3(
            model=model,
            kernel_size=cfg.model.Attention_kernel_size[i],
            nb_head=cfg.model.nb_Attention_embed[i],
            stride=cfg.model.Attention_conv_stride[i],
            input_group=current_src_grp,
            input_feedback_group=feedback_group,
            recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
            regs=regs,
            hidden_init=hidden_init,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            # connection_class=connection_class,
            connection_kwargs={},
        )




        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        ChannelAttention_group = CustomLIFGroup(
            linear_group_q.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)



        # # MLP层
        if cfg.model.shortcut:
            if cfg.model.shortcut_Linear:
                linear_layer_MLP = LinearLayer_of_shortcut(
                    name="block2_attention_linear_MLP",
                    model=model,
                    size=attention_scr_size[1],
                    shortcut_group=current_src_grp,
                    recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                    batchNorm=batchNorm,
                    regs=regs,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_kwargs={},
                )
                # initialize
                hidden_init.initialize(linear_layer_MLP)
                # connection
                linear_connection_MLP = model.add_connection(
                    connection_class(
                        src=ChannelAttention_group,
                        dst=linear_layer_MLP.output_group,
                        flatten_input=True,
                        name="block2_linera_connection_MLP",
                    )
                )
                hidden_init.initialize(linear_connection_MLP) # initialize


                current_src_grp = linear_layer_MLP.output_group
            else:
                linear_layer_MLP = LinearLayer_with_shortcut(
                    name="block2_attention_linear_MLP",
                    model=model,
                    size=attention_scr_size,
                    input_group=ChannelAttention_group,
                    shortcut_group=current_src_grp,
                    shortcut_opFlag=cfg.model.shortcut_opFlag,
                    batchNorm=batchNorm,
                    recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                    regs=regs,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_kwargs={},
                )
                # initialize
                hidden_init.initialize(linear_layer_MLP)
                current_src_grp = linear_layer_MLP.output_group
        else:
            linear_layer_MLP = Layer(
                name="block2_attention_linear_MLP",
                model=model,
                size=attention_scr_size,
                input_group=ChannelAttention_group,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=Connection_withBatchNorm,
                connection_kwargs={
                    "row": True
                },
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            current_src_grp = linear_layer_MLP.output_group






    # block3: 输出层
    # READOUT LAYER
    # # # # # # # #
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                )
            )
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp, readout_group, dtype=dtype, flatten_input=True)
        )

        readout_init.initialize(con_ro)

    if output_feedback:
        model.groups[0].add_src(model.groups[-1])

    return model

# 对应session code
def get_model_attention_V5(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    # if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
    #     output_feedback=cfg.model.output_feedback
    #     nb_inputs += 2
    # else:
    #     output_feedback=False

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
    else:
        batchsize_ = cfg.training.batchsize

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    assert data is not None, "数据不能为空"
    if data is not None:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False





    # INPUT LAYER
    # # # # # # # #
    input_group = model.add_group(
        InputGroup(
            (1,nb_inputs),
            dropout_p=cfg.model.dropout_p,
        )
    )
    current_src_grp = input_group
    if cfg.session_encode:
        session_encode_len=10

        input_input_group = custom_fake_InputGroup(
            (1, nb_inputs-session_encode_len),
            name="input_input",
        )
        model.add_group(input_input_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=input_input_group,
                dtype=dtype,
                feedback=False,
                name="input_input_con",
                input_range=(0,nb_inputs-session_encode_len),
            )
        )
        current_src_grp = input_input_group

        session_code_input_group = custom_fake_InputGroup(
            (1, session_encode_len),
            name="input_input",
        )
        model.add_group(session_code_input_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=session_code_input_group,
                dtype=dtype,
                feedback=False,
                name="input_input_con",
                input_range=(nb_inputs-session_encode_len,nb_inputs),
            )
        )




    # HIDDEN LAYERS
    # # # # # # # #
    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }





    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    # assert cfg.model.nb_linear_hidden==0, "还没进行该部分的优化，目前不能用这个"
    for i in range(cfg.model.nb_linear_hidden):

        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={
                "row": True
            },
        )

        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # if output_feedback and data is not None:
        #     assert  hidden_layer.connections[0].op.weight.shape[1] == 98 or hidden_layer.connections[0].op.weight.shape[1] == 194, \
        #         "The shape of input group q, k and v must be 1 demention."
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, -2:] /= mean3 / mean1
        #     if i == 0 and nb_inputs == 194 and cfg.data.data_len == 192:
        #         with torch.no_grad():
        #             hidden_layer.connections[0].op.weight[:, 96:-2] /= mean2 / mean1
        # elif i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1
        if i == 0 and nb_inputs == 192 and data is not None:
            with torch.no_grad():
                if mean2 == 0:
                    hidden_layer.connections[0].op.weight[:, 96:193].zero_()
                else:
                    hidden_layer.connections[0].op.weight[:, 96:193] /= mean2 / mean1







    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        if cfg.model.Attention_qkv=="conv":
            print("Attention_qkv is conv")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V2(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                # connection_class=connection_class,
                connection_kwargs={},
            )
        elif cfg.model.Attention_qkv=="linear":
            if cfg.model.commom_linear:
                common_linear = nn.Linear(current_src_grp.nb_units, current_src_grp.nb_units)
                nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
                if cfg.session_encode:
                    session_encode_linear = nn.Linear(session_encode_len, current_src_grp.nb_units)
                    nn.init.kaiming_normal_(session_encode_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布

            print("Attention_qkv is linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V1(
                model=model,
                size=(cfg.model.nb_Attention_embed[i],cfg.model.Attention_hidden_size[i]),
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                mean1=mean1,mean2=mean2,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs={
                    "row": True,
                    "common_linear": common_linear if cfg.model.commom_linear else None,
                    "session_encode_linear": session_encode_linear if cfg.session_encode else None,
                    "session_encode_input": session_code_input_group if cfg.session_encode else None,
                },
            )


        else:
            raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))




        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        ChannelAttention_group = CustomLIFGroup(
            linear_group_q.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)




        # MLP层
        if cfg.model.shortcut_Linear:
            linear_layer_MLP = LinearLayer_of_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=attention_scr_size[1],
                shortcut_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                batchNorm=batchNorm,
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={},
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            if nb_inputs == 192:
                with torch.no_grad():
                    if mean2 == 0:
                        linear_layer_MLP.connections[0].op.weight[:, 96:193].zero_()
                    else:
                        linear_layer_MLP.connections[0].op.weight[:, 96:193] /= mean2 / mean1

            # connection
            linear_connection_MLP = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=linear_layer_MLP.output_group,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_init.initialize(linear_connection_MLP) # initialize


            current_src_grp = linear_layer_MLP.output_group
        else:
            linear_layer_MLP = LinearLayer_with_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=attention_scr_size,
                input_group=ChannelAttention_group,
                shortcut_group=current_src_grp,
                shortcut_opFlag=cfg.model.shortcut_opFlag,
                batchNorm=batchNorm,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={},
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            current_src_grp = linear_layer_MLP.output_group






    # block3: 输出层
    # READOUT LAYER
    # # # # # # # #
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                )
            )
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp, readout_group, dtype=dtype, flatten_input=True)
        )

        readout_init.initialize(con_ro)

    # if output_feedback:
    #     model.groups[0].add_src(model.groups[-1])

    return model

# 对应brain area wise
def get_model_attention_V6(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
    else:
        batchsize_ = cfg.training.batchsize

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)
    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    assert data is not None, "数据不能为空"
    if data is not None:
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None
    mean=[mean1, mean2]
    hidden_inits, readout_init = get_initializers(cfg, mean, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False





    # INPUT LAYER
    # # # # # # # #
    input_group = model.add_group(
        InputGroup(
            (1,nb_inputs),
            dropout_p=cfg.model.dropout_p,
            # output_feedback=output_feedback,
        )
    )
    assert cfg.model.area_wise, "目前只支持area_wise"
    if cfg.model.area_wise:
        current_src_grps = []
        assert nb_inputs == 192, "目前只支持192维的输入数据"
        nb_inputs_per_area = int(nb_inputs / 2)

        input_area1_group = custom_fake_InputGroup(
            shape=(1, nb_inputs_per_area),
            name="input_input",
        )
        model.add_group(input_area1_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=input_area1_group,
                dtype=dtype,
                feedback=False,
                name="input_brain_area1(M1)_con",
                input_range=(0,nb_inputs_per_area),
            )
        )
        current_src_grps.append(input_area1_group)

        input_area2_group = custom_fake_InputGroup(
            shape=(1, nb_inputs_per_area),
            name="input_input",
        )
        model.add_group(input_area2_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=input_area2_group,
                dtype=dtype,
                feedback=False,
                name="input_brain_area1(S1)_con",
                input_range=(nb_inputs_per_area,nb_inputs),
            )
        )
        current_src_grps.append(input_area2_group)




    # HIDDEN LAYERS # # # # # # # # # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }



    # 通道注意力层（）
    assert cfg.model.nb_Attention_hidden == 1, "目前只支持一个注意力层"
    i=0
    # brain_area_wise_common_linear = nn.Linear()
    hidden_init_mean=hidden_inits[-1]
    attention_output_group=[]
    attention_scr_size=nb_inputs_per_area
    for area_indx in range(2): #目前只支持两个脑区
        current_src_grp = current_src_grps[area_indx]
        assert current_src_grp.nb_units == attention_scr_size, \
            "The shape of input group must be the same."
        hidden_init = hidden_inits[area_indx]



        # QKV层
        if cfg.model.Attention_qkv=="conv":
            print("Attention_qkv is conv")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V2(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                # connection_class=connection_class,
                connection_kwargs={},
            )
        elif cfg.model.Attention_qkv=="linear":
            if cfg.model.commom_linear:
                common_linear = nn.Linear(current_src_grp.nb_units, current_src_grp.nb_units)
                nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
            print("Attention_qkv is linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V1(
                model=model,
                size=(cfg.model.nb_Attention_embed[i],cfg.model.Attention_hidden_size[i]),
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                mean1=mean1,mean2=mean2,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs={
                    "row": True,
                    "common_linear": common_linear if cfg.model.commom_linear else None
                },
            )
        else:
            raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))




        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        ChannelAttention_group = CustomLIFGroup(
            linear_group_q.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)
        attention_output_group.append(ChannelAttention_group)




    # MLP层
    assert cfg.model.shortcut_Linear, "目前只支持shortcut_Linear"
    if cfg.model.shortcut_Linear:

        # group
        linear_group_MLP = CustomLIFGroup(
            attention_scr_size,
            name="block2_linera_group_MLP",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(linear_group_MLP)
        # connection
        for area_indx in range(2):
            linear_connection_MLP = model.add_connection(
                connection_class(
                    src=attention_output_group[area_indx],
                    dst=linear_group_MLP,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_inits[area_indx].initialize(linear_connection_MLP)  # initialize

            linear_connection_MLP_shortcut = model.add_connection(
                connection_class(
                    src=current_src_grps[area_indx],
                    dst=linear_group_MLP,
                    flatten_input=True,
                    name="block2_linera_connection_MLP_shortcut",
                )
            )
            hidden_inits[area_indx].initialize(linear_connection_MLP_shortcut)  # initialize
        # multi_src=[]
        # for each_attention_output_group,current_src_grps in zip(attention_output_group, current_src_grps):
        #     multi_src.append(each_attention_output_group)
        #     multi_src.append(current_src_grps)
        # linear_multi_src_connection_MLP = model.add_connection(
        #     Connection_withBatchNorm_with_multi_src(
        #         src=multi_src,
        #         dst=linear_group_MLP,
        #         flatten_input=True,
        #         name="block2_linera_connection_MLP_multi_src",
        #         common_linear=True,
        #     )
        # )

        # recurrent connection
        if cfg.model.Attention_hidden_recurrent_MLP[i]:
            linear_connection_MLP_recurrent = model.add_connection(
                Connection_withBatchNorm(
                    linear_group_MLP,
                    linear_group_MLP,
                    name="block2_linera_connection_recurrent",
                )
            )
            hidden_init_mean.initialize(linear_connection_MLP_recurrent)
        current_src_grp = linear_group_MLP


    else:
        linear_layer_MLP = LinearLayer_with_shortcut(
            name="block2_attention_linear_MLP",
            model=model,
            size=attention_scr_size,
            input_group=ChannelAttention_group,
            shortcut_group=current_src_grp,
            shortcut_opFlag=cfg.model.shortcut_opFlag,
            batchNorm=batchNorm,
            recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_kwargs={},
        )
        # initialize
        hidden_init.initialize(linear_layer_MLP)
        current_src_grp = linear_layer_MLP.output_group






    # block3: 输出层
    # READOUT LAYER
    # # # # # # # #
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                )
            )
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp, readout_group, dtype=dtype, flatten_input=True)
        )

        readout_init.initialize(con_ro)

    # if output_feedback:
    #     model.groups[0].add_src(model.groups[-1])

    return model

# 对应session_classfication
def get_model_attention_V7(cfg, nb_inputs, dtype, data=None, stateFlag="default", nb_class=None):


    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
    else:
        batchsize_ = cfg.training.batchsize

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    assert data is not None, "数据不能为空"
    if data is not None:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False





    # INPUT LAYER
    # # # # # # # #
    input_group = model.add_group(
        InputGroup(
            (1,nb_inputs),
            dropout_p=cfg.model.dropout_p,
            # output_feedback=output_feedback,
        )
    )
    current_src_grp = input_group




    # HIDDEN LAYERS
    # # # # # # # #
    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }





    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    # assert cfg.model.nb_linear_hidden==0, "还没进行该部分的优化，目前不能用这个"
    for i in range(cfg.model.nb_linear_hidden):

        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={
                "row": True
            },
        )

        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # if output_feedback and data is not None:
        #     assert  hidden_layer.connections[0].op.weight.shape[1] == 98 or hidden_layer.connections[0].op.weight.shape[1] == 194, \
        #         "The shape of input group q, k and v must be 1 demention."
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, -2:] /= mean3 / mean1
        #     if i == 0 and nb_inputs == 194 and cfg.data.data_len == 192:
        #         with torch.no_grad():
        #             hidden_layer.connections[0].op.weight[:, 96:-2] /= mean2 / mean1
        # elif i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1
        if i == 0 and nb_inputs == 192 and data is not None:
            with torch.no_grad():
                if mean2 == 0:
                    hidden_layer.connections[0].op.weight[:, 96:193].zero_()
                else:
                    hidden_layer.connections[0].op.weight[:, 96:193] /= mean2 / mean1







    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        if cfg.model.Attention_qkv=="conv":
            print("Attention_qkv is conv")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V2(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                # connection_class=connection_class,
                connection_kwargs={},
            )
        elif cfg.model.Attention_qkv=="linear":
            if cfg.model.commom_linear:
                common_linear = nn.Linear(current_src_grp.nb_units, current_src_grp.nb_units)
                nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布

            print("Attention_qkv is linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V1(
                model=model,
                size=(cfg.model.nb_Attention_embed[i],cfg.model.Attention_hidden_size[i]),
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                mean1=mean1,mean2=mean2,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs={
                    "row": True,
                    "common_linear": common_linear if cfg.model.commom_linear else None
                },
            )


        else:
            raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))




        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        ChannelAttention_group = CustomLIFGroup(
            linear_group_q.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)




        # MLP层
        if cfg.model.shortcut_Linear:
            linear_layer_MLP = LinearLayer_of_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=attention_scr_size[1],
                shortcut_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                batchNorm=batchNorm,
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={},
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            if nb_inputs == 192:
                with torch.no_grad():
                    if mean2 == 0:
                        linear_layer_MLP.connections[0].op.weight[:, 96:193].zero_()
                    else:
                        linear_layer_MLP.connections[0].op.weight[:, 96:193] /= mean2 / mean1

            # connection
            linear_connection_MLP = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=linear_layer_MLP.output_group,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_init.initialize(linear_connection_MLP) # initialize


            current_src_grp = linear_layer_MLP.output_group
        else:
            linear_layer_MLP = LinearLayer_with_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=attention_scr_size,
                input_group=ChannelAttention_group,
                shortcut_group=current_src_grp,
                shortcut_opFlag=cfg.model.shortcut_opFlag,
                batchNorm=batchNorm,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={},
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            current_src_grp = linear_layer_MLP.output_group






    # READOUT LAYER block3: 输出层
    if nb_class is not None:
        if cfg.model.multiple_readouts:
            logger.info("Adding session_classfication readout groups")

            custom_readouts_classfication = get_custom_readouts(cfg, size=nb_class, name="session_classfication_readout")
            for g in custom_readouts_classfication:
                model.add_group(g)
                con_ro = model.add_connection(
                    connection_class(
                        current_src_grp, g, dtype=dtype, flatten_input=True,name="session_classfication_readout",
                    )
                )
                readout_init.initialize(con_ro)
            model.class_readout=model.add_group(
                AverageReadouts(model.groups[-len(custom_readouts_classfication):], name="session_classfication_readout")
            )

        else:
            classfication_readout = model.add_group(
                CustomReadoutGroup(
                    nb_class,
                    tau_mem=cfg.model.tau_mem_readout,
                    tau_syn=cfg.model.tau_syn_readout,
                    het_timescales=cfg.model.het_timescales_readout,
                    learn_timescales=cfg.model.learn_timescales_readout,
                    initial_state=-1e-2,
                    is_delta_syn=cfg.model.delta_synapses,
                    name="session_classfication_readout",
                )
            )
            con_ro = model.add_connection(
                Connection_withBatchNorm_with_GradientReverse(
                    current_src_grp,
                    classfication_readout,
                    dtype=dtype,
                    flatten_input=True,
                    name="session_classfication_readout"
                )
            )
            readout_init.initialize(con_ro)
            model.class_readout=classfication_readout

    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                )
            )
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))

    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp, readout_group, dtype=dtype, flatten_input=True)
        )

        readout_init.initialize(con_ro)

    return model

# 去除MLP layer
def get_model_attention_V8(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
        requires_grad = (not cfg.model.pretrain_forze)
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
        requires_grad = True
    else:
        batchsize_ = cfg.training.batchsize
    assert requires_grad==True, "requires_grad 只能为True"

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    assert data is not None, "数据不能为空"
    if data is not None:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    connection_class = Connection_withBatchNorm
    batchNorm = True
    model.seq_forward=cfg.model.seq_forward
    model.propagate_seq=[]




    # INPUT LAYER
    if cfg.model.reshape_num>1:
        if nb_inputs==192:
            inputshape=(int(cfg.model.reshape_num), int(nb_inputs/cfg.model.reshape_num))
        elif nb_inputs==96:
            inputshape=(int(cfg.model.reshape_num/2), int(2*nb_inputs/cfg.model.reshape_num))
    else:
        inputshape = (1, nb_inputs)
    input_group = model.add_group(
        InputGroup(
            inputshape,
            dropout_p=cfg.model.dropout_p,
            # output_feedback=output_feedback,
        )
    )
    current_src_grp = input_group
    model.propagate_seq.append(input_group)



    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
        "spike_after_dynamic": cfg.model.spike_after_dynamic,
    }
    connection_kwargs = {
        "requires_grad": requires_grad,
        "bias": cfg.model.Linear_bias,
        "adaptive_bn": cfg.model.adaptive_bn,
    }




    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    # assert cfg.model.nb_linear_hidden==0, "还没进行该部分的优化，目前不能用这个"
    for i in range(cfg.model.nb_linear_hidden):

        # hidden_layer = Layer(
        #     name="block1_linear_hidden",
        #     model=model,
        #     size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
        #     input_group=current_src_grp,
        #     recurrent=cfg.model.linear_hidden_recurrent[i],
        #     regs=regs,
        #     neuron_class=CustomLIFGroup,
        #     neuron_kwargs=hidden_neuron_kwargs,
        #     connection_class=connection_class,
        #     connection_kwargs={
        #         "row": True
        #     },
        # )
        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            # size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
            size=(cfg.model.reshape_num,cfg.model.linear_hidden_size[i]),
            # size=input_group.shape,
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={**connection_kwargs, 'row': True},
        )
        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # if output_feedback and data is not None:
        #     assert  hidden_layer.connections[0].op.weight.shape[1] == 98 or hidden_layer.connections[0].op.weight.shape[1] == 194, \
        #         "The shape of input group q, k and v must be 1 demention."
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, -2:] /= mean3 / mean1
        #     if i == 0 and nb_inputs == 194 and cfg.data.data_len == 192:
        #         with torch.no_grad():
        #             hidden_layer.connections[0].op.weight[:, 96:-2] /= mean2 / mean1
        # elif i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1
        if i == 0 and nb_inputs == 192 and data is not None:
            with torch.no_grad():
                if mean2 == 0:
                    hidden_layer.connections[0].op.weight[:, 96:193].zero_()
                else:
                    hidden_layer.connections[0].op.weight[:, 96:193] /= mean2 / mean1

        for c in hidden_layer.connections:
            model.propagate_seq.append(c)
        model.propagate_seq.append(hidden_layer.output_group)







    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        if cfg.model.Attention_qkv=="conv_linear" :
            print("Attention_qkv is conv_linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V5(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                conv=(
                    RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                    else RepVGGplusBlock1dV2 if cfg.model.Repconv
                    else nn.Conv1d),
                connection_kwargs=connection_kwargs,
            )
        elif cfg.model.Attention_qkv=="conv":
            print("Attention_qkv is conv")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V2(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                conv=(
                    RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                    else RepVGGplusBlock1dV2 if cfg.model.Repconv
                    else nn.Conv1d),
                connection_kwargs=connection_kwargs,
            )
        elif cfg.model.Attention_qkv=="linear":
            if cfg.model.commom_linear:
                common_linear = nn.Linear(current_src_grp.nb_units, current_src_grp.nb_units)
                nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布

            print("Attention_qkv is linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer(
                model=model,
                size=(cfg.model.nb_Attention_embed[i],cfg.model.Attention_hidden_size[i]),
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                # mean1=mean1,mean2=mean2,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs=connection_kwargs,
            )
            # linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V1(
            #     model=model,
            #     size=(cfg.model.nb_Attention_embed[i],cfg.model.Attention_hidden_size[i]),
            #     input_group=current_src_grp,
            #     recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
            #     regs=regs,
            #     hidden_init=hidden_init,
            #     mean1=mean1,mean2=mean2,
            #     neuron_class=CustomLIFGroup,
            #     neuron_kwargs=hidden_neuron_kwargs,
            #     connection_class=connection_class,
            #     connection_kwargs={
            #         "row": True,
            #         "common_linear": common_linear if cfg.model.commom_linear else None
            #     },
            # )
        else:
            raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))




        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        # ChannelAttention_group = CustomLIFGroup(
        #     linear_group_q.shape,
        #     name="channel_attention_group",
        #     regularizers=regs,
        #     **hidden_neuron_kwargs,
        # )
        # model.add_group(ChannelAttention_group)
        linear_layer_shortcut = LinearLayer_of_shortcut(
            name="block2_channel_attention_layer",
            model=model,
            size=linear_group_q.shape,
            shortcut_group=current_src_grp,
            recurrent=cfg.model.Attention_recurrent[i],
            batchNorm=batchNorm,
            regs=regs,
            neuron_class=CustomLIFGroup,
            operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_kwargs=connection_kwargs,
        )
        # initialize
        hidden_init.initialize(linear_layer_shortcut)
        # if nb_inputs == 192:
        #     with torch.no_grad():
        #         if mean2 == 0:
        #             linear_layer_shortcut.connections[0].op.weight[:, 96:193].zero_()
        #         else:
        #             linear_layer_shortcut.connections[0].op.weight[:, 96:193] /= mean2 / mean1

        # connection
        atten_connection=model.add_connection(
            ChannelAttentionConnection_multiHead(
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=linear_layer_shortcut.output_group,
                num_heads=cfg.model.nb_Attention_head[i],
                dtype=dtype,
            )
        )

        for c in linear_layer_shortcut.connections:
            model.propagate_seq.append(c)
        model.propagate_seq.append(atten_connection)
        model.propagate_seq.append(linear_layer_shortcut.output_group)

        if current_src_grp != input_group:
            input_connection_MLP = model.add_connection(
                connection_class(
                    src=input_group,
                    dst=linear_layer_shortcut.output_group,
                    flatten_input=True,
                    name="block2_input_connection_MLP",
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    **connection_kwargs
                )
            )
            hidden_init.initialize(input_connection_MLP)  # initialize
            model.propagate_seq.append(input_connection_MLP)

        current_src_grp = linear_layer_shortcut.output_group









    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    **connection_kwargs
                )
            )
            readout_init.initialize(con_ro)
            model.propagate_seq.append(con_ro)
            model.propagate_seq.append(g)

        average_ro_g = model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
        model.propagate_seq.append(average_ro_g)
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp,
                             readout_group,
                             dtype=dtype,
                             flatten_input=True,
                             operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                             **connection_kwargs)
        )

        readout_init.initialize(con_ro)

        model.propagate_seq.append(con_ro)
        model.propagate_seq.append(readout_group)

    # if output_feedback:
    #     model.groups[0].add_src(model.groups[-1])
    if not cfg.model.BatchNorm:
        for c in model.connections:
            c.bn=None

    return model

# MLP与readout合并
def get_model_attention_V9(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    # if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
    #     output_feedback=cfg.model.output_feedback
    #     nb_inputs += 2
    # else:
    #     output_feedback=False

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
    else:
        batchsize_ = cfg.training.batchsize

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    assert data is not None, "数据不能为空"
    if data is not None:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False





    # INPUT LAYER
    # # # # # # # #
    input_group = model.add_group(
        InputGroup(
            (1,nb_inputs),
            dropout_p=cfg.model.dropout_p,
            # output_feedback=output_feedback,
        )
    )
    current_src_grp = input_group




    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }





    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    # assert cfg.model.nb_linear_hidden==0, "还没进行该部分的优化，目前不能用这个"
    for i in range(cfg.model.nb_linear_hidden):

        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={
                "row": True
            },
        )

        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # if output_feedback and data is not None:
        #     assert  hidden_layer.connections[0].op.weight.shape[1] == 98 or hidden_layer.connections[0].op.weight.shape[1] == 194, \
        #         "The shape of input group q, k and v must be 1 demention."
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, -2:] /= mean3 / mean1
        #     if i == 0 and nb_inputs == 194 and cfg.data.data_len == 192:
        #         with torch.no_grad():
        #             hidden_layer.connections[0].op.weight[:, 96:-2] /= mean2 / mean1
        # elif i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1
        if i == 0 and nb_inputs == 192 and data is not None:
            with torch.no_grad():
                if mean2 == 0:
                    hidden_layer.connections[0].op.weight[:, 96:193].zero_()
                else:
                    hidden_layer.connections[0].op.weight[:, 96:193] /= mean2 / mean1







    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        if cfg.model.Attention_qkv=="conv":
            print("Attention_qkv is conv")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V2(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                # connection_class=connection_class,
                connection_kwargs={},
            )
        elif cfg.model.Attention_qkv=="linear":
            if cfg.model.commom_linear:
                common_linear = nn.Linear(current_src_grp.nb_units, current_src_grp.nb_units)
                nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布

            print("Attention_qkv is linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V1(
                model=model,
                size=(cfg.model.nb_Attention_embed[i],cfg.model.Attention_hidden_size[i]),
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                mean1=mean1,mean2=mean2,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs={
                    "row": True,
                    "common_linear": common_linear if cfg.model.commom_linear else None
                },
            )


        else:
            raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))




        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        ChannelAttention_group = CustomLIFGroup(
            linear_group_q.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)



    # block3: 输出层 READOUT LAYER
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    ChannelAttention_group, g, dtype=dtype, flatten_input=True,
                )
            )
            readout_init.initialize(con_ro)

            con_shortcut = connection_class(
                src=current_src_grp,
                dst=g,
                name="readout2_linera_connection_shortcut",
                flatten_input=True,
            )
            readout_init.initialize(con_shortcut)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(ChannelAttention_group, readout_group, dtype=dtype, flatten_input=True)
        )
        readout_init.initialize(con_ro)

        con_shortcut = connection_class(
            src=current_src_grp,
            dst=readout_group,
            name="readout2_linera_connection_shortcut",
            flatten_input=True,
        )
        readout_init.initialize(con_shortcut)

    # if output_feedback:
    #     model.groups[0].add_src(model.groups[-1])

    return model

# 在tiny和big的基础上加attention
def add_model_attention(cfg, nb_inputs, dtype, model, data=None, stateFlag="default"):




    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
        requires_grad = False
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
        requires_grad = True
    else:
        batchsize_ = cfg.training.batchsize

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    assert data is not None, "数据不能为空"
    if data is not None:
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器
    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False

    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }



    current_src_grp=model.groups[0]
    dst_grp=model.groups[1]
    attention_scr_size = current_src_grp.shape
    i=0
    # QKV层
    if cfg.model.Attention_qkv=="conv":
        print("Attention_qkv is conv")
        linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V2(
            model=model,
            kernel_size=cfg.model.Attention_kernel_size[i],
            nb_head=cfg.model.nb_Attention_embed[i],
            stride=cfg.model.Attention_conv_stride[i],
            input_group=current_src_grp,
            recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
            regs=regs,
            hidden_init=hidden_init,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            # connection_class=connection_class,
            connection_kwargs={
                "requires_grad": requires_grad,
            },
        )
    elif cfg.model.Attention_qkv=="linear":
        if cfg.model.commom_linear:
            common_linear = nn.Linear(current_src_grp.nb_units, current_src_grp.nb_units)
            nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布

        print("Attention_qkv is linear")
        linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V1(
            model=model,
            size=(cfg.model.nb_Attention_embed[i],cfg.model.Attention_hidden_size[i]),
            input_group=current_src_grp,
            recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
            regs=regs,
            hidden_init=hidden_init,
            mean1=mean1,mean2=mean2,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={
                "row": True,
                "requires_grad": requires_grad,
                "common_linear": common_linear if cfg.model.commom_linear else None
            },
        )
    else:
        raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))
    for group in [linear_group_q, linear_group_k, linear_group_v]:
        group.configure(
            model.batch_size,
            model.nb_time_steps,
            model.time_step,
            model.device,
            model.dtype,
        )
        group.to(model.device)
    for connection in model.connections[-3:]:
        connection.configure(
            model.batch_size,
            model.nb_time_steps,
            model.time_step,
            model.device,
            model.dtype,
        )
        connection.to(model.device)


    # Attention-score层
    assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
        "The shape of input group q, k and v must be the same."
    # group
    ChannelAttention_group = CustomLIFGroup(
        linear_group_q.shape,
        name="channel_attention_group",
        regularizers=regs,
        **hidden_neuron_kwargs,
    )
    model.add_group(ChannelAttention_group)
    # connection
    attention_connection=model.add_connection(
        ChannelAttentionConnection_multiHead(
            # src_shortcut=current_src_grp,
            src_q=linear_group_q,
            src_k=linear_group_k,
            src_v=linear_group_v,
            dst=ChannelAttention_group,
            num_heads=cfg.model.nb_Attention_head[i],
            # shortcut="VS",
            dtype=dtype,
        )
    )
    attention_connection.configure(
        model.batch_size,
        model.nb_time_steps,
        model.time_step,
        model.device,
        model.dtype,
    )
    attention_connection.to(model.device)
    if cfg.model.Attention_recurrent[i]:
        linear_connection_Attention_recurrent = model.add_connection(
            connection_class(
                src=ChannelAttention_group,
                dst=ChannelAttention_group,
                row=True,
                flatten_input=True,
                name="block2_linera_connection_MLP",
                requires_grad = requires_grad,
            )
        )
        hidden_init.initialize(linear_connection_Attention_recurrent)
        linear_connection_Attention_recurrent.configure(
            model.batch_size,
            model.nb_time_steps,
            model.time_step,
            model.device,
            model.dtype,
        )
        linear_connection_Attention_recurrent.to(model.device)
    ChannelAttention_group.configure(
        model.batch_size,
        model.nb_time_steps,
        model.time_step,
        model.device,
        model.dtype,
    )
    ChannelAttention_group.to(model.device)

    # MLP层
    linear_connection_MLP = model.add_connection(
        connection_class(
            src=ChannelAttention_group,
            dst=dst_grp,
            flatten_input=True,
            name="block2_linera_connection_MLP",
            requires_grad = requires_grad,
        )
    )
    hidden_init.initialize(linear_connection_MLP) # initialize
    linear_connection_MLP.configure(
        model.batch_size,
        model.nb_time_steps,
        model.time_step,
        model.device,
        model.dtype,
    )
    linear_connection_MLP.to(model.device)


    model.to(model.device)

# cross Attention加self attention
def get_model_attention_V10(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    output_feedback = True
    cfg.model.output_feedback = output_feedback
    nb_inputs += 2
    output_crossAttention = True


    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune":
        batchsize_ = cfg.training.batchsize_finetuning
    else:
        batchsize_ = cfg.training.batchsize
    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
    mean=[mean1, mean2, mean3]
    hidden_inits, readout_init = get_initializers(cfg, mean, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器



    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False


    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }


    # INPUT LAYER
    # # # # # # # #
    input_group = model.add_group(
        InputGroup(
            (1,nb_inputs),
            output_feedback=output_feedback,
            max_epochs=cfg.training.nb_epochs_pretrain if stateFlag=="pretrain" else cfg.training.nb_epochs_train,
        )
    )
    current_src_grp = input_group
    if output_crossAttention:

        # spike input
        if cfg.model.reshape_num > 1:
            if (nb_inputs-2) == 192:
                inputshape = (int(cfg.model.reshape_num), int((nb_inputs-2) / cfg.model.reshape_num))
            elif (nb_inputs-2) == 96:
                inputshape = (int(cfg.model.reshape_num / 2), int(2 * (nb_inputs-2) / cfg.model.reshape_num))
        else:
            inputshape = (1, (nb_inputs-2))
        input_input_group = custom_fake_InputGroup(
            inputshape,
            dropout_p=cfg.model.dropout_p,
            name="input_input",
        )
        model.add_group(input_input_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=input_input_group,
                dtype=dtype,
                feedback=False,
                name="input_input_con",
                input_range=(0, nb_inputs-2),
            )
        )

        input_feedback_group = custom_feedback_InputGroup(
            (cfg.model.output_feedback_timestep, 2),
            name="input_feedback",
            output_feedback_timestep=cfg.model.output_feedback_timestep,
        )
        model.add_group(input_feedback_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=input_feedback_group,
                dtype=dtype,
                feedback=True,
                name="input_feedback_con",
            )
        )

        feedback_group = CustomLIFGroup(
            input_feedback_group.shape,
            name="feedback_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(feedback_group)
        model.add_connection(
            Connection_Identity(
                src=input_feedback_group,
                dst=feedback_group,
                dtype=dtype,
                feedback=True,
                name="feedback_con",
            )
        )

        current_src_grp = input_input_group




    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        linear_group_k,linear_group_v,linear_layer_q_self,linear_layer_q_cross= get_q_k_v_Layer_V4(
            model=model,
            kernel_size=cfg.model.Attention_kernel_size[i],
            nb_head=cfg.model.nb_Attention_embed[i],
            stride=cfg.model.Attention_conv_stride[i],
            input_group=current_src_grp,
            input_feedback_group=feedback_group,
            recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
            regs=regs,
            hidden_inits=hidden_inits,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            conv=(
                RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                else RepVGGplusBlock1dV2 if cfg.model.Repconv
                else nn.Conv1d),
            connection_class=ConvConnection_withBatchNorm,
            connection_kwargs={},
        )





        # Attention-score层
        assert linear_layer_q_self.shape==linear_layer_q_cross.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # self group
        ChannelAttention_group_self = CustomLIFGroup(
            linear_layer_q_self.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group_self)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_layer_q_self,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group_self,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group_self,
                    dst=ChannelAttention_group_self,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_inits[0].initialize(linear_connection_Attention_recurrent)
        # cross group
        ChannelAttention_group_cross = CustomLIFGroup(
            linear_layer_q_cross.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group_cross)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_layer_q_cross,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group_cross,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group_cross,
                    dst=ChannelAttention_group_cross,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_inits[0].initialize(linear_connection_Attention_recurrent)



        # # MLP层
        linear_group_MLP = CustomLIFGroup(
            cfg.model.MLP_size[i],
            name="block2_linera_group_MLP",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(linear_group_MLP)
        # connection
        linear_connection_MLP_self = model.add_connection(
            connection_class(
                src=ChannelAttention_group_self,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP",
            )
        )
        hidden_inits[2].initialize(linear_connection_MLP_self)  # initialize
        linear_connection_MLP_cross = model.add_connection(
            connection_class(
                src=ChannelAttention_group_cross,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP",
            )
        )
        # hidden_inits[3].initialize(linear_connection_MLP_cross)  # initialize
        nn.init.kaiming_normal_(linear_connection_MLP_cross.op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
        linear_connection_MLP_shortcut_spike = model.add_connection(
            connection_class(
                src=input_input_group,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP_shortcut",
            )
        )
        hidden_inits[2].initialize(linear_connection_MLP_shortcut_spike)  # initialize
        linear_connection_MLP_shortcut_spike = model.add_connection(
            connection_class(
                src=feedback_group,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP_shortcut",
            )
        )
        # hidden_inits[3].initialize(linear_connection_MLP_shortcut_spike)  # initialize
        nn.init.kaiming_normal_(linear_connection_MLP_shortcut_spike.op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布


        # recurrent connection
        if cfg.model.Attention_hidden_recurrent_MLP[i]:
            linear_connection_MLP_recurrent = model.add_connection(
                Connection_withBatchNorm(
                    linear_group_MLP,
                    linear_group_MLP,
                    name="block2_linera_connection_recurrent",
                )
            )
            hidden_inits[2].initialize(linear_connection_MLP_recurrent)


        current_src_grp = linear_group_MLP








    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                )
            )
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp, readout_group, dtype=dtype, flatten_input=True)
        )

        readout_init.initialize(con_ro)

    if output_feedback:
        model.groups[0].add_src(model.groups[-1])

    return model

# cross Attention加self attention
def get_model_attention_V11(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    output_feedback = True
    cfg.model.output_feedback = output_feedback
    nb_inputs += 2
    output_crossAttention = True


    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune":
        batchsize_ = cfg.training.batchsize_finetuning
    else:
        batchsize_ = cfg.training.batchsize
    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
    mean=[mean1, mean2, mean3]
    hidden_inits, readout_init = get_initializers(cfg, mean, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器



    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    if cfg.model.BatchNorm:
        connection_class=Connection_withBatchNorm
        batchNorm=True
    else:
        connection_class=Connection
        batchNorm=False


    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }


    # INPUT LAYER
    # # # # # # # #
    input_group = model.add_group(
        InputGroup(
            (1,nb_inputs),
            output_feedback=output_feedback,
            teacher_forcing_ratio_start = cfg.feedback_teacher_forcing_ratio_start,
            teacher_forcing_ratio_end = cfg.feedback_teacher_forcing_ratio_end,
            max_epochs=cfg.feedback_max_epoch_ratio * (
                cfg.training.nb_epochs_pretrain if stateFlag=="pretrain"
                else cfg.training.nb_epochs_train),
        )
    )
    current_src_grp = input_group
    if output_crossAttention:

        # spike input
        if cfg.model.reshape_num > 1:
            if (nb_inputs-2) == 192:
                inputshape = (int(cfg.model.reshape_num), int((nb_inputs-2) / cfg.model.reshape_num))
            elif (nb_inputs-2) == 96:
                inputshape = (int(cfg.model.reshape_num / 2), int(2 * (nb_inputs-2) / cfg.model.reshape_num))
        else:
            inputshape = (1, (nb_inputs-2))
        input_input_group = custom_fake_InputGroup(
            inputshape,
            dropout_p=cfg.model.dropout_p,
            name="input_input",
        )
        model.add_group(input_input_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=input_input_group,
                dtype=dtype,
                feedback=False,
                name="input_input_con",
                input_range=(0, nb_inputs-2),
            )
        )

        input_feedback_group = custom_feedback_InputGroup(
            (cfg.model.output_feedback_timestep, 2),
            name="input_feedback",
            output_feedback_timestep=cfg.model.output_feedback_timestep,
        )
        model.add_group(input_feedback_group)
        model.add_connection(
            Connection_Identity(
                src=input_group,
                dst=input_feedback_group,
                dtype=dtype,
                feedback=True,
                name="input_feedback_con",
            )
        )

        feedback_group = CustomLIFGroup(
            input_feedback_group.shape,
            name="feedback_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(feedback_group)
        model.add_connection(
            Connection_Identity(
                src=input_feedback_group,
                dst=feedback_group,
                dtype=dtype,
                feedback=True,
                name="feedback_con",
            )
        )

        current_src_grp = input_input_group




    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):

        # QKV层
        linear_group_k,linear_group_v,linear_layer_q_self,linear_layer_q_cross= get_q_k_v_Layer_V4(
            model=model,
            kernel_size=cfg.model.Attention_kernel_size[i],
            nb_head=cfg.model.nb_Attention_embed[i],
            stride=cfg.model.Attention_conv_stride[i],
            input_group=current_src_grp,
            input_feedback_group=feedback_group,
            recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
            regs=regs,
            hidden_inits=hidden_inits,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            conv=(
                RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                else RepVGGplusBlock1dV2 if cfg.model.Repconv
                else nn.Conv1d),
            connection_class=ConvConnection_withBatchNorm,
            connection_kwargs={},
        )





        # Attention-score层
        assert linear_layer_q_self.shape==linear_layer_q_cross.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # self group
        ChannelAttention_group_self = CustomLIFGroup(
            linear_layer_q_self.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group_self)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_layer_q_self,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group_self,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group_self,
                    dst=ChannelAttention_group_self,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_inits[0].initialize(linear_connection_Attention_recurrent)
        # cross group
        ChannelAttention_group_cross = CustomLIFGroup(
            linear_layer_q_cross.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group_cross)
        # connection
        model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_layer_q_cross,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group_cross,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group_cross,
                    dst=ChannelAttention_group_cross,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_inits[0].initialize(linear_connection_Attention_recurrent)



        # # MLP层
        linear_group_MLP = CustomLIFGroup(
            cfg.model.MLP_size[i],
            name="block2_linera_group_MLP",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(linear_group_MLP)
        # connection
        linear_connection_MLP_self = model.add_connection(
            connection_class(
                src=ChannelAttention_group_self,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP",
            )
        )
        hidden_inits[2].initialize(linear_connection_MLP_self)  # initialize
        linear_connection_MLP_cross = model.add_connection(
            connection_class(
                src=ChannelAttention_group_cross,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP",
            )
        )
        # hidden_inits[3].initialize(linear_connection_MLP_cross)  # initialize
        nn.init.kaiming_normal_(linear_connection_MLP_cross.op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
        linear_connection_MLP_shortcut_spike = model.add_connection(
            connection_class(
                src=input_input_group,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP_shortcut",
            )
        )
        hidden_inits[2].initialize(linear_connection_MLP_shortcut_spike)  # initialize
        linear_connection_MLP_shortcut_spike = model.add_connection(
            connection_class(
                src=feedback_group,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP_shortcut",
            )
        )
        # hidden_inits[3].initialize(linear_connection_MLP_shortcut_spike)  # initialize
        nn.init.kaiming_normal_(linear_connection_MLP_shortcut_spike.op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布


        # recurrent connection
        if cfg.model.Attention_hidden_recurrent_MLP[i]:
            linear_connection_MLP_recurrent = model.add_connection(
                Connection_withBatchNorm(
                    linear_group_MLP,
                    linear_group_MLP,
                    name="block2_linera_connection_recurrent",
                )
            )
            hidden_inits[2].initialize(linear_connection_MLP_recurrent)


        current_src_grp = linear_group_MLP








    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                )
            )
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp, readout_group, dtype=dtype, flatten_input=True)
        )

        readout_init.initialize(con_ro)

    if output_feedback:
        model.groups[0].add_src(model.groups[-1])

    return model

# 无attention，ablation，attention替换为线性层
def get_model_attention_V12(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    # if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
    #     output_feedback=cfg.model.output_feedback
    #     nb_inputs += 2
    # else:
    #     output_feedback=False

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
        requires_grad = (not cfg.model.pretrain_forze)
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
        requires_grad = True
    else:
        batchsize_ = cfg.training.batchsize
    assert requires_grad==True, "requires_grad 只能为True"

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    assert data is not None, "数据不能为空"
    if data is not None:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    connection_class = Connection_withBatchNorm
    batchNorm = True
    # if cfg.model.BatchNorm:
    #     batchNorm=True
    # else:
    #     connection_class=Connection
    #     batchNorm=False





    # INPUT LAYER
    if cfg.model.reshape_num>1:
        if nb_inputs==192:
            inputshape=(int(cfg.model.reshape_num), int(nb_inputs/cfg.model.reshape_num))
        elif nb_inputs==96:
            inputshape=(int(cfg.model.reshape_num/2), int(2*nb_inputs/cfg.model.reshape_num))
    else:
        inputshape = (1, nb_inputs)
    input_group = model.add_group(
        InputGroup(
            inputshape,
            dropout_p=cfg.model.dropout_p,
            # output_feedback=output_feedback,
        )
    )
    current_src_grp = input_group




    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }





    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    # assert cfg.model.nb_linear_hidden==0, "还没进行该部分的优化，目前不能用这个"
    for i in range(cfg.model.nb_linear_hidden):

        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={
                "row": True
            },
        )

        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # if output_feedback and data is not None:
        #     assert  hidden_layer.connections[0].op.weight.shape[1] == 98 or hidden_layer.connections[0].op.weight.shape[1] == 194, \
        #         "The shape of input group q, k and v must be 1 demention."
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, -2:] /= mean3 / mean1
        #     if i == 0 and nb_inputs == 194 and cfg.data.data_len == 192:
        #         with torch.no_grad():
        #             hidden_layer.connections[0].op.weight[:, 96:-2] /= mean2 / mean1
        # elif i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1
        if i == 0 and nb_inputs == 192 and data is not None:
            with torch.no_grad():
                if mean2 == 0:
                    hidden_layer.connections[0].op.weight[:, 96:193].zero_()
                else:
                    hidden_layer.connections[0].op.weight[:, 96:193] /= mean2 / mean1







    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):

        # Attention-score层
        # group
        ChannelAttention_group = CustomLIFGroup(
            current_src_grp.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        linear_connection_noAttention = model.add_connection(
            connection_class(
                src=ChannelAttention_group,
                dst=ChannelAttention_group,
                row=True,
                flatten_input=True,
                name="block2_linera_connection_MLP",
                requires_grad=requires_grad,
            )
        )
        hidden_init.initialize(linear_connection_noAttention)
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                    requires_grad = requires_grad,
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)
        if not requires_grad:
            for param in ChannelAttention_group.parameters():
                param.requires_grad = False




        # MLP层
        if cfg.model.shortcut_Linear:
            linear_layer_MLP = LinearLayer_of_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=cfg.model.MLP_size[i],
                shortcut_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                batchNorm=batchNorm,
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={
                    "bias": cfg.model.Linear_bias,
                },
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            # if nb_inputs == 192:
            #     with torch.no_grad():
            #         if mean2 == 0:
            #             linear_layer_MLP.connections[0].op.weight[:, 96:193].zero_()
            #         else:
            #             linear_layer_MLP.connections[0].op.weight[:, 96:193] /= mean2 / mean1

            # connection
            linear_connection_MLP = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=linear_layer_MLP.output_group,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                    requires_grad = requires_grad,
                    bias = cfg.model.Linear_bias,
                )
            )
            hidden_init.initialize(linear_connection_MLP) # initialize


            current_src_grp = linear_layer_MLP.output_group
        else:
            linear_layer_MLP = LinearLayer_with_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=cfg.model.MLP_size[i],
                input_group=ChannelAttention_group,
                shortcut_group=current_src_grp,
                shortcut_opFlag=cfg.model.shortcut_opFlag,
                batchNorm=batchNorm,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={},
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            current_src_grp = linear_layer_MLP.output_group






    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True, bias = cfg.model.Linear_bias,
                )
            )
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp, readout_group, dtype=dtype, flatten_input=True, bias = cfg.model.Linear_bias,)
        )

        readout_init.initialize(con_ro)

    # if output_feedback:
    #     model.groups[0].add_src(model.groups[-1])
    if not cfg.model.BatchNorm:
        for c in model.connections:
            c.bn=None

    return model

# cross Attention加self attention
def get_model_attention_V13(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    output_feedback = True
    cfg.model.output_feedback = output_feedback
    nb_inputs += 2
    output_crossAttention = True


    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune":
        batchsize_ = cfg.training.batchsize_finetuning
    else:
        batchsize_ = cfg.training.batchsize
    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
    mean=[mean1, mean2, mean3]
    hidden_inits, readout_init = get_initializers(cfg, mean, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器



    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    batchNorm = True
    connection_class = Connection_withBatchNorm
    model.seq_forward = cfg.model.seq_forward
    model.propagate_seq = []



    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }


    # INPUT LAYER
    # # # # # # # #
    input_group = model.add_group(
        InputGroup(
            (1,nb_inputs),
            output_feedback=output_feedback,
            teacher_forcing_ratio_start = cfg.model.feedback_teacher_forcing_ratio_start,
            teacher_forcing_ratio_end = cfg.model.feedback_teacher_forcing_ratio_end,
            max_epochs=cfg.model.feedback_max_epoch_ratio * (
                cfg.training.nb_epochs_pretrain if stateFlag=="pretrain"
                else cfg.training.nb_epochs_train),
        )
    )
    current_src_grp = input_group
    model.propagate_seq.append(input_group)
    if output_crossAttention:

        # spike input
        if cfg.model.reshape_num > 1:
            if (nb_inputs-2) == 192:
                inputshape = (int(cfg.model.reshape_num), int((nb_inputs-2) / cfg.model.reshape_num))
            elif (nb_inputs-2) == 96:
                inputshape = (int(cfg.model.reshape_num / 2), int(2 * (nb_inputs-2) / cfg.model.reshape_num))
        else:
            inputshape = (1, (nb_inputs-2))
        input_input_group = custom_fake_InputGroup(
            inputshape,
            dropout_p=cfg.model.dropout_p,
            name="input_input",
        )
        model.add_group(input_input_group)
        input_input_group_con=Connection_Identity(
            src=input_group,
            dst=input_input_group,
            dtype=dtype,
            feedback=False,
            name="input_input_con",
            input_range=(0, nb_inputs - 2),
        )
        input_input_group_con=model.add_connection(input_input_group_con)
        model.propagate_seq.append(input_input_group_con)
        model.propagate_seq.append(input_input_group)

        # feedback input
        input_feedback_group = custom_feedback_InputGroup(
            (cfg.model.output_feedback_timestep, 2),
            name="input_feedback",
            output_feedback_timestep=cfg.model.output_feedback_timestep,
        )
        model.add_group(input_feedback_group)
        input_feedback_group_con = Connection_Identity(
                src=input_group,
                dst=input_feedback_group,
                dtype=dtype,
                feedback=True,
                name="input_feedback_con",
            )
        model.add_connection(input_feedback_group_con)
        model.propagate_seq.append(input_feedback_group_con)
        model.propagate_seq.append(input_feedback_group)

        # feedback spike code
        feedback_group = CustomLIFGroup(
            input_feedback_group.shape,
            name="feedback_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(feedback_group)
        feedback_group_con = Connection_Identity(
                src=input_feedback_group,
                dst=feedback_group,
                dtype=dtype,
                feedback=True,
                name="feedback_con",
            )
        model.add_connection(feedback_group_con)
        model.propagate_seq.append(feedback_group_con)
        model.propagate_seq.append(feedback_group)

        current_src_grp = input_input_group




    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):

        # QKV层
        linear_group_k,linear_group_v,linear_layer_q_self,linear_layer_q_cross= get_q_k_v_Layer_V4(
            model=model,
            kernel_size=cfg.model.Attention_kernel_size[i],
            nb_head=cfg.model.nb_Attention_embed[i],
            stride=cfg.model.Attention_conv_stride[i],
            input_group=current_src_grp,
            input_feedback_group=feedback_group,
            recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
            regs=regs,
            hidden_inits=hidden_inits,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            conv=(
                RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                else RepVGGplusBlock1dV2 if cfg.model.Repconv
                else nn.Conv1d),
            connection_class=ConvConnection_withBatchNorm,
            connection_kwargs={},
        )




        # Attention-score层
        assert linear_layer_q_self.shape==linear_layer_q_cross.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # self group
        ChannelAttention_group_self = CustomLIFGroup(
            linear_layer_q_self.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group_self)
        # connection
        ChannelAttention_con_self = model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_layer_q_self,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group_self,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        model.propagate_seq.append(ChannelAttention_con_self)
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group_self,
                    dst=ChannelAttention_group_self,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_inits[0].initialize(linear_connection_Attention_recurrent)
            model.propagate_seq.append(linear_connection_Attention_recurrent)
        model.propagate_seq.append(ChannelAttention_group_self)

        # cross group
        ChannelAttention_group_cross = CustomLIFGroup(
            linear_layer_q_cross.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group_cross)
        # connection
        ChannelAttention_con_cross = model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_layer_q_cross,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group_cross,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        model.propagate_seq.append(ChannelAttention_con_cross)
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group_cross,
                    dst=ChannelAttention_group_cross,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                )
            )
            hidden_inits[0].initialize(linear_connection_Attention_recurrent)
            model.propagate_seq.append(linear_connection_Attention_recurrent)
        model.propagate_seq.append(ChannelAttention_group_cross)


        # # MLP层
        linear_group_MLP = CustomLIFGroup(
            cfg.model.MLP_size[i],
            name="block2_linera_group_MLP",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(linear_group_MLP)
        # for self attention
        linear_connection_MLP_self = model.add_connection(
            connection_class(
                src=ChannelAttention_group_self,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP",
            )
        )
        hidden_inits[2].initialize(linear_connection_MLP_self)  # initialize
        model.propagate_seq.append(linear_connection_MLP_self)

        # for cross attention
        linear_connection_MLP_cross = model.add_connection(
            connection_class(
                src=ChannelAttention_group_cross,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP",
            )
        )
        # hidden_inits[3].initialize(linear_connection_MLP_cross)  # initialize
        nn.init.kaiming_normal_(linear_connection_MLP_cross.op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
        model.propagate_seq.append(linear_connection_MLP_cross)

        linear_connection_MLP_shortcut_spike = model.add_connection(
            connection_class(
                src=input_input_group,
                dst=linear_group_MLP,
                flatten_input=True,
                name="block2_linera_connection_MLP_shortcut",
            )
        )
        hidden_inits[2].initialize(linear_connection_MLP_shortcut_spike)  # initialize
        model.propagate_seq.append(linear_connection_MLP_shortcut_spike)

        # linear_connection_MLP_shortcut_spike = model.add_connection(
        #     connection_class(
        #         src=feedback_group,
        #         dst=linear_group_MLP,
        #         flatten_input=True,
        #         name="block2_linera_connection_MLP_shortcut",
        #     )
        # )
        # # hidden_inits[3].initialize(linear_connection_MLP_shortcut_spike)  # initialize
        # nn.init.kaiming_normal_(linear_connection_MLP_shortcut_spike.op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
        # model.propagate_seq.append(linear_connection_MLP_shortcut_spike)

        # recurrent connection
        if cfg.model.Attention_hidden_recurrent_MLP[i]:
            linear_connection_MLP_recurrent = model.add_connection(
                Connection_withBatchNorm(
                    linear_group_MLP,
                    linear_group_MLP,
                    name="block2_linera_connection_recurrent",
                )
            )
            hidden_inits[2].initialize(linear_connection_MLP_recurrent)
            model.propagate_seq.append(linear_connection_MLP_recurrent)
        model.propagate_seq.append(linear_group_MLP)


        current_src_grp = linear_group_MLP








    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True, bias = cfg.model.Linear_bias,
                )
            )
            readout_init.initialize(con_ro)
            model.propagate_seq.append(con_ro)
            model.propagate_seq.append(g)

        average_ro_g = model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
        model.propagate_seq.append(average_ro_g)
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp, readout_group, dtype=dtype, flatten_input=True, bias = cfg.model.Linear_bias,)
        )

        readout_init.initialize(con_ro)

        model.propagate_seq.append(con_ro)
        model.propagate_seq.append(readout_group)

    if output_feedback:
        model.groups[0].add_src(model.groups[-1])

    if not cfg.model.BatchNorm:
        for c in model.connections:
            c.bn=None

    return model

# multi syn
def get_model_attention_V14(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
        requires_grad = (not cfg.model.pretrain_forze)
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
        requires_grad = True
    else:
        batchsize_ = cfg.training.batchsize
    assert requires_grad==True, "requires_grad 只能为True"

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    assert data is not None, "数据不能为空"
    if data is not None:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    connection_class = Connection_withBatchNorm
    batchNorm = True
    model.seq_forward=cfg.model.seq_forward
    model.propagate_seq=[]




    # INPUT LAYER
    if cfg.model.reshape_num>1:
        if nb_inputs==192:
            inputshape=(int(cfg.model.reshape_num), int(nb_inputs/cfg.model.reshape_num))
        elif nb_inputs==96:
            inputshape=(int(cfg.model.reshape_num/2), int(2*nb_inputs/cfg.model.reshape_num))
    else:
        inputshape = (1, nb_inputs)
    input_group = model.add_group(
        InputGroup(
            inputshape,
            dropout_p=cfg.model.dropout_p,
            # output_feedback=output_feedback,
        )
    )
    current_src_grp = input_group
    model.propagate_seq.append(input_group)



    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
        "spike_after_dynamic": cfg.model.spike_after_dynamic,
    }





    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 Custom_multiSyn_LIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    # assert cfg.model.nb_linear_hidden==0, "还没进行该部分的优化，目前不能用这个"
    for i in range(cfg.model.nb_linear_hidden):

        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=Custom_multiSyn_LIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={
                "row": True
            },
        )
        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # if output_feedback and data is not None:
        #     assert  hidden_layer.connections[0].op.weight.shape[1] == 98 or hidden_layer.connections[0].op.weight.shape[1] == 194, \
        #         "The shape of input group q, k and v must be 1 demention."
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, -2:] /= mean3 / mean1
        #     if i == 0 and nb_inputs == 194 and cfg.data.data_len == 192:
        #         with torch.no_grad():
        #             hidden_layer.connections[0].op.weight[:, 96:-2] /= mean2 / mean1
        # elif i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1
        if i == 0 and nb_inputs == 192 and data is not None:
            with torch.no_grad():
                if mean2 == 0:
                    hidden_layer.connections[0].op.weight[:, 96:193].zero_()
                else:
                    hidden_layer.connections[0].op.weight[:, 96:193] /= mean2 / mean1

        for c in hidden_layer.connections:
            model.propagate_seq.append(c)
        model.propagate_seq.append(hidden_layer.output_group)







    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        if cfg.model.Attention_qkv=="conv":
            print("Attention_qkv is conv")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V6(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=Custom_multiSyn_LIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                conv=(
                    RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                    else RepVGGplusBlock1dV2 if cfg.model.Repconv
                    else nn.Conv1d),
                connection_kwargs={
                    "requires_grad": requires_grad,
                    "bias": cfg.model.Linear_bias,
                },
            )
        else:
            raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))
        if not requires_grad:
            for group in [linear_group_q, linear_group_k, linear_group_v]:
                for param in group.parameters():
                    param.requires_grad = False





        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        ChannelAttention_group = Custom_multiSyn_LIFGroup(
            linear_group_q.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        atten_connection=model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                name="block2_attention_connection",
                # shortcut="VS",
                dtype=dtype,
            )
        )
        model.propagate_seq.append(atten_connection)
        ChannelAttention_group.add_syn(atten_connection)
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_attention_connection_recurrent",
                    requires_grad = requires_grad,
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)
            ChannelAttention_group.add_syn(linear_connection_Attention_recurrent)
            model.propagate_seq.append(linear_connection_Attention_recurrent)
        if not requires_grad:
            for param in ChannelAttention_group.parameters():
                param.requires_grad = False
        model.propagate_seq.append(ChannelAttention_group)





        # MLP层
        if cfg.model.shortcut_Linear:
            # 这个地方其实只是增加了inputgroup和MLPgroup之间的connection以及recurrent connection
            linear_layer_MLP = LinearLayer_of_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=cfg.model.MLP_size[i],
                shortcut_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                batchNorm=batchNorm,
                regs=regs,
                neuron_class=Custom_multiSyn_LIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={
                    "bias": cfg.model.Linear_bias,
                },
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)


            # connection用于连接attention的output group以及MLPgroup
            linear_connection_MLP = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=linear_layer_MLP.output_group,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                    requires_grad = requires_grad,
                    bias = cfg.model.Linear_bias,
                )
            )
            hidden_init.initialize(linear_connection_MLP) # initialize
            if cfg.model.MLP_id_map:
                with torch.no_grad():  # 禁用梯度计算
                    linear_connection_MLP.op.weight.data.zero_()  # 使用 in-place 操作 _zero_()


            for c in linear_layer_MLP.connections:
                model.propagate_seq.append(c)
                linear_layer_MLP.output_group.add_syn(c)
            model.propagate_seq.append(linear_connection_MLP)
            linear_layer_MLP.output_group.add_syn(linear_connection_MLP)
            model.propagate_seq.append(linear_layer_MLP.output_group)

            current_src_grp = linear_layer_MLP.output_group






    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True, bias = cfg.model.Linear_bias,
                )
            )
            readout_init.initialize(con_ro)
            model.propagate_seq.append(con_ro)
            model.propagate_seq.append(g)

        average_ro_g = model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
        model.propagate_seq.append(average_ro_g)
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp, readout_group, dtype=dtype, flatten_input=True, bias = cfg.model.Linear_bias,)
        )

        readout_init.initialize(con_ro)

        model.propagate_seq.append(con_ro)
        model.propagate_seq.append(readout_group)

    # if output_feedback:
    #     model.groups[0].add_src(model.groups[-1])
    if not cfg.model.BatchNorm:
        for c in model.connections:
            c.bn=None

    return model

# 分两个脑区
def get_model_attention_V15(cfg, nb_inputs, dtype, brain_area=['M1', 'S1'] ,data=None, stateFlag="default"):

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
        requires_grad = (not cfg.model.pretrain_forze)
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
        requires_grad = True
    else:
        batchsize_ = cfg.training.batchsize
    assert requires_grad==True, "requires_grad 只能为True"

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    assert data is not None, "数据不能为空"
    if data is not None:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    connection_class = Connection_withBatchNorm
    batchNorm = True
    model.seq_forward=cfg.model.seq_forward
    model.propagate_seq=[]




    # INPUT LAYER
    input_group = model.add_group(
        InputGroup(
            (1, nb_inputs),
            dropout_p=cfg.model.dropout_p,
        )
    )
    model.propagate_seq.append(input_group)

    shape_of_each_area = (4,24)
    M1_input_group = model.add_group(
        custom_fake_InputGroup(
            shape_of_each_area,
            name="M1_input_group",
        )
    )
    M1_input_con=model.add_connection(
        Connection_Identity(
            src=input_group,
            dst=M1_input_group,
            dtype=dtype,
            name="M1_input_con",
            input_range=(0, 96),
        )
    )
    model.brain_area=['M1']
    model.propagate_seq.append(M1_input_con)
    model.propagate_seq.append(M1_input_group)

    if nb_inputs == 192:
        if brain_area[1] == 'S1':
            S1_input_group = model.add_group(
                custom_fake_InputGroup(
                    shape_of_each_area,
                    name="S1_input_group",
                )
            )
            S1_input_con=model.add_connection(
                Connection_Identity(
                    src=input_group,
                    dst=S1_input_group,
                    dtype=dtype,
                    feedback=False,
                    name="S1_input_con",
                    input_range=(96, 192),
                )
            )
            model.propagate_seq.append(S1_input_con)
            model.propagate_seq.append(S1_input_group)
            model.brain_area.append('S1')
        if brain_area[1] == 'PMd':
            PMd_input_group = model.add_group(
                custom_fake_InputGroup(
                    shape_of_each_area,
                    name="PMd_input_group",
                )
            )
            PMd_input_con=model.add_connection(
                Connection_Identity(
                    src=input_group,
                    dst=PMd_input_group,
                    dtype=dtype,
                    feedback=False,
                    name="PMd_input_con",
                    input_range=(96, 192),
                )
            )
            model.propagate_seq.append(PMd_input_con)
            model.propagate_seq.append(PMd_input_group)
            model.brain_area.append('PMd')


    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
        "spike_after_dynamic": cfg.model.spike_after_dynamic,
    }
    connection_kwargs = {
        "requires_grad": requires_grad,
        "bias": cfg.model.Linear_bias,
        "adaptive_bn": cfg.model.adaptive_bn,
    }

    MLP_group = []
    # block2：通道注意力层（M1）
    if 'M1' in model.brain_area:
        for i in range(cfg.model.nb_Attention_hidden):

            # QKV层
            if cfg.model.Attention_qkv == "conv_linear":
                print("Attention_qkv is conv_linear")
                M1_group_q, M1_group_k, M1_group_v = get_q_k_v_Layer_V5(
                    model=model,
                    kernel_size=cfg.model.Attention_kernel_size[i],
                    nb_head=cfg.model.nb_Attention_embed[i],
                    stride=cfg.model.Attention_conv_stride[i],
                    input_group=M1_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                    regs=regs,
                    hidden_init=hidden_init,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    conv=(
                        RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                        else RepVGGplusBlock1dV2 if cfg.model.Repconv
                        else nn.Conv1d),
                    connection_kwargs=connection_kwargs,
                )
            elif cfg.model.Attention_qkv == "conv":
                print("Attention_qkv is conv")
                M1_group_q, M1_group_k, M1_group_v = get_q_k_v_Layer_V2(
                    model=model,
                    kernel_size=cfg.model.Attention_kernel_size[i],
                    nb_head=cfg.model.nb_Attention_embed[i],
                    stride=cfg.model.Attention_conv_stride[i],
                    input_group=M1_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                    regs=regs,
                    hidden_init=hidden_init,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    conv=(
                        RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                        else RepVGGplusBlock1dV2 if cfg.model.Repconv
                        else nn.Conv1d),
                    connection_kwargs=connection_kwargs,
                )
            elif cfg.model.Attention_qkv == "linear":
                if cfg.model.commom_linear:
                    common_linear = nn.Linear(M1_input_group.nb_units, M1_input_group.nb_units)
                    nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布

                print("Attention_qkv is linear")
                M1_group_q, M1_group_k, M1_group_v = get_q_k_v_Layer(
                    model=model,
                    size=(cfg.model.nb_Attention_embed[i], cfg.model.Attention_hidden_size[i]),
                    input_group=M1_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                    regs=regs,
                    hidden_init=hidden_init,
                    # mean1=mean1,mean2=mean2,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_class=connection_class,
                    connection_kwargs=connection_kwargs,
                )
            else:
                raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))
            if not requires_grad:
                for group in [M1_group_q, M1_group_k, M1_group_v]:
                    for param in group.parameters():
                        param.requires_grad = False

            # Attention-score层
            assert M1_group_q.shape == M1_group_k.shape == M1_group_v.shape, \
                "The shape of input group q, k and v must be the same."
            # group
            ChannelAttention_group = CustomLIFGroup(
                M1_group_q.shape,
                name="M1_channel_attention_group",
                regularizers=regs,
                **hidden_neuron_kwargs,
            )
            model.add_group(ChannelAttention_group)
            # connection
            atten_connection = model.add_connection(
                ChannelAttentionConnection_multiHead(
                    src_q=M1_group_q,
                    src_k=M1_group_k,
                    src_v=M1_group_v,
                    dst=ChannelAttention_group,
                    num_heads=cfg.model.nb_Attention_head[i],
                    # shortcut="VS",
                    dtype=dtype,
                )
            )
            model.propagate_seq.append(atten_connection)
            if cfg.model.Attention_recurrent[i]:
                linear_connection_Attention_recurrent = model.add_connection(
                    connection_class(
                        src=ChannelAttention_group,
                        dst=ChannelAttention_group,
                        row=True,
                        flatten_input=True,
                        name="block2_linera_connection_MLP_M1",
                        **connection_kwargs
                        # requires_grad = requires_grad,
                    )
                )
                hidden_init.initialize(linear_connection_Attention_recurrent)
                model.propagate_seq.append(linear_connection_Attention_recurrent)
            if not requires_grad:
                for param in ChannelAttention_group.parameters():
                    param.requires_grad = False
            model.propagate_seq.append(ChannelAttention_group)

            # MLP层
            if cfg.model.shortcut_Linear:
                # 这个地方其实只是增加了inputgroup和MLPgroup之间的connection以及recurrent connection
                linear_layer_MLP = LinearLayer_of_shortcut(
                    name="block2_attention_linear_MLP",
                    model=model,
                    size=cfg.model.MLP_size[i],
                    shortcut_group=M1_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                    batchNorm=batchNorm,
                    regs=regs,
                    neuron_class=CustomLIFGroup,
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_kwargs=connection_kwargs
                )
                # initialize
                hidden_init.initialize(linear_layer_MLP)
                # if nb_inputs == 192:
                #     with torch.no_grad():
                #         if mean2 == 0:
                #             linear_layer_MLP.connections[0].op.weight[:, 96:193].zero_()
                #         else:
                #             linear_layer_MLP.connections[0].op.weight[:, 96:193] /= mean2 / mean1

                # connection用于连接attention的output group以及MLPgroup
                linear_connection_MLP = model.add_connection(
                    connection_class(
                        src=ChannelAttention_group,
                        dst=linear_layer_MLP.output_group,
                        flatten_input=True,
                        name="block2_linera_connection_MLP",
                        operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                        **connection_kwargs
                    )
                )
                hidden_init.initialize(linear_connection_MLP)  # initialize
                if cfg.model.MLP_id_map:
                    with torch.no_grad():  # 禁用梯度计算
                        linear_connection_MLP.op.weight.data.zero_()  # 使用 in-place 操作 _zero_()

                for c in linear_layer_MLP.connections:
                    model.propagate_seq.append(c)
                model.propagate_seq.append(linear_connection_MLP)
                model.propagate_seq.append(linear_layer_MLP.output_group)
                M1_MLP_group = linear_layer_MLP.output_group
                MLP_group.append(M1_MLP_group)

    # block2：通道注意力层（S1）
    if 'S1' in model.brain_area:
        for i in range(cfg.model.nb_Attention_hidden):

            # QKV层
            if cfg.model.Attention_qkv == "conv_linear":
                print("Attention_qkv is conv_linear")
                S1_group_q, S1_group_k, S1_group_v = get_q_k_v_Layer_V5(
                    model=model,
                    kernel_size=cfg.model.Attention_kernel_size[i],
                    nb_head=cfg.model.nb_Attention_embed[i],
                    stride=cfg.model.Attention_conv_stride[i],
                    input_group=S1_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                    regs=regs,
                    hidden_init=hidden_init,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    conv=(
                        RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                        else RepVGGplusBlock1dV2 if cfg.model.Repconv
                        else nn.Conv1d),
                    connection_kwargs=connection_kwargs,
                )
            elif cfg.model.Attention_qkv == "conv":
                print("Attention_qkv is conv")
                S1_group_q, S1_group_k, S1_group_v = get_q_k_v_Layer_V2(
                    model=model,
                    kernel_size=cfg.model.Attention_kernel_size[i],
                    nb_head=cfg.model.nb_Attention_embed[i],
                    stride=cfg.model.Attention_conv_stride[i],
                    input_group=S1_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                    regs=regs,
                    hidden_init=hidden_init,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    conv=(
                        RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                        else RepVGGplusBlock1dV2 if cfg.model.Repconv
                        else nn.Conv1d),
                    connection_kwargs=connection_kwargs,
                )
            elif cfg.model.Attention_qkv == "linear":
                if cfg.model.commom_linear:
                    common_linear = nn.Linear(S1_input_group.nb_units, S1_input_group.nb_units)
                    nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')

                print("Attention_qkv is linear")
                S1_group_q, S1_group_k, S1_group_v = get_q_k_v_Layer(
                    model=model,
                    size=(cfg.model.nb_Attention_embed[i], cfg.model.Attention_hidden_size[i]),
                    input_group=S1_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                    regs=regs,
                    hidden_init=hidden_init,
                    # mean1=mean1,mean2=mean2,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_class=connection_class,
                    connection_kwargs=connection_kwargs,
                )
            else:
                raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))
            if not requires_grad:
                for group in [S1_group_q, S1_group_k, S1_group_v]:
                    for param in group.parameters():
                        param.requires_grad = False

            # Attention-score层
            assert S1_group_q.shape == S1_group_k.shape == S1_group_v.shape, \
                "The shape of input group q, k and v must be the same."
            # group
            ChannelAttention_group = CustomLIFGroup(
                S1_group_q.shape,
                name="S1_channel_attention_group",
                regularizers=regs,
                **hidden_neuron_kwargs,
            )
            model.add_group(ChannelAttention_group)
            # connection
            atten_connection = model.add_connection(
                ChannelAttentionConnection_multiHead(
                    src_q=S1_group_q,
                    src_k=S1_group_k,
                    src_v=S1_group_v,
                    dst=ChannelAttention_group,
                    num_heads=cfg.model.nb_Attention_head[i],
                    # shortcut="VS",
                    dtype=dtype,
                )
            )
            model.propagate_seq.append(atten_connection)
            if cfg.model.Attention_recurrent[i]:
                linear_connection_Attention_recurrent = model.add_connection(
                    connection_class(
                        src=ChannelAttention_group,
                        dst=ChannelAttention_group,
                        row=True,
                        flatten_input=True,
                        name="block2_linera_connection_MLP_S1",
                        **connection_kwargs
                        # requires_grad = requires_grad,
                    )
                )
                hidden_init.initialize(linear_connection_Attention_recurrent)
                model.propagate_seq.append(linear_connection_Attention_recurrent)
            if not requires_grad:
                for param in ChannelAttention_group.parameters():
                    param.requires_grad = False
            model.propagate_seq.append(ChannelAttention_group)

            # MLP层
            if cfg.model.shortcut_Linear:
                # 这个地方其实只是增加了inputgroup和MLPgroup之间的connection以及recurrent connection
                linear_layer_MLP = LinearLayer_of_shortcut(
                    name="block2_attention_linear_MLP",
                    model=model,
                    size=cfg.model.MLP_size[i],
                    shortcut_group=S1_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                    batchNorm=batchNorm,
                    regs=regs,
                    neuron_class=CustomLIFGroup,
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_kwargs=connection_kwargs
                )
                # initialize
                hidden_init.initialize(linear_layer_MLP)
                # if nb_inputs == 192:
                #     with torch.no_grad():
                #         if mean2 == 0:
                #             linear_layer_MLP.connections[0].op.weight[:, 96:193].zero_()
                #         else:
                #             linear_layer_MLP.connections[0].op.weight[:, 96:193] /= mean2 / mean1

                # connection用于连接attention的output group以及MLPgroup
                linear_connection_MLP = model.add_connection(
                    connection_class(
                        src=ChannelAttention_group,
                        dst=linear_layer_MLP.output_group,
                        flatten_input=True,
                        name="block2_linera_connection_MLP",
                        operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                        **connection_kwargs
                    )
                )
                hidden_init.initialize(linear_connection_MLP)  # initialize
                if cfg.model.MLP_id_map:
                    with torch.no_grad():
                        linear_connection_MLP.op.weight.data.zero_()

                for c in linear_layer_MLP.connections:
                    model.propagate_seq.append(c)
                model.propagate_seq.append(linear_connection_MLP)
                model.propagate_seq.append(linear_layer_MLP.output_group)
                S1_MLP_group = linear_layer_MLP.output_group
                MLP_group.append(S1_MLP_group)

    # block2：通道注意力层（PMd）
    if 'PMd' in model.brain_area:
        for i in range(cfg.model.nb_Attention_hidden):

            # QKV层
            if cfg.model.Attention_qkv == "conv_linear":
                print("Attention_qkv is conv_linear")
                PMd_group_q, PMd_group_k, PMd_group_v = get_q_k_v_Layer_V5(
                    model=model,
                    kernel_size=cfg.model.Attention_kernel_size[i],
                    nb_head=cfg.model.nb_Attention_embed[i],
                    stride=cfg.model.Attention_conv_stride[i],
                    input_group=PMd_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                    regs=regs,
                    hidden_init=hidden_init,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    conv=(
                        RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                        else RepVGGplusBlock1dV2 if cfg.model.Repconv
                        else nn.Conv1d),
                    connection_kwargs=connection_kwargs,
                )
            elif cfg.model.Attention_qkv == "conv":
                print("Attention_qkv is conv")
                PMd_group_q, PMd_group_k, PMd_group_v = get_q_k_v_Layer_V2(
                    model=model,
                    kernel_size=cfg.model.Attention_kernel_size[i],
                    nb_head=cfg.model.nb_Attention_embed[i],
                    stride=cfg.model.Attention_conv_stride[i],
                    input_group=PMd_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                    regs=regs,
                    hidden_init=hidden_init,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    conv=(
                        RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                        else RepVGGplusBlock1dV2 if cfg.model.Repconv
                        else nn.Conv1d),
                    connection_kwargs=connection_kwargs,
                )
            elif cfg.model.Attention_qkv == "linear":
                if cfg.model.commom_linear:
                    common_linear = nn.Linear(PMd_input_group.nb_units, PMd_input_group.nb_units)
                    nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')

                print("Attention_qkv is linear")
                PMd_group_q, PMd_group_k, PMd_group_v = get_q_k_v_Layer(
                    model=model,
                    size=(cfg.model.nb_Attention_embed[i], cfg.model.Attention_hidden_size[i]),
                    input_group=PMd_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                    regs=regs,
                    hidden_init=hidden_init,
                    # mean1=mean1,mean2=mean2,
                    neuron_class=CustomLIFGroup,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_class=connection_class,
                    connection_kwargs=connection_kwargs,
                )
            else:
                raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))
            if not requires_grad:
                for group in [PMd_group_q, PMd_group_k, PMd_group_v]:
                    for param in group.parameters():
                        param.requires_grad = False

            # Attention-score层
            assert PMd_group_q.shape == PMd_group_k.shape == PMd_group_v.shape, \
                "The shape of input group q, k and v must be the same."
            # group
            ChannelAttention_group = CustomLIFGroup(
                PMd_group_q.shape,
                name="PMd_channel_attention_group",
                regularizers=regs,
                **hidden_neuron_kwargs,
            )
            model.add_group(ChannelAttention_group)
            # connection
            atten_connection = model.add_connection(
                ChannelAttentionConnection_multiHead(
                    src_q=PMd_group_q,
                    src_k=PMd_group_k,
                    src_v=PMd_group_v,
                    dst=ChannelAttention_group,
                    num_heads=cfg.model.nb_Attention_head[i],
                    # shortcut="VS",
                    dtype=dtype,
                )
            )
            model.propagate_seq.append(atten_connection)
            if cfg.model.Attention_recurrent[i]:
                linear_connection_Attention_recurrent = model.add_connection(
                    connection_class(
                        src=ChannelAttention_group,
                        dst=ChannelAttention_group,
                        row=True,
                        flatten_input=True,
                        name="block2_linera_connection_MLP_PMd",
                        **connection_kwargs
                        # requires_grad = requires_grad,
                    )
                )
                hidden_init.initialize(linear_connection_Attention_recurrent)
                model.propagate_seq.append(linear_connection_Attention_recurrent)
            if not requires_grad:
                for param in ChannelAttention_group.parameters():
                    param.requires_grad = False
            model.propagate_seq.append(ChannelAttention_group)

            # MLP层
            if cfg.model.shortcut_Linear:
                # 这个地方其实只是增加了inputgroup和MLPgroup之间的connection以及recurrent connection
                linear_layer_MLP = LinearLayer_of_shortcut(
                    name="block2_attention_linear_MLP",
                    model=model,
                    size=cfg.model.MLP_size[i],
                    shortcut_group=PMd_input_group,
                    recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                    batchNorm=batchNorm,
                    regs=regs,
                    neuron_class=CustomLIFGroup,
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    neuron_kwargs=hidden_neuron_kwargs,
                    connection_kwargs=connection_kwargs
                )
                # initialize
                hidden_init.initialize(linear_layer_MLP)
                # if nb_inputs == 192:
                #     with torch.no_grad():
                #         if mean2 == 0:
                #             linear_layer_MLP.connections[0].op.weight[:, 96:193].zero_()
                #         else:
                #             linear_layer_MLP.connections[0].op.weight[:, 96:193] /= mean2 / mean1

                # connection用于连接attention的output group以及MLPgroup
                linear_connection_MLP = model.add_connection(
                    connection_class(
                        src=ChannelAttention_group,
                        dst=linear_layer_MLP.output_group,
                        flatten_input=True,
                        name="block2_linera_connection_MLP",
                        operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                        **connection_kwargs
                    )
                )
                hidden_init.initialize(linear_connection_MLP)  # initialize
                if cfg.model.MLP_id_map:
                    with torch.no_grad():  # 禁用梯度计算
                        linear_connection_MLP.op.weight.data.zero_()  # 使用 in-place 操作 _zero_()

                for c in linear_layer_MLP.connections:
                    model.propagate_seq.append(c)
                model.propagate_seq.append(linear_connection_MLP)
                model.propagate_seq.append(linear_layer_MLP.output_group)
                PMd_MLP_group = linear_layer_MLP.output_group
                MLP_group.append(PMd_MLP_group)


    MLP_shape = 0
    for g in MLP_group:
        MLP_shape += g.nb_units

    MLP_group_concat = custom_fake_InputGroup(
                shape=MLP_shape,
                name="MLP_group_concat",
            )
    model.add_group(MLP_group_concat)
    MLP_con = model.add_connection(
        Connection_identity_with_multi_src(
            src=MLP_group,
            dst=MLP_group_concat,
            feedback=False,
            name="M1_input_con",
        )
    )
    model.propagate_seq.append(MLP_con)
    model.propagate_seq.append(MLP_group_concat)
    current_src_grp=MLP_group_concat

    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    **connection_kwargs
                )
            )
            readout_init.initialize(con_ro)
            model.propagate_seq.append(con_ro)
            model.propagate_seq.append(g)

        average_ro_g = model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
        model.propagate_seq.append(average_ro_g)
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp,
                             readout_group,
                             dtype=dtype,
                             flatten_input=True,
                             operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                             **connection_kwargs)
        )

        readout_init.initialize(con_ro)

        model.propagate_seq.append(con_ro)
        model.propagate_seq.append(readout_group)

    # if output_feedback:
    #     model.groups[0].add_src(model.groups[-1])
    if not cfg.model.BatchNorm:
        for c in model.connections:
            c.bn=None

    return model

# qkv合并 qkvmixed
def get_model_attention_V16(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    # if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
    #     output_feedback=cfg.model.output_feedback
    #     nb_inputs += 2
    # else:
    #     output_feedback=False

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
        requires_grad = (not cfg.model.pretrain_forze)
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
        requires_grad = True
    else:
        batchsize_ = cfg.training.batchsize
    assert requires_grad==True, "requires_grad 只能为True"

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    # assert data is not None, "数据不能为空"
    if data is not None:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        Warning("数据为空，无法计算输入脉冲频率，使用默认值")
        mean1 = 0.1



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    connection_class = Connection_withBatchNorm
    batchNorm = True
    # if cfg.model.BatchNorm:
    #     batchNorm=True
    # else:
    #     connection_class=Connection
    #     batchNorm=False
    model.seq_forward=cfg.model.seq_forward
    model.propagate_seq=[]




    # INPUT LAYER
    if cfg.model.reshape_num>1:
        if nb_inputs==192:
            inputshape=(int(cfg.model.reshape_num), int(nb_inputs/cfg.model.reshape_num))
        elif nb_inputs==96:
            inputshape=(int(cfg.model.reshape_num/2), int(2*nb_inputs/cfg.model.reshape_num))
    else:
        inputshape = (1, nb_inputs)
    input_group = model.add_group(
        InputGroup(
            inputshape,
            dropout_p=cfg.model.dropout_p,
            # output_feedback=output_feedback,
        )
    )
    current_src_grp = input_group
    model.propagate_seq.append(input_group)



    # 定义隐藏层神经元的参数。
    neuron_class = Custom_multiSyn_LIFGroup if cfg.model.multiSyn_LIF else CustomLIFGroup
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
        "spike_after_dynamic": cfg.model.spike_after_dynamic,
    }
    connection_kwargs = {
        "requires_grad": requires_grad,
        "bias": cfg.model.Linear_bias,
        "adaptive_bn": cfg.model.adaptive_bn,
    }




    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    assert cfg.model.nb_linear_hidden==0, "还没进行该部分的优化，目前不能用这个"
    for i in range(cfg.model.nb_linear_hidden):

        # hidden_layer = Layer(
        #     name="block1_linear_hidden",
        #     model=model,
        #     size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
        #     input_group=current_src_grp,
        #     recurrent=cfg.model.linear_hidden_recurrent[i],
        #     regs=regs,
        #     neuron_class=neuron_class,
        #     neuron_kwargs=hidden_neuron_kwargs,
        #     connection_class=connection_class,
        #     connection_kwargs={
        #         "row": True
        #     },
        # )
        hidden_layer = Layer(
            name="block1_linear_hidden",
            model=model,
            # size=(cfg.model.linear_hidden_dim[i],cfg.model.linear_hidden_size[i]),
            size=input_group.shape,
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[i],
            regs=regs,
            neuron_class=neuron_class,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={**connection_kwargs, 'row': True},
        )
        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        # if output_feedback and data is not None:
        #     assert  hidden_layer.connections[0].op.weight.shape[1] == 98 or hidden_layer.connections[0].op.weight.shape[1] == 194, \
        #         "The shape of input group q, k and v must be 1 demention."
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, -2:] /= mean3 / mean1
        #     if i == 0 and nb_inputs == 194 and cfg.data.data_len == 192:
        #         with torch.no_grad():
        #             hidden_layer.connections[0].op.weight[:, 96:-2] /= mean2 / mean1
        # elif i == 0 and nb_inputs == 192 and data is not None and cfg.data.data_len==192:
        #     with torch.no_grad():
        #         hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1
        if i == 0 and nb_inputs == 192 and data is not None:
            with torch.no_grad():
                if mean2 == 0:
                    hidden_layer.connections[0].op.weight[:, 96:193].zero_()
                else:
                    hidden_layer.connections[0].op.weight[:, 96:193] /= mean2 / mean1

        for c in hidden_layer.connections:
            model.propagate_seq.append(c)
        if cfg.model.multiSyn_LIF:
            for c in hidden_layer.connections:
                hidden_layer.output_group.add_syn(c)
        model.propagate_seq.append(hidden_layer.output_group)







    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        if cfg.model.Attention_qkv == "conv_linear" or cfg.model.Attention_qkv == "conv":
            qkv_layer = Channel1dConvConnection_multihead_qkv_Layer(
                name="block2_attention_conv_qkv",
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_filters=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                neuron_class=neuron_class,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs=connection_kwargs,
            )
            # initialize
            hidden_init.initialize(qkv_layer)
            if cfg.model.MLP_id_map:
                with torch.no_grad():  # 禁用梯度计算
                    qkv_layer.connections[0].op.weight *= 100
            for c in qkv_layer.connections:
                model.propagate_seq.append(c)
            if cfg.model.multiSyn_LIF:
                for c in qkv_layer.connections:
                    qkv_layer.output_group.add_syn(c)
            if cfg.model.Attention_qkv == "conv_linear":
                linear_connection_qkv = model.add_connection(
                    Connection_withBatchNorm(
                        src=current_src_grp,
                        dst=qkv_layer.output_group,
                        row=True,
                        flatten_input=True,
                        name="block2_attention_linear_qkv",
                        bias=False,
                    )
                )
                hidden_init.initialize(linear_connection_qkv)  # initialize
                if cfg.model.MLP_id_map:
                    with torch.no_grad():  # 禁用梯度计算
                        linear_connection_qkv.op.weight *= 100
                model.propagate_seq.append(linear_connection_qkv)
                if cfg.model.multiSyn_LIF:
                    qkv_layer.output_group.add_syn(linear_connection_qkv)
            model.propagate_seq.append(qkv_layer.output_group)
            assert qkv_layer.output_group.shape[0] % 3 ==0, \
                "The output group shape of qkv_layer must be divisible by 3, but get {}".format(qkv_layer.output_group.shape[0])
        elif cfg.model.Attention_qkv == "linear":
            qkv_layer = Layer(
                name="block2_attention_linear_qkv",
                model=model,
                size=(cfg.model.nb_Attention_embed[i]*3,cfg.model.Attention_hidden_size[i]),
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                neuron_class=neuron_class,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs={**connection_kwargs, 'row': True},
            )
            # initialize
            hidden_init.initialize(qkv_layer)
            if cfg.model.MLP_id_map:
                with torch.no_grad():  # 禁用梯度计算
                    qkv_layer.connections[0].op.weight *= 100
            for c in qkv_layer.connections:
                model.propagate_seq.append(c)
            if cfg.model.multiSyn_LIF:
                for c in qkv_layer.connections:
                    qkv_layer.output_group.add_syn(c)
            model.propagate_seq.append(qkv_layer.output_group)
            assert qkv_layer.output_group.shape[0] % 3 ==0, \
                "The output group shape of qkv_layer must be divisible by 3, but get {}".format(qkv_layer.output_group.shape[0])



        # Attention-score层
        attention_shape=(int(qkv_layer.output_group.shape[0] / 3), qkv_layer.output_group.shape[1])
        ChannelAttention_group = neuron_class(
            attention_shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        atten_connection=model.add_connection(
            ChannelAttentionConnection_multiHead_qkv(
                src=qkv_layer.output_group,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                dtype=dtype,
                name="block2_attention_connection",
            )
        )
        model.propagate_seq.append(atten_connection)
        if cfg.model.multiSyn_LIF:
            ChannelAttention_group.add_syn(atten_connection)
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_attention_connection_recurrent",
                    **connection_kwargs
                    # requires_grad = requires_grad,
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)
            model.propagate_seq.append(linear_connection_Attention_recurrent)
            if cfg.model.multiSyn_LIF:
                ChannelAttention_group.add_syn(linear_connection_Attention_recurrent)
        if not requires_grad:
            for param in ChannelAttention_group.parameters():
                param.requires_grad = False
        model.propagate_seq.append(ChannelAttention_group)




        # MLP层
        if cfg.model.shortcut_Linear:
            # 这个地方其实只是增加了inputgroup和MLPgroup之间的connection以及recurrent connection
            linear_layer_MLP = LinearLayer_of_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=cfg.model.MLP_size[i],
                shortcut_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                batchNorm=batchNorm,
                regs=regs,
                neuron_class=neuron_class,
                operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear ,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs=connection_kwargs
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            # if nb_inputs == 192:
            #     with torch.no_grad():
            #         if mean2 == 0:
            #             linear_layer_MLP.connections[0].op.weight[:, 96:193].zero_()
            #         else:
            #             linear_layer_MLP.connections[0].op.weight[:, 96:193] /= mean2 / mean1

            # connection用于连接attention的output group以及MLPgroup
            linear_connection_MLP = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=linear_layer_MLP.output_group,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    **connection_kwargs
                )
            )
            hidden_init.initialize(linear_connection_MLP) # initialize
            if cfg.model.MLP_id_map:
                with torch.no_grad():  # 禁用梯度计算
                    linear_connection_MLP.op.weight.data.zero_()  # 使用 in-place 操作 _zero_()


            for c in linear_layer_MLP.connections:
                model.propagate_seq.append(c)
            model.propagate_seq.append(linear_connection_MLP)
            if cfg.model.multiSyn_LIF:
                for c in linear_layer_MLP.connections:
                    linear_layer_MLP.output_group.add_syn(c)
                linear_layer_MLP.output_group.add_syn(linear_connection_MLP)

            if current_src_grp!=input_group:
                input_connection_MLP = model.add_connection(
                    connection_class(
                        src=input_group,
                        dst=linear_layer_MLP.output_group,
                        flatten_input=True,
                        name="block2_input_connection_MLP",
                        operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                        **connection_kwargs
                    )
                )
                hidden_init.initialize(input_connection_MLP)  # initialize
                model.propagate_seq.append(input_connection_MLP)
                if cfg.model.multiSyn_LIF:
                    linear_layer_MLP.output_group.add_syn(input_connection_MLP)

            model.propagate_seq.append(linear_layer_MLP.output_group)
            current_src_grp = linear_layer_MLP.output_group






    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    **connection_kwargs
                )
            )
            readout_init.initialize(con_ro)
            model.propagate_seq.append(con_ro)
            model.propagate_seq.append(g)

        average_ro_g = model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
        model.propagate_seq.append(average_ro_g)
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp,
                             readout_group,
                             dtype=dtype,
                             flatten_input=True,
                             operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                             **connection_kwargs)
        )

        readout_init.initialize(con_ro)

        model.propagate_seq.append(con_ro)
        model.propagate_seq.append(readout_group)

    # if output_feedback:
    #     model.groups[0].add_src(model.groups[-1])
    if not cfg.model.BatchNorm:
        for c in model.connections:
            c.bn=None
            if isinstance(c.op, RepClassModule):
                raise("RepClassModule should not be used in this model, please check the configuration.")

    if cfg.model.multi_BN:

        model.multiBN = nn.ModuleDict()
        for key in cfg.pretrain_monkeys:
            model.multiBN[key] = nn.ModuleDict()

        for c_idx, c in enumerate(model.connections):
            c_id = f"connection_{c_idx}"

            if hasattr(c, "bn") and c.bn is not None:
                for key in cfg.pretrain_monkeys:
                    model.multiBN[key][c_id] = copy.deepcopy(c.bn).to(cfg.device)
                # 存储引用以便后续访问
                c.multiBN_id = c_id

            if isinstance(c.op, RepClassModule):
                for key in cfg.pretrain_monkeys:
                    model.multiBN[key][c_id] = nn.ModuleDict()

                # 存储引用以便后续访问
                c.multiBN_id = c_id
                # c.module_ids = {}

                if hasattr(c.op, "rbr_1x1"):
                    for key in cfg.pretrain_monkeys:
                        model.multiBN[key][c_id]["rbr_1x1_bn"]=copy.deepcopy(c.op.rbr_1x1.bn).to(cfg.device)
                if hasattr(c.op, "rbr_dense"):
                    for key in cfg.pretrain_monkeys:
                        model.multiBN[key][c_id]["rbr_dense_bn"]=copy.deepcopy(c.op.rbr_dense.bn).to(cfg.device)
                if hasattr(c.op, "rbr_identity") and c.op.rbr_identity is not None:
                    for key in cfg.pretrain_monkeys:
                        model.multiBN[key][c_id]["rbr_identity"]=copy.deepcopy(c.op.rbr_identity).to(cfg.device)
                if hasattr(c.op, "model") and isinstance(c.op.model, nn.ModuleList):
                    ValueError("MultiBN does not support nn.ModuleList in RepClassModule")



                # for module_idx, module in enumerate(c.op.modules()):
                #     if isinstance(module, nn.BatchNorm1d):
                #         module_id = f"module_{module_idx}"
                #         c.module_ids[module] = module_id
                #
                #         for key in cfg.pretrain_monkeys:
                #             model.multiBN[key][c_id][module_id] = copy.deepcopy(module).to(cfg.device)
                #             # model.multiBN[key][c][module] = copy.deepcopy(module.state_dict())
    else:
        model.multiBN = False


    return model

# hidden=1，作为encoder
def get_model_attention_V17(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
        requires_grad = (not cfg.model.pretrain_forze)
    elif stateFlag=="fine-tune" or "finetune":
        batchsize_ = cfg.training.batchsize_finetuning
        requires_grad = True
    else:
        batchsize_ = cfg.training.batchsize
    assert requires_grad==True, "requires_grad 只能为True"

    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    if data is not None:
        # if output_feedback:
        #     mean1, mean2, mean3 = compute_input_firing_rates(data, cfg, nb_inputs)
        # else:
        #     mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        Warning("数据为空，无法计算输入脉冲频率，使用默认值")
        mean1 = 0.1



    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器




    # attention 的相关配置
    assert cfg.model.Layer_init==True, "默认使用层初始化"
    connection_class = Connection_withBatchNorm
    batchNorm = True
    model.seq_forward=cfg.model.seq_forward
    model.propagate_seq=[]




    # INPUT LAYER
    if cfg.model.reshape_num>1:
        if nb_inputs==192:
            inputshape=(int(cfg.model.reshape_num), int(nb_inputs/cfg.model.reshape_num))
        elif nb_inputs==96:
            inputshape=(int(cfg.model.reshape_num/2), int(2*nb_inputs/cfg.model.reshape_num))
    else:
        inputshape = (1, nb_inputs)
    input_group = model.add_group(
        InputGroup(
            inputshape,
            dropout_p=cfg.model.dropout_p,
            # output_feedback=output_feedback,
        )
    )
    current_src_grp = input_group
    model.propagate_seq.append(input_group)



    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
        "spike_after_dynamic": cfg.model.spike_after_dynamic,
    }
    connection_kwargs = {
        "requires_grad": requires_grad,
        "bias": cfg.model.Linear_bias,
        "adaptive_bn": cfg.model.adaptive_bn,
    }




    # block1: 线性隐藏层
    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。
    # 初始化隐藏层，并根据输入脉冲频率调整权重。
    if cfg.model.firingRate_Decoder:

        hidden_layer = Layer(
            name="block1_linear_encoder",
            model=model,
            size=(cfg.model.linear_encode_size[0],cfg.model.linear_encode_size[1]),
            input_group=current_src_grp,
            recurrent=cfg.model.linear_hidden_recurrent[0],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_class=connection_class,
            connection_kwargs={**connection_kwargs,
                               'row': True,
                               # 'operation': RepVGGLinearBlock1d,
                               'bn_type': "units",
                               },
        )
        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        for c in hidden_layer.connections:
            model.propagate_seq.append(c)
        model.propagate_seq.append(hidden_layer.output_group)







    attention_scr_size = current_src_grp.shape
    # block2：通道注意力层（）
    for i in range(cfg.model.nb_Attention_hidden):


        # QKV层
        if cfg.model.Attention_qkv=="conv_linear" :
            print("Attention_qkv is conv_linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V5(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                conv=(
                    RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                    else RepVGGplusBlock1dV2 if cfg.model.Repconv
                    else nn.Conv1d),
                connection_kwargs=connection_kwargs,
            )
        elif cfg.model.Attention_qkv=="conv":
            print("Attention_qkv is conv")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer_V2(
                model=model,
                kernel_size=cfg.model.Attention_kernel_size[i],
                nb_head=cfg.model.nb_Attention_embed[i],
                stride=cfg.model.Attention_conv_stride[i],
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                conv=(
                    RepVGGplusBlock1d if cfg.model.Repconv and cfg.model.Attention_kernel_size[i] == 3
                    else RepVGGplusBlock1dV2 if cfg.model.Repconv
                    else nn.Conv1d),
                connection_kwargs=connection_kwargs,
            )
        elif cfg.model.Attention_qkv=="linear":
            if cfg.model.commom_linear:
                common_linear = nn.Linear(current_src_grp.nb_units, current_src_grp.nb_units)
                nn.init.kaiming_normal_(common_linear.weight, mode='fan_in', nonlinearity='relu')  # He正态分布

            print("Attention_qkv is linear")
            linear_group_q,linear_group_k,linear_group_v= get_q_k_v_Layer(
                model=model,
                size=(cfg.model.nb_Attention_embed[i],cfg.model.Attention_hidden_size[i]),
                input_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_qkv[i],
                regs=regs,
                hidden_init=hidden_init,
                # mean1=mean1,mean2=mean2,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_class=connection_class,
                connection_kwargs={**connection_kwargs, 'row': True}
            )
        else:
            raise ValueError("Attention_qkv must be conv or linear, but get {}".format(cfg.model.Attention_qkv))
        if not requires_grad:
            for group in [linear_group_q, linear_group_k, linear_group_v]:
                for param in group.parameters():
                    param.requires_grad = False
        if cfg.model.Attention_parameter_scale != 1:
            if cfg.model.nb_linear_hidden==0:
                for c in model.connections:
                    with torch.no_grad():  # 禁用梯度计算
                        c.op.weight *= cfg.model.Attention_parameter_scale
            else:
                ValueError("Attention_parameter_scale 只能在nb_linear_hidden=0时使用, 但当前nb_linear_hidden={}".format(cfg.model.nb_linear_hidden))





        # Attention-score层
        assert linear_group_q.shape==linear_group_k.shape==linear_group_v.shape, \
            "The shape of input group q, k and v must be the same."
        # group
        ChannelAttention_group = CustomLIFGroup(
            linear_group_q.shape,
            name="channel_attention_group",
            regularizers=regs,
            **hidden_neuron_kwargs,
        )
        model.add_group(ChannelAttention_group)
        # connection
        atten_connection=model.add_connection(
            ChannelAttentionConnection_multiHead(
                # src_shortcut=current_src_grp,
                src_q=linear_group_q,
                src_k=linear_group_k,
                src_v=linear_group_v,
                dst=ChannelAttention_group,
                num_heads=cfg.model.nb_Attention_head[i],
                # shortcut="VS",
                dtype=dtype,
            )
        )
        model.propagate_seq.append(atten_connection)
        if cfg.model.Attention_recurrent[i]:
            linear_connection_Attention_recurrent = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=ChannelAttention_group,
                    row=True,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                    **connection_kwargs
                    # requires_grad = requires_grad,
                )
            )
            hidden_init.initialize(linear_connection_Attention_recurrent)
            model.propagate_seq.append(linear_connection_Attention_recurrent)
        if not requires_grad:
            for param in ChannelAttention_group.parameters():
                param.requires_grad = False
        model.propagate_seq.append(ChannelAttention_group)




        # MLP层
        if cfg.model.shortcut_Linear:
            # 这个地方其实只是增加了inputgroup和MLPgroup之间的connection以及recurrent connection
            linear_layer_MLP = LinearLayer_of_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=cfg.model.MLP_size[i],
                shortcut_group=current_src_grp,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                batchNorm=batchNorm,
                regs=regs,
                neuron_class=CustomLIFGroup,
                operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear ,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs=connection_kwargs
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            # if nb_inputs == 192:
            #     with torch.no_grad():
            #         if mean2 == 0:
            #             linear_layer_MLP.connections[0].op.weight[:, 96:193].zero_()
            #         else:
            #             linear_layer_MLP.connections[0].op.weight[:, 96:193] /= mean2 / mean1

            # connection用于连接attention的output group以及MLPgroup
            linear_connection_MLP = model.add_connection(
                connection_class(
                    src=ChannelAttention_group,
                    dst=linear_layer_MLP.output_group,
                    flatten_input=True,
                    name="block2_linera_connection_MLP",
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    **connection_kwargs
                )
            )
            hidden_init.initialize(linear_connection_MLP) # initialize
            if cfg.model.MLP_id_map:
                with torch.no_grad():  # 禁用梯度计算
                    linear_connection_MLP.op.weight.data.zero_()  # 使用 in-place 操作 _zero_()


            for c in linear_layer_MLP.connections:
                model.propagate_seq.append(c)
            model.propagate_seq.append(linear_connection_MLP)

            if current_src_grp!=input_group:
                input_connection_MLP = model.add_connection(
                    connection_class(
                        src=input_group,
                        dst=linear_layer_MLP.output_group,
                        flatten_input=True,
                        name="block2_input_connection_MLP",
                        operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                        **connection_kwargs
                    )
                )
                hidden_init.initialize(input_connection_MLP)  # initialize
                model.propagate_seq.append(input_connection_MLP)

            model.propagate_seq.append(linear_layer_MLP.output_group)
            current_src_grp = linear_layer_MLP.output_group
        else:
            linear_layer_MLP = LinearLayer_with_shortcut(
                name="block2_attention_linear_MLP",
                model=model,
                size=cfg.model.MLP_size[i],
                input_group=ChannelAttention_group,
                shortcut_group=current_src_grp,
                shortcut_opFlag=cfg.model.shortcut_opFlag,
                batchNorm=batchNorm,
                recurrent=cfg.model.Attention_hidden_recurrent_MLP[i],
                regs=regs,
                neuron_class=CustomLIFGroup,
                neuron_kwargs=hidden_neuron_kwargs,
                connection_kwargs={},
            )
            # initialize
            hidden_init.initialize(linear_layer_MLP)
            current_src_grp = linear_layer_MLP.output_group
            for c in linear_layer_MLP.connections:
                model.propagate_seq.append(c)
            model.propagate_seq.append(linear_layer_MLP.output_group)






    # block3: 输出层
    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(
                connection_class(
                    current_src_grp, g, dtype=dtype, flatten_input=True,
                    operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                    **connection_kwargs
                )
            )
            readout_init.initialize(con_ro)
            model.propagate_seq.append(con_ro)
            model.propagate_seq.append(g)

        average_ro_g = model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
        model.propagate_seq.append(average_ro_g)
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            connection_class(current_src_grp,
                             readout_group,
                             dtype=dtype,
                             flatten_input=True,
                             operation=RepVGGLinearBlock1d if cfg.model.Rep_Linear else nn.Linear,
                             **connection_kwargs)
        )

        readout_init.initialize(con_ro)

        model.propagate_seq.append(con_ro)
        model.propagate_seq.append(readout_group)

    if cfg.model.firingRate_Decoder:
        model.firingRate_Decoder = Connection_mem_spike_decoder(
            src=hidden_layer.output_group,
            output_size=nb_inputs,
            bias=cfg.model.Linear_bias,
            name="firingRate_Decoder",
            dtype=dtype,
        )
        model.propagate_seq.append(model.firingRate_Decoder)

    if not cfg.model.BatchNorm:
        for c in model.connections:
            c.bn=None
        ValueError("BatchNorm is not used in this model, please check the configuration.")

    if cfg.model.multi_BN:

        model.multiBN = nn.ModuleDict()
        for key in cfg.pretrain_monkeys:
            model.multiBN[key] = nn.ModuleDict()

        for c_idx, c in enumerate(model.connections):
            c_id = f"connection_{c_idx}"

            if hasattr(c, "bn") and c.bn is not None:
                for key in cfg.pretrain_monkeys:
                    model.multiBN[key][c_id] = copy.deepcopy(c.bn).to(cfg.device)
                # 存储引用以便后续访问
                c.multiBN_id = c_id

            if isinstance(c.op, RepClassModule):
                for key in cfg.pretrain_monkeys:
                    model.multiBN[key][c_id] = nn.ModuleDict()

                # 存储引用以便后续访问
                c.multiBN_id = c_id
                # c.module_ids = {}

                if hasattr(c.op, "rbr_1x1"):
                    for key in cfg.pretrain_monkeys:
                        model.multiBN[key][c_id]["rbr_1x1_bn"]=copy.deepcopy(c.op.rbr_1x1.bn).to(cfg.device)
                if hasattr(c.op, "rbr_dense"):
                    for key in cfg.pretrain_monkeys:
                        model.multiBN[key][c_id]["rbr_dense_bn"]=copy.deepcopy(c.op.rbr_dense.bn).to(cfg.device)
                if hasattr(c.op, "rbr_identity") and c.op.rbr_identity is not None:
                    for key in cfg.pretrain_monkeys:
                        model.multiBN[key][c_id]["rbr_identity"]=copy.deepcopy(c.op.rbr_identity).to(cfg.device)
                if hasattr(c.op, "model") and isinstance(c.op.model, nn.ModuleList):
                    ValueError("MultiBN does not support nn.ModuleList in RepClassModule")

        if cfg.model.firingRate_Decoder:
            for key in cfg.pretrain_monkeys:
                model.multiBN[key]["linear1_bn"] = copy.deepcopy(model.firingRate_Decoder.linear1_bn).to(cfg.device)


    else:
        model.multiBN = False



    return model


# 对应conv qkv multiSyn
def get_q_k_v_group_V2(
            # name,
            model,
            kernel_size,
            nb_head,
            input_group,
            recurrent,
            regs,
            hidden_init,

            stride=1,

            neuron_class=CustomLIFGroup,
            neuron_kwargs=None,  # 使用 None 替代 {}

            conv=nn.Conv1d,
            connection_class=ConvConnection_withBatchNorm,
            connection_kwargs=None,  # 使用 None 替代 {}
        ):
    if neuron_kwargs is None:
        neuron_kwargs = {}
    if connection_kwargs is None:
        connection_kwargs = {}

    group_q = neuron_class(
            input_group.shape,
            name="block2_conv_group_q",
            regularizers=regs,
            **neuron_kwargs,
        )
    connection_q = model.add_connection(
        connection_class(
            input_group,
            group_q,
            conv=conv,
            kernel_size=kernel_size,
            stride=stride,
            name="block2_connection_q",
            **connection_kwargs
        )
    )
    # initialize
    hidden_init.initialize(connection_q)
    model.propagate_seq.append(connection_q)
    model.propagate_seq.append(group_q)
    group_q.add_syn(connection_q)

    group_k = neuron_class(
            input_group.shape,
            name="block2_conv_group_q",
            regularizers=regs,
            **neuron_kwargs,
        )
    connection_k = model.add_connection(
        connection_class(
            input_group,
            group_k,
            conv=conv,
            kernel_size=kernel_size,
            stride=stride,
            name="block2_connection_k",
            **connection_kwargs
        )
    )
    # initialize
    hidden_init.initialize(connection_k)
    model.propagate_seq.append(connection_k)
    model.propagate_seq.append(group_k)
    group_k.add_syn(connection_k)

    group_v = neuron_class(
            input_group.shape,
            name="block2_conv_group_q",
            regularizers=regs,
            **neuron_kwargs,
        )
    connection_v = model.add_connection(
        connection_class(
            input_group,
            group_v,
            conv=conv,
            kernel_size=kernel_size,
            stride=stride,
            name="block2_connection_v",
            **connection_kwargs
        )
    )
    # initialize
    hidden_init.initialize(connection_v)
    model.propagate_seq.append(connection_v)
    model.propagate_seq.append(group_v)
    group_v.add_syn(connection_v)

    return group_q, group_k, group_v

# 对应conv qkv multiSyn
def get_q_k_v_Layer_V6(
            # name,
            model,
            kernel_size,
            nb_head,
            input_group,
            recurrent,
            regs,
            hidden_init,

            stride=1,

            neuron_class=Custom_multiSyn_LIFGroup,
            neuron_kwargs=None,  # 使用 None 替代 {}

            conv=nn.Conv1d,
            connection_class=ConvConnection_withBatchNorm,
            connection_kwargs=None,  # 使用 None 替代 {}
        ):
    if neuron_kwargs is None:
        neuron_kwargs = {}
    if connection_kwargs is None:
        connection_kwargs = {}

    linear_layer_q = Channel1dConvConnectionLayer(
        name="block2_attention_conv_q",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_q)
    for c in linear_layer_q.connections:
        model.propagate_seq.append(c)
        linear_layer_q.output_group.add_syn(c)
    model.propagate_seq.append(linear_layer_q.output_group)

    linear_layer_k = Channel1dConvConnectionLayer(
        name="block2_attention_conv_k",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_k)
    for c in linear_layer_k.connections:
        model.propagate_seq.append(c)
        linear_layer_k.output_group.add_syn(c)
    model.propagate_seq.append(linear_layer_k.output_group)

    linear_layer_v = Channel1dConvConnectionLayer(
        name="block2_attention_conv_v",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_v)
    for c in linear_layer_v.connections:
        model.propagate_seq.append(c)
        linear_layer_v.output_group.add_syn(c)
    model.propagate_seq.append(linear_layer_v.output_group)

    return linear_layer_q.output_group, linear_layer_k.output_group, linear_layer_v.output_group

# 对应conv qkv+linear qkv
def get_q_k_v_Layer_V5(
            # name,
            model,
            nb_head,
            input_group,
            recurrent,
            regs,
            hidden_init,

            # for conv
            kernel_size,
            stride=1,

            neuron_class=CustomLIFGroup,
            neuron_kwargs=None,  # 使用 None 替代 {}

            conv=nn.Conv1d,
            connection_class=ConvConnection_withBatchNorm,
            connection_kwargs=None,  # 使用 None 替代 {}
            cfg=None,
        ):
    if neuron_kwargs is None:
        neuron_kwargs = {}
    if connection_kwargs is None:
        connection_kwargs = {}

    linear_layer_q = Channel1dConvConnectionLayer(
        name="block2_attention_conv_q",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_q)
    for c in linear_layer_q.connections:
        model.propagate_seq.append(c)
    linear_connection_q = model.add_connection(
        Connection_withBatchNorm(
            src=input_group,
            dst=linear_layer_q.output_group,
            row=True,
            flatten_input=True,
            name="block2_linera_connection_q",
            bias=False,
            shortcut=cfg.model.forward_shortcut,
        )
    )
    hidden_init.initialize(linear_connection_q)  # initialize
    model.propagate_seq.append(linear_connection_q)
    model.propagate_seq.append(linear_layer_q.output_group)

    linear_layer_k = Channel1dConvConnectionLayer(
        name="block2_attention_conv_k",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_k)
    for c in linear_layer_k.connections:
        model.propagate_seq.append(c)
    linear_connection_k = model.add_connection(
        Connection_withBatchNorm(
            src=input_group,
            dst=linear_layer_k.output_group,
            row=True,
            flatten_input=True,
            name="block2_linera_connection_k",
            bias=False,
            shortcut=cfg.model.forward_shortcut,
        )
    )
    hidden_init.initialize(linear_connection_k)  # initialize
    model.propagate_seq.append(linear_connection_k)
    model.propagate_seq.append(linear_layer_k.output_group)

    linear_layer_v = Channel1dConvConnectionLayer(
        name="block2_attention_conv_v",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_v)
    for c in linear_layer_v.connections:
        model.propagate_seq.append(c)
    linear_connection_v = model.add_connection(
        Connection_withBatchNorm(
            src=input_group,
            dst=linear_layer_v.output_group,
            row=True,
            flatten_input=True,
            name="block2_linera_connection_v",
            bias=False,
            shortcut=cfg.model.forward_shortcut,
        )
    )
    hidden_init.initialize(linear_connection_v)  # initialize
    model.propagate_seq.append(linear_connection_v)
    model.propagate_seq.append(linear_layer_v.output_group)

    return linear_layer_q.output_group, linear_layer_k.output_group, linear_layer_v.output_group

# 对应self attention 以及corss attention conv qkv
def get_q_k_v_Layer_V4(
            # name,
            model,
            kernel_size,
            nb_head,
            input_group,
            input_feedback_group,
            recurrent,
            regs,
            hidden_inits,

            stride=1,

            neuron_class=CustomLIFGroup,
            neuron_kwargs=None,  # 使用 None 替代 {}

            conv=nn.Conv1d,
            connection_class=ConvConnection_withBatchNorm,
            connection_kwargs=None,  # 使用 None 替代 {}
        ):
    if neuron_kwargs is None:
        neuron_kwargs = {}
    if connection_kwargs is None:
        connection_kwargs = {}



    linear_layer_k = Channel1dConvConnectionLayer(
        name="block2_attention_conv_k",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_inits[0].initialize(linear_layer_k)
    for c in linear_layer_k.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_k.output_group)

    linear_layer_v = Channel1dConvConnectionLayer(
        name="block2_attention_conv_v",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_inits[0].initialize(linear_layer_v)
    for c in linear_layer_v.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_v.output_group)

    linear_layer_q_self = Channel1dConvConnectionLayer(
        name="block2_attention_conv_q_self",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_inits[0].initialize(linear_layer_q_self)
    for c in linear_layer_q_self.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_q_self.output_group)

    linear_layer_q_cross = Layer(
        name="block2_attention_linear_q_q_cross",
        model=model,
        size=linear_layer_k.output_group.shape,
        input_group=input_feedback_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        connection_class=Connection_withBatchNorm,
        connection_kwargs={
                "row": True
            },
    )
    # initialize
    # hidden_inits[-1].initialize(linear_layer_q_cross)
    nn.init.kaiming_normal_(linear_layer_q_cross.connections[0].op.weight, mode='fan_in', nonlinearity='relu')  # He正态分布
    for c in linear_layer_q_cross.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_q_cross.output_group)

    return linear_layer_k.output_group, linear_layer_v.output_group, linear_layer_q_self.output_group, linear_layer_q_cross.output_group

# 对应cross attention qkv
def get_q_k_v_Layer_V3(
            # name,
            model,
            kernel_size,
            nb_head,
            input_group,
            input_feedback_group,
            recurrent,
            regs,
            hidden_init,

            stride=1,

            neuron_class=CustomLIFGroup,
            neuron_kwargs=None,  # 使用 None 替代 {}
            connection_class=ConvConnection_withBatchNorm,
            connection_kwargs=None,  # 使用 None 替代 {}
        ):
    if neuron_kwargs is None:
        neuron_kwargs = {}
    if connection_kwargs is None:
        connection_kwargs = {}



    linear_layer_k = Channel1dConvConnectionLayer(
        name="block2_attention_linear_k",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        # connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_k)

    linear_layer_v = Channel1dConvConnectionLayer(
        name="block2_attention_linear_v",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        # connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_v)

    linear_layer_q = Layer(
        name="block2_attention_linear_q",
        model=model,
        size=linear_layer_v.output_group.shape,
        input_group=input_feedback_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        connection_class=Connection_withBatchNorm,
        connection_kwargs={
                "row": True
            },
    )
    # initialize
    hidden_init.initialize(linear_layer_q)

    return linear_layer_q.output_group, linear_layer_k.output_group, linear_layer_v.output_group

# 对应conv qkv
def get_q_k_v_Layer_V2(
            # name,
            model,
            kernel_size,
            nb_head,
            input_group,
            recurrent,
            regs,
            hidden_init,

            stride=1,

            neuron_class=CustomLIFGroup,
            neuron_kwargs=None,  # 使用 None 替代 {}

            conv=nn.Conv1d,
            connection_class=ConvConnection_withBatchNorm,
            connection_kwargs=None,  # 使用 None 替代 {}
        ):
    if neuron_kwargs is None:
        neuron_kwargs = {}
    if connection_kwargs is None:
        connection_kwargs = {}

    linear_layer_q = Channel1dConvConnectionLayer(
        name="block2_attention_conv_q",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_q)
    for c in linear_layer_q.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_q.output_group)

    linear_layer_k = Channel1dConvConnectionLayer(
        name="block2_attention_conv_k",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_k)
    for c in linear_layer_k.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_k.output_group)

    linear_layer_v = Channel1dConvConnectionLayer(
        name="block2_attention_conv_v",
        model=model,
        kernel_size=kernel_size,
        nb_filters=nb_head,
        stride=stride,
        shape=None,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        conv=conv,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_v)
    for c in linear_layer_v.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_v.output_group)

    return linear_layer_q.output_group, linear_layer_k.output_group, linear_layer_v.output_group

# 对应linear qkv
def get_q_k_v_Layer_V1(
            # name,
            model,
            size,
            input_group,
            recurrent,
            regs,
            hidden_init,
            mean1,mean2,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=None,  # 使用 None 替代 {}
            connection_class=Connection,
            connection_kwargs=None,  # 使用 None 替代 {}
        ):
    if neuron_kwargs is None:
        neuron_kwargs = {}
    if connection_kwargs is None:
        connection_kwargs = {}

    nb_inputs = input_group.nb_units
    # assert nb_inputs==96 or nb_inputs==192, "nb_inputs must be 96 or 192, but get {}".format(nb_inputs)

    linear_layer_q = Layer(
        name="block2_attention_linear_q",
        model=model,
        size=size,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_q)
    if nb_inputs == 192 :
        with torch.no_grad():
            if mean2==0:
                linear_layer_q.connections[0].op.weight[:, 96:193].zero_()
            else:
                linear_layer_q.connections[0].op.weight[:, 96:193] /= mean2 / mean1
    for c in linear_layer_q.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_q.output_group)

    linear_layer_k = Layer(
        name="block2_attention_linear_k",
        model=model,
        size=size,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_k)
    if nb_inputs == 192 :
        with torch.no_grad():
            if mean2==0:
                linear_layer_k.connections[0].op.weight[:, 96:193].zero_()
            else:
                linear_layer_k.connections[0].op.weight[:, 96:193] /= mean2 / mean1
    for c in linear_layer_k.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_k.output_group)

    linear_layer_v = Layer(
        name="block2_attention_linear_v",
        model=model,
        size=size,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_v)
    if nb_inputs == 192 :
        with torch.no_grad():
            if mean2==0:
                linear_layer_v.connections[0].op.weight[:, 96:193].zero_()
            else:
                linear_layer_v.connections[0].op.weight[:, 96:193] /= mean2 / mean1
    for c in linear_layer_v.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_v.output_group)

    return linear_layer_q.output_group, linear_layer_k.output_group, linear_layer_v.output_group

def get_q_k_v_Layer(
            # name,
            model,
            size,
            input_group,
            recurrent,
            regs,
            hidden_init,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=None,  # 使用 None 替代 {}
            connection_class=Connection,
            connection_kwargs=None,  # 使用 None 替代 {}
        ):
    if neuron_kwargs is None:
        neuron_kwargs = {}
    if connection_kwargs is None:
        connection_kwargs = {}

    linear_layer_q = Layer(
        name="block2_attention_linear_q",
        model=model,
        size=size,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_q)
    for c in linear_layer_q.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_q.output_group)

    linear_layer_k = Layer(
        name="block2_attention_linear_k",
        model=model,
        size=size,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_k)
    for c in linear_layer_k.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_k.output_group)

    linear_layer_v = Layer(
        name="block2_attention_linear_v",
        model=model,
        size=size,
        input_group=input_group,
        recurrent=recurrent,
        regs=regs,
        neuron_class=neuron_class,
        neuron_kwargs=neuron_kwargs,
        connection_class=connection_class,
        connection_kwargs=connection_kwargs,
    )
    # initialize
    hidden_init.initialize(linear_layer_v)
    for c in linear_layer_v.connections:
        model.propagate_seq.append(c)
    model.propagate_seq.append(linear_layer_v.output_group)

    return linear_layer_q.output_group, linear_layer_k.output_group, linear_layer_v.output_group


def get_q_k_v_group(
        # name,
        model,
        size,
        input_group,
        recurrent,
        regs,
        hidden_init,
        neuron_class=CustomLIFGroup,
        neuron_kwargs=None,  # 使用 None 替代 {}
        connection_class=Connection,
        connection_kwargs=None,  # 使用 None 替代 {}
):
    if neuron_kwargs is None:
        neuron_kwargs = {}
    if connection_kwargs is None:
        connection_kwargs = {}

    linear_group_q = neuron_class(
            size,
            name="block2_linera_group_q",
            regularizers=regs,
            **neuron_kwargs,
        )
    model.add_group(linear_group_q)
    linear_connection_q = model.add_connection(
        connection_class(
            input_group,
            linear_group_q,
            name = "block2_linera_connection_q",
            **connection_kwargs
        )
    )
    # initialize
    hidden_init.initialize(linear_connection_q)

    linear_group_k = neuron_class(
            size,
            name="block2_linera_group_k",
            regularizers=regs,
            **neuron_kwargs,
        )
    model.add_group(linear_group_k)
    linear_connection_k = model.add_connection(
        connection_class(
            input_group,
            linear_group_k,
            name="block2_linera_connection_k",
            **connection_kwargs
        )
    )
    # initialize
    hidden_init.initialize(linear_connection_k)

    linear_group_v = neuron_class(
            size,
            name="block2_linera_group_v",
            regularizers=regs,
            **neuron_kwargs,
        )
    model.add_group(linear_group_v)
    linear_connection_v = model.add_connection(
        connection_class(
            input_group,
            linear_group_v,
            name="block2_linera_connection_v",
            **connection_kwargs
        )
    )
    # initialize
    hidden_init.initialize(linear_connection_v)

    if recurrent:
        linear_connection_q_recurrent = model.add_connection(
            Connection(
                linear_group_q,
                linear_group_q,
                name = "block2_linera_connection_q_recurrent",
                **connection_kwargs
            )
        )
        hidden_init.initialize(linear_connection_q_recurrent)
        linear_connection_k_recurrent = model.add_connection(
            Connection(
                linear_group_k,
                linear_group_k,
                name = "block2_linera_connection_k_recurrent",
                **connection_kwargs
            )
        )
        hidden_init.initialize(linear_connection_k_recurrent)
        linear_connection_v_recurrent = model.add_connection(
            Connection(
                linear_group_v,
                linear_group_v,
                name = "block2_linera_connection_k_recurrent",
                **connection_kwargs
            )
        )
        hidden_init.initialize(linear_connection_v_recurrent)

    return linear_group_q, linear_group_k, linear_group_v