import numpy as np
import time

import torch

from stork.models import RecurrentSpikingModel

import stork.nodes.base
from stork import generators
from stork import loss_stacks
from stork import monitors

import warnings
from itertools import chain
from typing import TypeVar, Sequence, Union, Optional
from torch._utils import (
    _get_all_device_indices,
    _get_available_device_type,
    _get_device_index,
)
from torch.nn.modules import Module
from torch.nn.parallel.scatter_gather import gather

from torch.cuda._utils import _get_device_index
from challenge.utils import save_model_state

from .Parallel_for_stork import scatter, _check_balance, parallel_apply, get_replicas, data_scatter
# from .test_function import get_model_gpu_memory, print_memory_usage, print_tensor_details
from datetime import datetime
import copy
from .RepConv1d import *
from .custom_connections import ChannelAttentionConnection
from .custom_connections import (
    ChannelAttentionConnection_multiHead,
    ChannelAttentionConnection_multiHead_qkv,
    Connection_mem_spike_decoder,
                                 )


T = TypeVar("T", bound=Module)

class CustomRecurrentSpikingModel(RecurrentSpikingModel,
                                            # Generic[T]
                                           ):
    """
    Custom model wrapper that implements an additional training loop with
    a custom mask for weight parameters used for an iterative pruning
    strategy.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def configure(
            self,
            input,
            output,
            loss_stack=None,
            optimizer=None,
            optimizer_kwargs=None,
            scheduler=None,
            scheduler_kwargs=None,
            generator=None,
            time_step=1e-3,
            wandb=None,
            multi_cuda=False,
            device_ids=None,
            earlystop=False,
            earlystop_patience=5,
            earlystop_min_delta=0.001,
            earlystop_restore_best_weights=True,
            earlystop_min_ep=0,
            earlystop_ema_alpha=0.6,  # EMA平滑系数 (0~1)
            earlystop_restart_epochs=0,  # 早停后微调轮次
            earlystop_restart_lr_factor=0.1,  # 重启学习率比例
            output_feedback=False,
            pretrain_forze=False,

    ):
        self.input_group = input
        self.output_group = output
        self.time_step = time_step
        self.wandb = wandb
        # self.pretrainFlag='train'
        self.output_feedback=output_feedback


        if loss_stack is not None:
            self.loss_stack = loss_stack
        else:
            self.loss_stack = loss_stacks.TemporalCrossEntropyReadoutStack()

        if generator is None:
            self.data_generator_ = generators.StandardGenerator()
        else:
            self.data_generator_ = generator

        self.earlystop = earlystop
        if earlystop:
            self.earlystop_patience = earlystop_patience
            self.earlystop_delta = earlystop_min_delta
            self.earlystop_restore_best_weights = earlystop_restore_best_weights
            self.earlystop_min_ep = earlystop_min_ep
            self.earlystop_ema_alpha = earlystop_ema_alpha
            self.earlystop_restart_epochs = earlystop_restart_epochs
            self.earlystop_restart_lr_factor = earlystop_restart_lr_factor

        # configure data generator
        self.data_generator_.configure(
            self.batch_size,
            self.nb_time_steps,
            self.nb_inputs,
            self.time_step,
            device=self.device,
            dtype=self.dtype,
        )

        for o in self.groups + self.connections:
            o.configure(
                self.batch_size,
                self.nb_time_steps,
                self.time_step,
                self.device,
                self.dtype,
            )
        if hasattr(self, "multiHidden") and self.multiHidden:
            for key in self.multiHidden:
                self.multiHidden[key]['LIF'].configure(
                    self.batch_size,
                    self.nb_time_steps,
                    self.time_step,
                    self.device,
                    self.dtype,
                )
        if pretrain_forze:
            for g in self.groups:
                if (g.name=="channel_attention_group"
                        or g.name=="block2_attention_linear_v"
                        or g.name=="block2_attention_linear_k"
                        or g.name=="block2_attention_linear_q"
                        or g.name=="block2_attention_conv_q"
                        or g.name=="block2_attention_conv_k"
                        or g.name=="block2_attention_conv_v"
                ):
                    for param in g.parameters():
                        param.requires_grad = False



        if optimizer is None:
            optimizer = torch.optim.Adam

        if optimizer_kwargs is None:
            optimizer_kwargs = dict(lr=1e-3, betas=(0.9, 0.999))

        self.optimizer_class = optimizer
        self.optimizer_kwargs = optimizer_kwargs
        self.configure_optimizer(self.optimizer_class, self.optimizer_kwargs)

        self.scheduler_class = scheduler
        self.scheduler_kwargs = scheduler_kwargs
        self.configure_scheduler(self.scheduler_class, self.scheduler_kwargs)

        if multi_cuda:
            assert len(device_ids) > 1, "Not enough GPUs are specified, gpu_ids need to be greater than 2, please modify cfg.gpu_ids."
            self.multiGPU_init(device_ids=device_ids, output_device=device_ids[0])
            self.device = self.multiGPU['output_device']
            # self.multiGPU = CustomRecurrentSpikingModel_multiGPU_Core(device_ids=device_ids, output_device=device_ids[0])
            # self.device=self.multiGPU['output_device']

        self.to(self.device)

    def reconfigure(self):
        """Runs configure and replaces arguments with default from last run.

        This should reset the model an reinitialize all trainable variables.
        """
        if self.input_group is None or self.output_group is None:
            print(
                "Warning! No input or output group has been assigned yet. Run configure first."
            )
            return

        self.groups[0].configure(
                self.batch_size,
                self.nb_time_steps,
                self.time_step,
                self.device,
                self.dtype,
            )
        for o in self.groups[1:-1]:
            o.reconfigure(
                self.batch_size,
                self.nb_time_steps,
                self.time_step,
                self.device,
                self.dtype,
            )
        self.groups[-1].configure(
                self.batch_size,
                self.nb_time_steps,
                self.time_step,
                self.device,
                self.dtype,
            )
        for o in self.connections:
            o.configure(
                self.batch_size,
                self.nb_time_steps,
                self.time_step,
                self.device,
                self.dtype,
            )

        return

        # Re-init optimizer
        # self.configure_optimizer(self.optimizer_class, self.optimizer_kwargs)

    # def reset_store_state_sequences(self, stored_sequences_Flag=True):
    #     if stored_sequences_Flag:
    #         for g in self.groups:
    #             if hasattr(g, 'store_state_sequences_key'):
    #                 g.store_state_sequences = g.store_state_sequences_key
    #     else:
    #         for g in self.groups:
    #             if hasattr(g, 'store_state_sequences_key'):
    #                 g.store_state_sequences = []


    def compute_regularizer_losses(self):
        reg_loss = torch.zeros(1, device=self.device)
        for g in self.groups:
            reg_loss += g.get_regularizer_loss()
        for c in self.connections:
            reg_loss += c.get_regularizer_loss()
        return reg_loss

    def run(self, x_batch, cur_batch_size=None, record=False):

        if cur_batch_size is None:
            cur_batch_size = len(x_batch)
        self.reset_states(cur_batch_size)
        self.input_group.feed_data(x_batch)

        for t in range(self.nb_time_steps):
            stork.nodes.base.CellGroup.clk = t
            self.evolve_propagate()
            # self.evolve_all()
            # self.propagate_all()
            self.execute_all()
            if record:
                self.monitor_all()

        self.out = self.output_group.get_out_sequence()
        if hasattr(self,"class_readout") and self.class_readout is not None:
            self.out = torch.cat((self.out, self.class_readout.get_out_sequence()), dim=-1)
        if hasattr(self, 'firingRate_Decoder'):
            self.firingRate_Decoder.output=torch.stack(self.firingRate_Decoder.output,dim=1)
            self.firingRate_out = self.firingRate_Decoder.output
        return self.out

    def reset_states(self, batch_size=None):
        for g in self.groups:
            g.reset_state(batch_size)
        if hasattr(self, 'firingRate_Decoder'):
            self.firingRate_Decoder.reset()
        if hasattr(self, "multiHidden") and self.multiHidden:
            for key in self.multiHidden:
                self.multiHidden[key]['LIF'].reset_state(batch_size)

    def get_total_loss(self, output, target_y, regularized=True):
        if type(target_y) in (list, tuple):
            target_y = [ty.to(self.device) for ty in target_y]
        else:
            target_y = target_y.to(self.device)

        self.out_loss = self.loss_stack(output, target_y)

        if regularized:
            if hasattr(self, 'multiGPU'):
                self.reg_loss = self.multiGPU['reg_loss']
            else:
                self.reg_loss = self.compute_regularizer_losses()
            # self.reg_loss = self.compute_regularizer_losses()
            total_loss = self.out_loss + self.reg_loss
        else:
            total_loss = self.out_loss

        return total_loss

    def evaluate(self, test_dataset, train_mode=False, stored_sequences_Flag = False):
        self.train(train_mode)
        # self.prepare_data(test_dataset)
        metrics = []

        # self.reset_store_state_sequences(stored_sequences_Flag)

        with torch.no_grad():  # 禁用梯度计算
            for local_X, local_y in self.data_generator(test_dataset, shuffle=False):

                # if self.output_feedback:
                #     shifted_data = torch.zeros_like(local_y)
                #     shifted_data[1:] = local_y[:-1]
                #     shifted_data[0] = shifted_data[1]
                #     local_X = torch.cat((local_X, shifted_data), dim=-1)

                output = self.forward_pass(local_X, cur_batch_size=len(local_X), target_y=local_y, stored_sequences_Flag=stored_sequences_Flag)

                total_loss = self.get_total_loss(output, local_y)
                if hasattr(self, 'firingRate_Decoder'):
                    self.firingRate_loss = poisson_nll_loss(local_X, self.firingRate_Decoder.output)
                    total_loss += self.firingRate_loss
                    # self.reset_states()
                    # store loss and other metrics
                    metrics.append(
                        [self.out_loss.item(), self.reg_loss.item(), self.firingRate_loss.item()] + self.loss_stack.metrics
                    )# out_loss, reg_loss, fr_loss, R2[0], R2[1], R2(mean)
                else:
                    # self.reset_states()
                    # store loss and other metrics
                    metrics.append(
                        [self.out_loss.item(), self.reg_loss.item()] + self.loss_stack.metrics
                    )# out_loss, reg_loss, R2[0], R2[1], R2(mean)

        # for local_X, local_y in self.data_generator(test_dataset, shuffle=False):
        #
        #         output = self.forward_pass(local_X, cur_batch_size=len(local_X), target_y=local_y, stored_sequences_Flag=stored_sequences_Flag)
        #
        #         total_loss = self.get_total_loss(output, local_y)
        #
        #         metrics.append(
        #             [self.out_loss.item(), self.reg_loss.item()] + self.loss_stack.metrics
        #         )  # out_loss, reg_loss, R2[0], R2[1], R2(mean)

        return np.mean(np.array(metrics), axis=0)

    def evaluate_multiBN(self, test_datasets, train_mode=False, stored_sequences_Flag = False):
        self.train(train_mode)
        # self.prepare_data(test_dataset)
        metrics = []

        # self.reset_store_state_sequences(stored_sequences_Flag)

        with torch.no_grad():  # 禁用梯度计算
            for key, test_dataset in test_datasets.items():
                # print(f"Evaluate on dataset: {key}")

                # 替换对应的BatchNorm
                if self.multiBN:
                    self.BN_switch(key)
                if self.multiHidden:
                    self.hidden_switch(key)
                # for c_idx, c in enumerate(self.connections):
                #     if hasattr(c, "bn") and c.bn is not None:
                #         c_id = c.multiBN_id
                #         c.bn = self.multiBN[key][c_id]
                #     if isinstance(c.op, RepClassModule):
                #         c_id = c.multiBN_id
                #         if hasattr(c.op, "rbr_1x1"):
                #             # module_id = c.module_ids[c.op.rbr_1x1.bn]
                #             c.op.rbr_1x1.bn = self.multiBN[key][c_id]["rbr_1x1_bn"]
                #         if hasattr(c.op, "rbr_dense"):
                #             # module_id = c.module_ids[c.op.rbr_dense.bn]
                #             c.op.rbr_dense.bn = self.multiBN[key][c_id]["rbr_dense_bn"]
                #         if hasattr(c.op, "rbr_identity") and c.op.rbr_identity is not None:
                #             # module_id = c.module_ids[c.op.rbr_identity]
                #             c.op.rbr_identity = self.multiBN[key][c_id]["rbr_identity"]
                #         if hasattr(c.op, "model") and isinstance(c.op.model, nn.ModuleList):
                #             ValueError("MultiBN does not support nn.ModuleList in RepClassModule")
                # if hasattr(self, 'firingRate_Decoder'):
                #     self.firingRate_Decoder.linear1_bn = self.multiBN[key]["linear1_bn"]

                for local_X, local_y in self.data_generator(test_dataset, shuffle=False):

                    # if self.output_feedback:
                    #     shifted_data = torch.zeros_like(local_y)
                    #     shifted_data[1:] = local_y[:-1]
                    #     shifted_data[0] = shifted_data[1]
                    #     local_X = torch.cat((local_X, shifted_data), dim=-1)

                    output = self.forward_pass(local_X, cur_batch_size=len(local_X), target_y=local_y, stored_sequences_Flag=stored_sequences_Flag)

                    total_loss = self.get_total_loss(output, local_y)
                    if hasattr(self, 'firingRate_Decoder'):
                        self.firingRate_loss = poisson_nll_loss(local_X, self.firingRate_Decoder.output)
                        total_loss += self.firingRate_loss
                        self.reset_states()
                        # store loss and other metrics
                        metrics.append(
                            [self.out_loss.item(), self.reg_loss.item(), self.firingRate_loss.item()] + self.loss_stack.metrics
                        )  # out_loss, reg_loss, fr_loss, R2[0], R2[1], R2(mean)
                    else:
                        self.reset_states()
                        # store loss and other metrics
                        metrics.append(
                            [self.out_loss.item(), self.reg_loss.item()] + self.loss_stack.metrics
                        )  # out_loss, reg_loss, R2[0], R2[1], R2(mean)
        # for local_X, local_y in self.data_generator(test_dataset, shuffle=False):
        #
        #         output = self.forward_pass(local_X, cur_batch_size=len(local_X), target_y=local_y, stored_sequences_Flag=stored_sequences_Flag)
        #
        #         total_loss = self.get_total_loss(output, local_y)
        #
        #         metrics.append(
        #             [self.out_loss.item(), self.reg_loss.item()] + self.loss_stack.metrics
        #         )  # out_loss, reg_loss, R2[0], R2[1], R2(mean)

        return np.mean(np.array(metrics), axis=0)

    def regtrain_epoch(self, dataset, shuffle=True):
        self.train(True)
        self.prepare_data(dataset)
        metrics = []
        for local_X, local_y in self.data_generator(dataset, shuffle=shuffle):
            output = self.forward_pass(local_X, cur_batch_size=len(local_X), target_y=local_y)
            self.reg_loss = self.compute_regularizer_losses()

            total_loss = self.get_total_loss(output, local_y)
            loss = self.reg_loss

            # store loss and other metrics
            metrics.append(
                [self.out_loss.item(), self.reg_loss.item()] + self.loss_stack.metrics
            )

            # Use autograd to compute the backward pass.
            self.optimizer_instance.zero_grad()
            loss.backward()
            self.optimizer_instance.step()
            self.apply_constraints()

        if self.scheduler_instance is not None:
            self.scheduler_instance.step()

        return np.mean(np.array(metrics), axis=0)

    def train_epoch(self, dataset, shuffle=True):
        self.train(True)
        # self.prepare_data(dataset)
        metrics = []

        for local_X, local_y in self.data_generator(dataset, shuffle=shuffle):  ## local_X 和 local_y 分别代表当前批次的数据和标签。

            # if self.output_feedback:
            #     shifted_data = torch.zeros_like(local_y)
            #     shifted_data[1:] = local_y[:-1]
            #     shifted_data[0] = shifted_data[1]
            #     local_X = torch.cat((local_X, shifted_data), dim=-1)

            output = self.forward_pass(local_X, cur_batch_size=len(local_X), target_y=local_y)

            total_loss = self.get_total_loss(output, local_y)
            if hasattr(self, 'firingRate_Decoder'):
                self.firingRate_loss = poisson_nll_loss(local_X, self.firingRate_Decoder.output)
                total_loss += self.firingRate_loss
                self.reset_states()
                # store loss and other metrics
                metrics.append(
                    [self.out_loss.item(), self.reg_loss.item(), self.firingRate_loss.item()] + self.loss_stack.metrics
                )
            else:
                self.reset_states()
                # store loss and other metrics
                metrics.append(
                    [self.out_loss.item(), self.reg_loss.item()] + self.loss_stack.metrics
                )

            # Use autograd to compute the backward pass.
            self.optimizer_instance.zero_grad() # 清空优化器的梯度缓存。

            # self.reset_states()

            total_loss.backward() # 计算损失相对于模型参数的梯度（反向传播）
            # for name, param in self.named_parameters():
            #     if param.grad is not None:
            #         print(f"{name}梯度存在，值：{param.grad.sum().item()}")
            #     else:
            #         print(f"{name}梯度未更新！")
            self.optimizer_instance.step() # 更新模型参数
            self.apply_constraints() # 对模型参数应用约束（如权重裁剪等）

            # 释放显存
            output = None
            total_loss = None
            del local_X, local_y

            self.check_for_nan_parameters()

        # if self.scheduler_instance is not None and not is_restart_phase:
        #     self.scheduler_instance.step()
        if self.scheduler_instance is not None:
            self.scheduler_instance.step()
            # current_lr = self.optimizer_instance.param_groups[0]['lr']
            # print(f"Current learning rate before scheduler update: {current_lr:.6f}")

        return np.mean(np.array(metrics), axis=0)

    def train_epoch_multiBN(self, datasets, shuffle=True):
        self.train(True)
        # self.prepare_data(dataset)
        metrics = []

        for key, dataset in datasets.items():
            # print(f"Training on dataset: {key}")

            # 替换对应的BatchNorm
            if self.multiBN:
                self.BN_switch(key)
            # for c_idx, c in enumerate(self.connections):
            #     if hasattr(c, "bn") and c.bn is not None:
            #         c_id = c.multiBN_id
            #         c.bn = self.multiBN[key][c_id]
            #     if isinstance(c.op, RepClassModule):
            #         c_id = c.multiBN_id
            #         if hasattr(c.op, "rbr_1x1"):
            #             # module_id = c.module_ids[c.op.rbr_1x1.bn]
            #             c.op.rbr_1x1.bn = self.multiBN[key][c_id]["rbr_1x1_bn"]
            #         if hasattr(c.op, "rbr_dense"):
            #             # module_id = c.module_ids[c.op.rbr_dense.bn]
            #             c.op.rbr_dense.bn = self.multiBN[key][c_id]["rbr_dense_bn"]
            #         if hasattr(c.op, "rbr_identity") and c.op.rbr_identity is not None:
            #             # module_id = c.module_ids[c.op.rbr_identity]
            #             c.op.rbr_identity = self.multiBN[key][c_id]["rbr_identity"]
            #         if hasattr(c.op, "model") and isinstance(c.op.model, nn.ModuleList):
            #             ValueError("MultiBN does not support nn.ModuleList in RepClassModule")
            #         # for module_idx, module in enumerate(c.op.modules()):
            #         #     if isinstance(module, nn.BatchNorm1d):
            #         #         module_id=c.module_ids[module]
            #         #         module = self.multiBN[key][c][module]
            # if hasattr(self, 'firingRate_Decoder'):
            #     self.firingRate_Decoder.linear1_bn = self.multiBN[key]["linear1_bn"]

            if self.multiHidden:
                self.hidden_switch(key)

            # data_loader = self.data_generator(dataset, shuffle=shuffle)
            for local_X, local_y in self.data_generator(dataset, shuffle=shuffle):  ## local_X 和 local_y 分别代表当前批次的数据和标签。

                # if self.output_feedback:
                #     shifted_data = torch.zeros_like(local_y)
                #     shifted_data[1:] = local_y[:-1]
                #     shifted_data[0] = shifted_data[1]
                #     local_X = torch.cat((local_X, shifted_data), dim=-1)

                output = self.forward_pass(local_X, cur_batch_size=len(local_X), target_y=local_y)

                total_loss = self.get_total_loss(output, local_y)
                if hasattr(self, 'firingRate_Decoder'):
                    self.firingRate_loss = poisson_nll_loss(local_X, self.firingRate_Decoder.output)
                    total_loss += self.firingRate_loss
                    self.reset_states()
                    # store loss and other metrics
                    metrics.append(
                        [self.out_loss.item(), self.reg_loss.item(), self.firingRate_loss.item()] + self.loss_stack.metrics
                    )
                else:
                    self.reset_states()
                    # store loss and other metrics
                    metrics.append(
                        [self.out_loss.item(), self.reg_loss.item()] + self.loss_stack.metrics
                    )

                # Use autograd to compute the backward pass.
                self.optimizer_instance.zero_grad()  # 清空优化器的梯度缓存。

                # self.reset_states()

                total_loss.backward()  # 计算损失相对于模型参数的梯度（反向传播）
                # for name, param in self.named_parameters():
                #     if param.grad is not None:
                #         print(f"{name}梯度存在，值：{param.grad.sum().item()}")
                #     else:
                #         print(f"{name}梯度未更新！")
                self.optimizer_instance.step()  # 更新模型参数
                self.apply_constraints()  # 对模型参数应用约束（如权重裁剪等）

                # 释放显存
                output = None
                total_loss = None
                del local_X, local_y

                self.check_for_nan_parameters()

        # 更新学习率调度器
        if self.scheduler_instance is not None:
            self.scheduler_instance.step()
            # current_lr = self.optimizer_instance.param_groups[0]['lr']
            # print(f"Current learning rate before scheduler update: {current_lr:.6f}")

        return np.mean(np.array(metrics), axis=0)


    def check_for_nan_parameters(self, epoch=None, batch=None):
        """检测模型参数中是否存在NaN值
        Args:
            model: 待检测的模型
            epoch: 当前训练轮次（可选）
            batch: 当前批次编号（可选）
        Raises:
            ValueError: 当检测到NaN参数时抛出异常
        """
        for name, param in self.named_parameters():
            if param.requires_grad and torch.isnan(param.data).any():
                # 收集NaN参数的详细信息
                nan_mask = torch.isnan(param.data)
                nan_count = nan_mask.sum().item()
                total_count = param.data.numel()
                nan_percent = (nan_count / total_count) * 100

                # 构建错误信息
                error_msg = f"参数中检测到NaN值: {name}\n"
                error_msg += f"形状: {param.data.shape}\n"
                error_msg += f"NaN数量: {nan_count}/{total_count} ({nan_percent:.2f}%)"
                if epoch is not None: error_msg += f"\n训练轮次: {epoch}"
                if batch is not None: error_msg += f"\n批次编号: {batch}"

                # 抛出异常并终止训练
                # raise ValueError(error_msg)
                warnings.warn(error_msg, UserWarning)
        return True

    def get_probabilities(self, x_input):
        probs = []
        # we don't care about the labels, but want to use the generator
        fake_labels = torch.zeros(len(x_input))
        self.prepare_data((x_input, fake_labels))
        for local_X, local_y in self.data_generator((x_input, fake_labels)):
            output = self.forward_pass(local_X, cur_batch_size=len(local_y), target_y=local_y)
            tmp = torch.exp(self.loss_stack.log_py_given_x(output))
            probs.append(tmp)
        return torch.cat(probs, dim=0)

    def get_predictions(self, dataset):
        self.prepare_data(dataset)

        pred = []
        for local_X, _ in self.data_generator(dataset, shuffle=False):
            output = self.forward_pass(local_X, cur_batch_size=len(local_X))
            pred_labels = self.loss_stack.predict(output).cpu().numpy()
            pred.append(pred_labels)

        return np.concatenate(pred)

    def predict(self, data, train_mode=False):
        self.train(train_mode)
        if type(data) in [torch.Tensor, np.ndarray]:
            output = self.forward_pass(data, cur_batch_size=len(data))
            pred = self.loss_stack.predict(output)
            return pred
        else:
            self.prepare_data(data)
            pred = []
            for local_X, _ in self.data_generator(data, shuffle=False):
                data_local = local_X.to(self.device)
                output = self.forward_pass(data_local, cur_batch_size=len(local_X))
                pred.append(self.loss_stack.predict(output).detach().cpu())
            return torch.cat(pred, dim=0)

    def monitor(self, dataset):
        self.prepare_data(dataset)

        # Prepare a list for each monitor to hold the batches
        results = [[] for _ in self.monitors]
        for local_X, local_y in self.data_generator(dataset, shuffle=False):
            for m in self.monitors:
                m.reset()

            output = self.forward_pass(
                local_X, cur_batch_size=len(local_X), target_y=local_y, record=True
            )

            for k, mon in enumerate(self.monitors):
                results[k].append(mon.get_data())

        return [torch.cat(res, dim=0) for res in results]

    def monitor_backward(self, dataset):
        """
        Allows monitoring of gradients with GradientMonitor
            - If there are no GradientMonitors, this runs the usual `monitor` method
            - Returns both normal monitor output and backward monitor output
        """

        # if there is a gradient monitor
        if any([isinstance(m, monitors.GradientMonitor) for m in self.monitors]):
            self.prepare_data(dataset)

            # Set monitors to record gradients
            gradient_monitors = [
                m for m in self.monitors if isinstance(m, monitors.GradientMonitor)
            ]
            for gm in gradient_monitors:
                gm.set_hook()

            # Prepare a list for each monitor to hold the batches
            results = [[] for _ in self.monitors]
            for local_X, local_y in self.data_generator(dataset, shuffle=False):
                for m in self.monitors:
                    m.reset()

                # forward pass
                output = self.forward_pass(
                    local_X, record=True, cur_batch_size=len(local_X), target_y=local_y
                )

                # compute loss
                total_loss = self.get_total_loss(output, local_y)

                # Use autograd to compute the backward pass.
                self.optimizer_instance.zero_grad()
                total_loss.backward()

                # do not call an optimizer step as that would update the weights!

                # Retrieve data from monitors
                for k, mon in enumerate(self.monitors):
                    results[k].append(mon.get_data())

            # Turn gradient recording off
            for gm in gradient_monitors:
                gm.remove_hook()

            return [torch.cat(res, dim=0) for res in results]

        else:
            return self.monitor(dataset)

    def record_group_outputs(self, group, x_input):
        res = []
        # we don't care about the labels, but want to use the generator
        fake_labels = torch.zeros(len(x_input))
        self.prepare_data((x_input, fake_labels))
        for local_X, _ in self.data_generator((x_input, fake_labels)):
            output = self.forward_pass(local_X, cur_batch_size=len(local_X))
            res.append(group.get_out_sequence())
        return torch.cat(res, dim=0)

    def train_epoch_masked(self, dataset, shuffle=True, mask=None):
        self.train(True)
        self.prepare_data(dataset)
        metrics = []
        for local_X, local_y in self.data_generator(dataset, shuffle=shuffle):
            output = self.forward_pass(local_X, cur_batch_size=len(local_X), target_y=local_y)
            total_loss = self.get_total_loss(output, local_y)

            # store loss and other metrics
            metrics.append(
                [self.out_loss.item(), self.reg_loss.item()] + self.loss_stack.metrics
            )

            # Use autograd to compute the backward pass.
            self.optimizer_instance.zero_grad()
            total_loss.backward()

            if mask is not None:
                for name, param in self.named_parameters():
                    if 'weight' in name:
                        param.grad.data.mul_(mask[name])

            self.optimizer_instance.step()
            self.apply_constraints()

        if self.scheduler_instance is not None:
            self.scheduler_instance.step()

        return np.mean(np.array(metrics), axis=0)

    def fit_validate_masked(
            self, dataset, valid_dataset, nb_epochs=10, verbose=True, wandb=None, mask=None
    ):
        self.hist_train = []
        self.hist_valid = []
        self.wall_clock_time = []
        for ep in range(nb_epochs):
            t_start = time.time()
            self.train()
            ret_train = self.train_epoch_masked(dataset, mask=mask)

            self.train(False)
            ret_valid = self.evaluate(valid_dataset)
            self.hist_train.append(ret_train)
            self.hist_valid.append(ret_valid)

            if self.wandb is not None:
                self.wandb.log(
                    {
                        key: value
                        for (key, value) in zip(
                        self.get_metric_names()
                        + self.get_metric_names(prefix="val_"),
                        ret_train.tolist() + ret_valid.tolist(),
                    )
                    }
                )

            if verbose:
                t_iter = time.time() - t_start
                self.wall_clock_time.append(t_iter)
                print(
                    "%02i %s --%s t_iter=%.2f"
                    % (
                        ep,
                        self.get_metrics_string(ret_train),
                        self.get_metrics_string(ret_valid, prefix="val_"),
                        t_iter,
                    )
                )

        self.hist = np.concatenate(
            (np.array(self.hist_train), np.array(self.hist_valid))
        )
        self.fit_runs.append(self.hist)
        dict1 = self.get_metrics_history_dict(np.array(self.hist_train), prefix="")
        dict2 = self.get_metrics_history_dict(np.array(self.hist_valid), prefix="val_")
        history = {**dict1, **dict2}
        return history

    def evaluate_continuos_testdata(self, test_dataset, train_mode=False, stored_sequences_Flag = False):
        """
        Evaluates on test data that is one long sample.
        Also returns predictions and ground truth
        """
        self.train(train_mode)  # 选择train模式还是evaluation评估模式，此处选择：evaluation（train_mode=False）
        self.prepare_data(test_dataset)  # 这句好像没有内容 里面是pass
        # self.reset_store_state_sequences(stored_sequences_Flag)
        metrics = []
        loss_min_batch = np.inf
        pre_batch_best = None
        gt_batch_best = None
        with torch.no_grad():  # 禁用梯度计算
            for local_X, local_y in self.data_generator(test_dataset, shuffle=False):
                output = self.forward_pass(local_X, cur_batch_size=len(local_X), target_y=local_y)
                total_loss = self.get_total_loss(output, local_y)

                # store loss and other metrics
                metrics.append(
                    [self.out_loss.item(), self.reg_loss.item()] + self.loss_stack.metrics
                )
                if self.out_loss < loss_min_batch:
                    loss_min_batch = self.out_loss
                    pre_batch_best = (
                        output.cpu().detach().numpy()[0, :, :]
                    )  #  Choose the first sample in the batch with minimal loss as an example, prediction
                    gt_batch_best = (
                        local_y.cpu().detach().numpy()[0, :, :]
                    )  #  Choose the first sample in the batch with minimal loss as an example, ground truth

        return np.mean(np.array(metrics), axis=0), pre_batch_best, gt_batch_best  #

    def forward_pass(self, x_batch, cur_batch_size, target_y=None, record=False, stored_sequences_Flag=False):
        # 如果没有使用multiGPU
        if not hasattr(self, 'multiGPU'):
            out = self.run(x_batch, cur_batch_size, record=record)
            return out

        if cur_batch_size < len(self.multiGPU['device_ids']):
            out = self.run(x_batch, cur_batch_size, record=record)
            return out

        # 如果使用了multiGPU
        assert hasattr(self, 'multiGPU') is not None, "multiGPU is Non-existent"
        assert target_y is not None, "target_y is None"
        assert len(target_y) == len(x_batch), "len(target_y) != cur_batch_size"

        outputs = forward_pass_multiGPU(self, x_batch, target_y, cur_batch_size, record, stored_sequences_Flag)
        self.out=outputs # out

        return outputs


    def fit_validate(
        self,
        dataset,
        valid_dataset,
        nb_epochs=10,
        verbose=True,
        wandb=None,
    ):
        self.hist_train = []
        self.hist_valid = []
        self.wall_clock_time = []

        # 早停初始化
        best_ema = -np.inf if self.earlystop == 'r2' else np.inf
        no_improvement_count = 0
        best_model_weights = None
        ema_val = None  # EMA平滑值

        # 训练阶段标记（正常训练/微调阶段）
        is_restart_phase = False
        total_epochs = nb_epochs  # 保存原始epoch数

        # self.reset_store_state_sequences(False)
        current_lr = self.optimizer_instance.param_groups[0]['lr']
        print(f"Current learning rate before scheduler update: {current_lr:.6f}")

        # 主训练循环（支持重启）
        while nb_epochs > 0:
            for ep in range(nb_epochs):
                t_start = time.time()

                # 训练逻辑
                self.train()  # 修改训练标志为True
                if self.output_feedback:
                    self.input_group.set_epoch(ep)
                if isinstance(dataset, dict) and (self.multiBN or self.multiHidden):
                    ret_train = self.train_epoch_multiBN(dataset)
                else:
                    ret_train = self.train_epoch(dataset)

                # 验证逻辑
                self.train(False) # 修改训练标志为False
                if isinstance(dataset, dict) and (self.multiBN or self.multiHidden):
                    ret_valid = self.evaluate_multiBN(dataset)
                else:
                    ret_valid = self.evaluate(valid_dataset)

                # 记录历史数据
                self._update_history(ret_train, ret_valid, t_start, ep, verbose)

                # 早停逻辑（仅在正常训练阶段生效）
                if self.earlystop and ((not is_restart_phase and ep > self.earlystop_min_ep) or is_restart_phase):
                    # 获取当前指标
                    current_val = ret_valid[4] if self.earlystop == 'r2' else ret_valid[0]

                    # 计算EMA（指数移动平均）
                    ema_val = current_val if ema_val is None else \
                        self.earlystop_ema_alpha * current_val + (1 - self.earlystop_ema_alpha) * ema_val

                    # 改进判断逻辑
                    is_improved = (ema_val > best_ema) if self.earlystop == 'r2' else \
                        (ema_val < best_ema)

                    if is_improved:
                        # 更新最佳记录
                        best_ema = ema_val
                        no_improvement_count = 0
                        best_model_weights = copy.deepcopy(self.state_dict())
                        print("Best model weights updated!")
                    else:
                        no_improvement_count += 1

                    # 早停触发
                    if no_improvement_count >= self.earlystop_patience and not is_restart_phase:
                        print(f"Early stopping at epoch {ep}! Best {'R²' if self.earlystop == 'r2' else 'Loss'}: {best_ema:.4f}")

                        # 进入微调阶段
                        if self.earlystop_restart_epochs > 0:
                            print(f"Starting warm restart for {self.earlystop_restart_epochs} epochs")
                            self.load_state_dict(best_model_weights)
                            self._adjust_learning_rate(self.earlystop_restart_lr_factor)
                            nb_epochs = self.earlystop_restart_epochs
                            ep = 0
                            is_restart_phase = True
                            break  # 跳出当前循环，进入微调阶段
                        else:
                            # 最终恢复最佳权重
                            self.load_state_dict(best_model_weights)
                            return self._finalize_training(best_model_weights)
                else:
                    # 获取当前指标
                    current_val = ret_valid[4] if self.earlystop == 'r2' else ret_valid[0]
                    # 计算EMA（指数移动平均）
                    ema_val = current_val
                    # 更新最佳记录
                    best_ema = ema_val
                    no_improvement_count = 0
                    best_model_weights = copy.deepcopy(self.state_dict())

                # # 保存模型检查点
                # if ep % 100 == 0:
                #     # 生成带时间的文件名，格式示例：model_checkpoint_epoch_100_20240525_153030.pth
                #     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                #     save_model_state(self, f"model_checkpoint_epoch_{ep}_{timestamp}.pth")

            # 正常训练完成但未触发早停
            if not is_restart_phase:
                break
            elif ep>0:  # 完成微调阶段
                break

        # 最终恢复最佳权重
        if self.earlystop and best_model_weights is not None and self.earlystop_restore_best_weights:
            self.load_state_dict(best_model_weights)

        return self._finalize_training(best_model_weights)

    def fit_validate_step(
        self,
        dataset,
        valid_dataset,
        nb_epochs=10,
        verbose=True,
        wandb=None,
    ):
        self.hist_train = []
        self.hist_valid = []
        self.wall_clock_time = []

        # 早停初始化
        best_ema = -np.inf if self.earlystop == 'r2' else np.inf
        no_improvement_count = 0
        best_model_weights = None
        ema_val = None  # EMA平滑值

        # 训练阶段标记（正常训练/微调阶段）
        is_restart_phase = False
        total_epochs = nb_epochs  # 保存原始epoch数

        current_lr = self.optimizer_instance.param_groups[0]['lr']
        print(f"Current learning rate before scheduler update: {current_lr:.6f}")

        # 主训练循环（支持重启）
        while nb_epochs > 0:
            for ep in range(nb_epochs):
                t_start = time.time()

                # 训练逻辑
                self.train()
                # if self.output_feedback:
                #     self.input_group.set_epoch(ep)
                ret_train = self.train_epoch(dataset)

                # 验证逻辑
                self.train(False)
                ret_valid = self.evaluate(valid_dataset)

                # 记录历史数据
                self._update_history(ret_train, ret_valid, t_start, ep, verbose)

                # 早停逻辑（仅在正常训练阶段生效）
                if self.earlystop and ((not is_restart_phase and ep > self.earlystop_min_ep) or is_restart_phase):
                    # 获取当前指标
                    current_val = ret_valid[4] if self.earlystop == 'r2' else ret_valid[0]

                    # 计算EMA（指数移动平均）
                    ema_val = current_val if ema_val is None else \
                        self.earlystop_ema_alpha * current_val + (1 - self.earlystop_ema_alpha) * ema_val

                    # 改进判断逻辑
                    is_improved = (ema_val > best_ema) if self.earlystop == 'r2' else \
                        (ema_val < best_ema)

                    if is_improved:
                        # 更新最佳记录
                        best_ema = ema_val
                        no_improvement_count = 0
                        best_model_weights = copy.deepcopy(self.state_dict())
                        print("Best model weights updated!")
                    else:
                        no_improvement_count += 1

                    # 早停触发
                    if no_improvement_count >= self.earlystop_patience and not is_restart_phase:
                        print(f"Early stopping at epoch {ep}! Best {'R²' if self.earlystop == 'r2' else 'Loss'}: {best_ema:.4f}")

                        # 进入微调阶段
                        if self.earlystop_restart_epochs > 0:
                            print(f"Starting warm restart for {self.earlystop_restart_epochs} epochs")
                            self.load_state_dict(best_model_weights)
                            self._adjust_learning_rate(self.earlystop_restart_lr_factor)
                            nb_epochs = self.earlystop_restart_epochs
                            ep = 0
                            is_restart_phase = True
                            break  # 跳出当前循环，进入微调阶段
                        else:
                            # 最终恢复最佳权重
                            self.load_state_dict(best_model_weights)
                            return self._finalize_training(best_model_weights)



            # 正常训练完成但未触发早停
            if not is_restart_phase:
                break
            elif ep>0:  # 完成微调阶段
                break

        # 最终恢复最佳权重
        if best_model_weights is not None and self.earlystop_restore_best_weights:
            self.load_state_dict(best_model_weights)

        return self._finalize_training(best_model_weights)

    # 新增辅助方法
    def _adjust_learning_rate(self, factor):
        """调整学习率（余弦退火示例）"""
        for param_group in self.optimizer_instance.param_groups:
            param_group['lr'] *= factor
        # # 可替换为余弦退火调度器
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
        # scheduler_kwargs = {"T_max": restart_epochs}
        # self.scheduler_instance = scheduler(
        #     self.optimizer_instance, **scheduler_kwargs
        # )
        # # self.scheduler = CosineAnnealingLR(self.optimizer, T_max=restart_epochs)

    def _update_history(self, ret_train, ret_valid, t_start, ep, verbose):
        """更新训练历史"""
        self.hist_train.append(ret_train)
        self.hist_valid.append(ret_valid)

        if self.wandb is not None:
            self.wandb.log(
                {
                    key: value
                    for (key, value) in zip(
                        self.get_metric_names()
                        + self.get_metric_names(prefix="val_"),
                        ret_train.tolist() + ret_valid.tolist(),
                    )
                }
            )

        if verbose:
            t_iter = time.time() - t_start
            self.wall_clock_time.append(t_iter)
            print(
                "%02i %s --%s t_iter=%.2f"
                % (
                    ep,
                    self.get_metrics_string(ret_train),
                    self.get_metrics_string(ret_valid, prefix="val_"),
                    t_iter,
                )
            )

    def _finalize_training(self, best_weights):
        """最终处理"""
        self.hist = np.concatenate((np.array(self.hist_train), np.array(self.hist_valid)))
        self.fit_runs.append(self.hist)
        return {
            **self.get_metrics_history_dict(np.array(self.hist_train), prefix=""),
            **self.get_metrics_history_dict(np.array(self.hist_valid), prefix="val_")
        }

    def multiGPU_init(
            self,
            device_ids: Optional[Sequence[Union[int, torch.device]]] = None,
            output_device: Optional[Union[int, torch.device]] = None,
            dim: int = 0,
    ) -> None:
        self.multiGPU={}
        self.multiGPU['output'] = []
        self.multiGPU['device_ids'] = []


        device_type = _get_available_device_type()
        assert device_type is not None, "No GPU devices found"

        # If no GPU is specified, all GPUs will be used.
        # 如果没有指定gpu，则使用所有gpu
        if device_ids is None:
            device_ids = _get_all_device_indices()
            warnings.warn("No GPU is specified, using all GPUs.!", category=UserWarning)
        assert len(device_ids) > 1, \
            "At least two GPUs are required for DataParallel training."
        assert len(device_ids) <= len(_get_all_device_indices()), \
            "Not enough GPUs available. len(device_ids) > len(_get_all_device_indices())"

        if output_device is None:
            output_device = device_ids[0]


        self.multiGPU['dim'] = dim
        self.multiGPU['device_ids'] = [torch.device(_get_device_index(x, True)) for x in device_ids]
        self.multiGPU['device_ids'] = [torch.device(id) for id in self.multiGPU['device_ids']]
        self.multiGPU['output_device'] = torch.device(_get_device_index(output_device, True))
        self.multiGPU['src_device_obj'] = self.multiGPU['device_ids'][0]

        if device_type == "cuda":
            _check_balance(self.multiGPU['device_ids'])

    def state_sequence_toTensor(self):
        for g in self.groups:
            seq = g.stored_sequences_['out']
            if type(seq) == list:
                g.stored_sequences_['out'] = torch.stack(seq, dim=1)

    def get_metric_names(self, prefix="", postfix=""):
        if self.hist_train[0].shape[0] == 4:
            outputnum = 1
        else:
            outputnum=2
        if hasattr(self, 'firingRate_Decoder'):
            metric_names = ["loss", "reg_loss", "fr_loss"] + self.loss_stack.get_metric_names(outputnum)
        else:
            metric_names = ["loss", "reg_loss"] + self.loss_stack.get_metric_names(outputnum)
        return ["%s%s%s" % (prefix, k, postfix) for k in metric_names]

    def get_metrics_string(self, metrics_array, prefix="", postfix=""):
        s = ""
        names = self.get_metric_names(prefix, postfix)
        for val, name in zip(metrics_array, names):
            s = s + " %s=%.3g" % (name, val)
        return s

    def rep(self):
        for c in self.connections:
            # if isinstance(c.op, RepVGGplusBlock1d) or isinstance(c.op, RepVGGLinearBlock1d) or isinstance(c.op, RepVGGplusBlock1dV2):
            if isinstance(c.op, RepClassModule):
                assert not isinstance(c.op, RepVGGplusBlock1dV2), \
                    "RepVGGplusBlock1dV2 is not supported in RepClassModule, please use RepVGGplusBlock1d or RepVGGLinearBlock1d"
                assert not isinstance(c.op, RepVGGLinearBlock1d), \
                    "RepVGGLinearBlock1d is not supported in RepClassModule, please use RepVGGplusBlock1d or RepVGGLinearBlock1d"

                c.op.switch_to_deploy()
        for c in self.connections:
            if hasattr(c, "bn") and c.bn is not None:
                c.switch_to_deploy()
        print("Rep Done")

    def evolve_propagate(self):
        # t_start = time.time()
        if hasattr(self, "seq_forward") and self.seq_forward:
            for g_or_c in self.propagate_seq:
                if isinstance(g_or_c, Connection_mem_spike_decoder):
                    g_or_c.forward()
                elif isinstance(g_or_c, stork.nodes.base.CellGroup):
                    g_or_c.evolve()
                    g_or_c.clear_input()
                elif isinstance(g_or_c, stork.connections.BaseConnection):
                    g_or_c.propagate()
                elif isinstance(g_or_c, ChannelAttentionConnection) or \
                        isinstance(g_or_c, ChannelAttentionConnection_multiHead) or \
                        isinstance(g_or_c, ChannelAttentionConnection_multiHead_qkv) :
                    g_or_c.propagate()
                else:
                    raise TypeError(
                        "propagate_seq must be a list of CellGroup or Connection objects"
                    )
        else:
            self.evolve_all()
            self.propagate_all()
        # t_iter = time.time() - t_start
        # return

    def BN_switch(self, key):
        # print(f"Training on dataset: {key}")

        if key in self.multiBN:
            # 替换对应的BatchNorm
            for c_idx, c in enumerate(self.connections):

                if hasattr(c, "bn") and c.bn is not None:
                    c_id = c.multiBN_id
                    c.bn = self.multiBN[key][c_id]

                if isinstance(c.op, RepClassModule):
                    c_id = c.multiBN_id
                    if hasattr(c.op, "rbr_1x1"):
                        # module_id = c.module_ids[c.op.rbr_1x1.bn]
                        c.op.rbr_1x1.bn = self.multiBN[key][c_id]["rbr_1x1_bn"]
                    if hasattr(c.op, "rbr_dense"):
                        # module_id = c.module_ids[c.op.rbr_dense.bn]
                        c.op.rbr_dense.bn = self.multiBN[key][c_id]["rbr_dense_bn"]
                    if hasattr(c.op, "rbr_identity") and c.op.rbr_identity is not None:
                        # module_id = c.module_ids[c.op.rbr_identity]
                        c.op.rbr_identity = self.multiBN[key][c_id]["rbr_identity"]
                    if hasattr(c.op, "model") and isinstance(c.op.model, nn.ModuleList):
                        ValueError("MultiBN does not support nn.ModuleList in RepClassModule")
            if hasattr(self, 'firingRate_Decoder'):
                self.firingRate_Decoder.linear1_bn = self.multiBN[key]["linear1_bn"]

        else:
            # 计算所有现有被试的BN参数平均值
            with torch.no_grad():
                # 为新被试创建BatchNorm，使用现有被试参数的平均值进行初始化
                new_multiBN = nn.ModuleDict()
                existing_keys = list(self.multiBN.keys())

                for c_idx, c in enumerate(self.connections):
                    if hasattr(c, "bn") and c.bn is not None:
                        c_id = c.multiBN_id
                        # 计算所有现有被试的BN参数平均值
                        avg_weight = torch.stack([self.multiBN[k][c_id].weight.data for k in existing_keys]).mean(dim=0)
                        avg_bias = torch.stack([self.multiBN[k][c_id].bias.data for k in existing_keys]).mean(dim=0)
                        avg_running_mean = torch.stack([self.multiBN[k][c_id].running_mean for k in existing_keys]).mean(dim=0)
                        avg_running_var = torch.stack([self.multiBN[k][c_id].running_var for k in existing_keys]).mean(dim=0)

                        # 创建新的BN并用平均参数初始化
                        tmp_bn = self.multiBN[existing_keys[0]][c_id]
                        new_bn = nn.BatchNorm1d(num_features=tmp_bn.num_features, device=c.device)
                        new_bn.weight.data.copy_(avg_weight)
                        new_bn.bias.data.copy_(avg_bias)
                        new_bn.running_mean.copy_(avg_running_mean)
                        new_bn.running_var.copy_(avg_running_var)
                        c.bn = new_bn

                        if c_id not in new_multiBN:
                            new_multiBN[c_id] = new_bn

                    if isinstance(c.op, RepClassModule):
                        c_id = c.multiBN_id
                        rep_bn_dict = nn.ModuleDict()

                        if hasattr(c.op, "rbr_1x1"):
                            avg_weight = torch.stack([self.multiBN[k][c_id]["rbr_1x1_bn"].weight.data for k in existing_keys]).mean(dim=0)
                            avg_bias = torch.stack([self.multiBN[k][c_id]["rbr_1x1_bn"].bias.data for k in existing_keys]).mean(dim=0)
                            avg_running_mean = torch.stack([self.multiBN[k][c_id]["rbr_1x1_bn"].running_mean for k in existing_keys]).mean(dim=0)
                            avg_running_var = torch.stack([self.multiBN[k][c_id]["rbr_1x1_bn"].running_var for k in existing_keys]).mean(dim=0)

                            tmp_bn = self.multiBN[existing_keys[0]][c_id]["rbr_1x1_bn"]
                            new_bn = nn.BatchNorm1d(num_features=tmp_bn.num_features, device=c.device)
                            new_bn.weight.data.copy_(avg_weight)
                            new_bn.bias.data.copy_(avg_bias)
                            new_bn.running_mean.copy_(avg_running_mean)
                            new_bn.running_var.copy_(avg_running_var)
                            c.op.rbr_1x1.bn = new_bn
                            rep_bn_dict["rbr_1x1_bn"] = new_bn

                        if hasattr(c.op, "rbr_dense"):
                            avg_weight = torch.stack([self.multiBN[k][c_id]["rbr_dense_bn"].weight.data for k in existing_keys]).mean(dim=0)
                            avg_bias = torch.stack([self.multiBN[k][c_id]["rbr_dense_bn"].bias.data for k in existing_keys]).mean(dim=0)
                            avg_running_mean = torch.stack([self.multiBN[k][c_id]["rbr_dense_bn"].running_mean for k in existing_keys]).mean(dim=0)
                            avg_running_var = torch.stack([self.multiBN[k][c_id]["rbr_dense_bn"].running_var for k in existing_keys]).mean(dim=0)

                            tmp_bn = self.multiBN[existing_keys[0]][c_id]["rbr_dense_bn"]
                            new_bn = nn.BatchNorm1d(num_features=tmp_bn.num_features, device=c.device)
                            new_bn.weight.data.copy_(avg_weight)
                            new_bn.bias.data.copy_(avg_bias)
                            new_bn.running_mean.copy_(avg_running_mean)
                            new_bn.running_var.copy_(avg_running_var)
                            c.op.rbr_dense.bn = new_bn
                            rep_bn_dict["rbr_dense_bn"] = new_bn

                        if hasattr(c.op, "rbr_identity") and c.op.rbr_identity is not None:
                            avg_weight = torch.stack([self.multiBN[k][c_id]["rbr_identity"].weight.data for k in existing_keys]).mean(dim=0)
                            avg_bias = torch.stack([self.multiBN[k][c_id]["rbr_identity"].bias.data for k in existing_keys]).mean(dim=0)
                            avg_running_mean = torch.stack([self.multiBN[k][c_id]["rbr_identity"].running_mean for k in existing_keys]).mean(dim=0)
                            avg_running_var = torch.stack([self.multiBN[k][c_id]["rbr_identity"].running_var for k in existing_keys]).mean(dim=0)

                            tmp_bn = self.multiBN[existing_keys[0]][c_id]["rbr_identity"]
                            new_bn = nn.BatchNorm1d(num_features=tmp_bn.num_features, device=c.device)
                            new_bn.weight.data.copy_(avg_weight)
                            new_bn.bias.data.copy_(avg_bias)
                            new_bn.running_mean.copy_(avg_running_mean)
                            new_bn.running_var.copy_(avg_running_var)
                            c.op.rbr_identity = new_bn
                            rep_bn_dict["rbr_identity"] = new_bn

                        if hasattr(c.op, "model") and isinstance(c.op.model, nn.ModuleList):
                            raise ValueError("MultiBN does not support nn.ModuleList in RepClassModule")

                        if rep_bn_dict:
                            new_multiBN[c_id] = rep_bn_dict

                if hasattr(self, 'firingRate_Decoder'):
                    avg_weight = torch.stack([self.multiBN[k]["linear1_bn"].weight.data for k in existing_keys]).mean(dim=0)
                    avg_bias = torch.stack([self.multiBN[k]["linear1_bn"].bias.data for k in existing_keys]).mean(dim=0)
                    avg_running_mean = torch.stack([self.multiBN[k]["linear1_bn"].running_mean for k in existing_keys]).mean(dim=0)
                    avg_running_var = torch.stack([self.multiBN[k]["linear1_bn"].running_var for k in existing_keys]).mean(dim=0)

                    tmp_bn = self.multiBN[existing_keys[0]]["linear1_bn"]
                    new_bn = nn.BatchNorm1d(num_features=tmp_bn.num_features, device=c.device)
                    new_bn.weight.data.copy_(avg_weight)
                    new_bn.bias.data.copy_(avg_bias)
                    new_bn.running_mean.copy_(avg_running_mean)
                    new_bn.running_var.copy_(avg_running_var)
                    self.firingRate_Decoder.linear1_bn = new_bn
                    new_multiBN["linear1_bn"] = new_bn

                # 将新被试的BN参数存储到multiBN中
                self.multiBN[key] = new_multiBN
        # else:
        #     # 新建对应的BatchNorm
        #     for c_idx, c in enumerate(self.connections):
        #
        #         if hasattr(c, "bn") and c.bn is not None:
        #             c_id = c.multiBN_id
        #             tmp_bn = self.multiBN[list(self.multiBN.keys())[-1]][c_id]
        #             c.bn = nn.BatchNorm1d(num_features=tmp_bn.num_features,device=c.device)
        #         if isinstance(c.op, RepClassModule):
        #             c_id = c.multiBN_id
        #             if hasattr(c.op, "rbr_1x1"):
        #                 tmp_bn = self.multiBN[list(self.multiBN.keys())[-1]][c_id]["rbr_1x1_bn"]
        #                 c.op.rbr_1x1.bn = nn.BatchNorm1d(num_features=tmp_bn.num_features,device=c.device)
        #             if hasattr(c.op, "rbr_dense"):
        #                 tmp_bn = self.multiBN[list(self.multiBN.keys())[-1]][c_id]["rbr_dense_bn"]
        #                 c.op.rbr_dense.bn = nn.BatchNorm1d(num_features=tmp_bn.num_features,device=c.device)
        #             if hasattr(c.op, "rbr_identity") and c.op.rbr_identity is not None:
        #                 tmp_bn = self.multiBN[list(self.multiBN.keys())[-1]][c_id]["rbr_identity"]
        #                 c.op.rbr_identity = nn.BatchNorm1d(num_features=tmp_bn.num_features,device=c.device)
        #             if hasattr(c.op, "model") and isinstance(c.op.model, nn.ModuleList):
        #                 ValueError("MultiBN does not support nn.ModuleList in RepClassModule")
        #     if hasattr(self, 'firingRate_Decoder'):
        #         tmp_bn = self.multiBN[list(self.multiBN.keys())[-1]][c_id]["linear1_bn"]
        #         self.firingRate_Decoder.linear1_bn = nn.BatchNorm1d(num_features=tmp_bn.num_features,device=c.device)



        return

    def hidden_switch(self, key):
        # print(f"Training on dataset: {key}")

        assert hasattr(self.connections[0], 'dst') and 'block1_linear_hidden' in self.connections[0].dst.name
        assert hasattr(self.connections[0], 'src') and 'Input' in self.connections[0].src.name

        if key in self.multiHidden:
            self.connections[0].op = self.multiHidden[key]["op"]
            self.connections[0].bn = self.multiHidden[key]["op_bn"]
            if "recurrent" in self.multiHidden[key]:
                assert hasattr(self.connections[1], 'dst') and 'block1_linear_hidden' in self.connections[1].dst.name
                assert hasattr(self.connections[1], 'src') and 'block1_linear_hidden' in self.connections[1].src.name
                self.connections[1].op = self.multiHidden[key]["recurrent"]
                self.connections[1].bn = self.multiHidden[key]["recurrent_bn"]

            # self.group2.load_state_dict(self.multiHidden[key].LIFdata)
            self.group2 = self.multiHidden[key]["LIF"]
            self.propagate_seq[3] = self.group2
            self.groups[1] = self.group2
            for c in self.connections:
                if hasattr(c, 'dst') and 'block1_linear_hidden' in c.dst.name:
                    c.dst = self.group2
                if hasattr(c, 'src') and 'block1_linear_hidden' in c.src.name:
                    c.src = self.group2
        else:
            # 计算所有现有被试的隐藏层参数平均值
            with torch.no_grad():
                existing_keys = list(self.multiHidden.keys())

                assert len(existing_keys) > 0, "No existing hidden layers found in multiHidden"
                # 处理主连接的 op 和 bn
                if len(existing_keys) > 0:
                    # 计算 op 参数平均值
                    avg_op_state = {}
                    for param_name in self.multiHidden[existing_keys[0]]["op"].state_dict():
                        avg_op_state[param_name] = torch.stack([
                            self.multiHidden[k]["op"].state_dict()[param_name]
                            for k in existing_keys
                        ]).mean(dim=0)

                    # 计算 bn 参数平均值
                    # 计算所有现有被试的BN参数平均值
                    avg_weight = torch.stack([self.multiHidden[k]["op_bn"].weight.data for k in existing_keys]).mean(dim=0)
                    avg_bias = torch.stack([self.multiHidden[k]["op_bn"].bias.data for k in existing_keys]).mean(dim=0)
                    avg_running_mean = torch.stack([self.multiHidden[k]["op_bn"].running_mean for k in existing_keys]).mean(dim=0)
                    avg_running_var = torch.stack([self.multiHidden[k]["op_bn"].running_var for k in existing_keys]).mean(dim=0)
                    # 创建新的BN并用平均参数初始化
                    tmp_bn = self.multiHidden[existing_keys[0]]["op_bn"]
                    new_bn = nn.BatchNorm1d(num_features=tmp_bn.num_features, device=self.device)
                    new_bn.weight.data.copy_(avg_weight)
                    new_bn.bias.data.copy_(avg_bias)
                    new_bn.running_mean.copy_(avg_running_mean)
                    new_bn.running_var.copy_(avg_running_var)
                    # avg_bn_state = {}
                    # for param_name in self.multiHidden[existing_keys[0]]["op_bn"].state_dict():
                    #     avg_bn_state[param_name] = torch.stack([
                    #         self.multiHidden[k]["op_bn"].state_dict()[param_name]
                    #         for k in existing_keys
                    #     ]).mean(dim=0)

                    # 创建新的 op 和 bn，用平均参数初始化
                    tmp_op = self.multiHidden[existing_keys[0]]["op"]
                    tmp_bn = self.multiHidden[existing_keys[0]]["op_bn"]
                    self.connections[0].op = copy.deepcopy(tmp_op).to(self.device)
                    self.connections[0].bn = copy.deepcopy(tmp_bn).to(self.device)
                    self.connections[0].op.load_state_dict(avg_op_state)
                    self.connections[0].bn=new_bn

                    # 处理 recurrent 连接（如果存在）
                    if "recurrent" in self.multiHidden[existing_keys[0]]:
                        # 计算 recurrent 参数平均值
                        avg_recurrent_state = {}
                        for param_name in self.multiHidden[existing_keys[0]]["recurrent"].state_dict():
                            avg_recurrent_state[param_name] = torch.stack([
                                self.multiHidden[k]["recurrent"].state_dict()[param_name]
                                for k in existing_keys
                            ]).mean(dim=0)

                        # 计算所有现有被试的BN参数平均值
                        avg_weight = torch.stack([self.multiHidden[k]["recurrent_bn"].weight.data for k in existing_keys]).mean(dim=0)
                        avg_bias = torch.stack([self.multiHidden[k]["recurrent_bn"].bias.data for k in existing_keys]).mean(dim=0)
                        avg_running_mean = torch.stack([self.multiHidden[k]["recurrent_bn"].running_mean for k in existing_keys]).mean(dim=0)
                        avg_running_var = torch.stack([self.multiHidden[k]["recurrent_bn"].running_var for k in existing_keys]).mean(dim=0)
                        # 创建新的BN并用平均参数初始化
                        tmp_bn = self.multiHidden[existing_keys[0]]["recurrent_bn"]
                        new_bn = nn.BatchNorm1d(num_features=tmp_bn.num_features, device=self.device)
                        new_bn.weight.data.copy_(avg_weight)
                        new_bn.bias.data.copy_(avg_bias)
                        new_bn.running_mean.copy_(avg_running_mean)
                        new_bn.running_var.copy_(avg_running_var)
                        # avg_recurrent_bn_state = {}
                        # for param_name in self.multiHidden[existing_keys[0]]["recurrent_bn"].state_dict():
                        #     avg_recurrent_bn_state[param_name] = torch.stack([
                        #         self.multiHidden[k]["recurrent_bn"].state_dict()[param_name]
                        #         for k in existing_keys
                        #     ]).mean(dim=0)

                        tmp_recurrent = self.multiHidden[existing_keys[0]]["recurrent"]
                        tmp_recurrent_bn = self.multiHidden[existing_keys[0]]["recurrent_bn"]
                        self.connections[1].op = copy.deepcopy(tmp_recurrent).to(self.device)
                        self.connections[1].bn = copy.deepcopy(tmp_recurrent_bn).to(self.device)
                        self.connections[1].op.load_state_dict(avg_recurrent_state)
                        self.connections[1].bn=new_bn

                    # 计算 LIF 神经元组参数平均值
                    avg_lif_state = {}
                    for param_name in self.multiHidden[existing_keys[0]]["LIF"].state_dict():
                        avg_lif_state[param_name] = torch.stack([
                            self.multiHidden[k]["LIF"].state_dict()[param_name]
                            for k in existing_keys
                        ]).mean(dim=0)

                    # 创建新的 LIF 神经元组，用平均参数初始化
                    self.group2 = copy.deepcopy(self.multiHidden[existing_keys[0]]["LIF"]).to(self.device)
                    self.group2.load_state_dict(avg_lif_state)

                else:
                    # 如果没有现有被试，使用随机初始化（fallback）
                    tmp_op = self.multiHidden[list(self.multiHidden.keys())[-1]]["op"]
                    tmp_bn = self.multiHidden[list(self.multiHidden.keys())[-1]]["op_bn"]
                    self.connections[0].op = copy.deepcopy(tmp_op).to(self.device)
                    self.connections[0].bn = nn.BatchNorm1d(num_features=tmp_bn.num_features, device=self.device)

                    if "recurrent" in self.multiHidden[list(self.multiHidden.keys())[-1]]:
                        tmp_recurrent = self.multiHidden[list(self.multiHidden.keys())[-1]]["recurrent"]
                        tmp_recurrent_bn = self.multiHidden[list(self.multiHidden.keys())[-1]]["recurrent_bn"]
                        self.connections[1].op = copy.deepcopy(tmp_recurrent).to(self.device)
                        self.connections[1].bn = nn.BatchNorm1d(num_features=tmp_recurrent_bn.num_features, device=self.device)

                    # LIF 神经元组保持原有逻辑（重新初始化）
                    self.group2 = copy.deepcopy(self.multiHidden[list(self.multiHidden.keys())[-1]]["LIF"]).to(self.device)

            self.propagate_seq[3] = self.group2
            self.groups[1] = self.group2
            for c in self.connections:
                if hasattr(c, 'dst') and 'block1_linear_hidden' in c.dst.name:
                    c.dst = self.group2
                if hasattr(c, 'src') and 'block1_linear_hidden' in c.src.name:
                    c.src = self.group2
        # else:
        #     # 新建对应的隐藏层
        #     tmp_op = self.multiHidden[list(self.multiHidden.keys())[-1]]["op"]
        #     tmp_bn = self.multiHidden[list(self.multiHidden.keys())[-1]]["op_bn"]
        #     self.connections[0].op = copy.deepcopy(tmp_op).to(self.device)
        #     self.connections[0].bn = copy.deepcopy(tmp_bn).to(self.device)
        #
        #     if "recurrent" in self.multiHidden[list(self.multiHidden.keys())[-1]]:
        #         tmp_recurrent = self.multiHidden[list(self.multiHidden.keys())[-1]]["recurrent"]
        #         tmp_recurrent_bn = self.multiHidden[list(self.multiHidden.keys())[-1]]["recurrent_bn"]
        #         self.connections[1].op = copy.deepcopy(tmp_recurrent).to(self.device)
        #         self.connections[1].bn = copy.deepcopy(tmp_recurrent_bn).to(self.device)
        #
        #     # self.group2.load_state_dict(self.multiHidden[list(self.multiHidden.keys())[-1]].LIFdata)
        #     self.group2 = copy.deepcopy(self.multiHidden[list(self.multiHidden.keys())[-1]]["LIF"]).to(self.device)
        #     self.propagate_seq[3] = self.group2
        #     self.groups[1] = self.group2
        #     for c in self.connections:
        #         if hasattr(c, 'dst') and 'block1_linear_hidden' in c.dst.name:
        #             c.dst = self.group2
        #         if hasattr(c, 'src') and 'block1_linear_hidden' in c.src.name:
        #             c.src = self.group2

        return

    def multi_BN_set(self, cfg):
        if cfg.model.multi_BN:
            self.multiBN = nn.ModuleDict()
            for key in cfg.pretrain_monkeys:
                self.multiBN[key] = nn.ModuleDict()

            for c_idx, c in enumerate(self.connections):
                c_id = f"connection_{c_idx}"

                if hasattr(c, "bn") and c.bn is not None:
                    for key in cfg.pretrain_monkeys:
                        self.multiBN[key][c_id] = copy.deepcopy(c.bn).to(cfg.device)
                    # 存储引用以便后续访问
                    c.multiBN_id = c_id

                if isinstance(c.op, RepClassModule):
                    for key in cfg.pretrain_monkeys:
                        self.multiBN[key][c_id] = nn.ModuleDict()

                    # 存储引用以便后续访问
                    c.multiBN_id = c_id
                    # c.module_ids = {}

                    if hasattr(c.op, "rbr_1x1"):
                        for key in cfg.pretrain_monkeys:
                            self.multiBN[key][c_id]["rbr_1x1_bn"] = copy.deepcopy(c.op.rbr_1x1.bn).to(cfg.device)
                    if hasattr(c.op, "rbr_dense"):
                        for key in cfg.pretrain_monkeys:
                            self.multiBN[key][c_id]["rbr_dense_bn"] = copy.deepcopy(c.op.rbr_dense.bn).to(cfg.device)
                    if hasattr(c.op, "rbr_identity") and c.op.rbr_identity is not None:
                        for key in cfg.pretrain_monkeys:
                            self.multiBN[key][c_id]["rbr_identity"] = copy.deepcopy(c.op.rbr_identity).to(cfg.device)
                    if hasattr(c.op, "model") and isinstance(c.op.model, nn.ModuleList):
                        ValueError("MultiBN does not support nn.ModuleList in RepClassModule")
        else:
            self.multiBN = False
        return

    def multiHidden_set(self, cfg):
        if cfg.model.multi_hidden:
            # assert cfg.model.multi_BN, "multiHidden 需要 multiBN 的支持"
            assert cfg.model.nb_linear_hidden==1, "multiHidden 只支持 nb_linear_hidden=1 的情况"
            self.multiHidden = nn.ModuleDict()
            for key in cfg.pretrain_monkeys:
                self.multiHidden[key] = nn.ModuleDict()
                assert hasattr(self.connections[0], 'dst') and 'block1_linear_hidden' in self.connections[0].dst.name
                assert hasattr(self.connections[0], 'src') and 'Input' in self.connections[0].src.name
                self.multiHidden[key]["op"] = copy.deepcopy(self.connections[0].op).to(cfg.device)
                self.multiHidden[key]["op_bn"] = copy.deepcopy(self.connections[0].bn).to(cfg.device)
                if cfg.model.linear_hidden_recurrent[0]:
                    assert hasattr(self.connections[1], 'dst') and 'block1_linear_hidden' in self.connections[1].dst.name
                    assert hasattr(self.connections[1], 'src') and 'block1_linear_hidden' in self.connections[1].src.name
                    self.multiHidden[key]["recurrent"] = copy.deepcopy(self.connections[1].op).to(cfg.device)
                    self.multiHidden[key]["recurrent_bn"] = copy.deepcopy(self.connections[1].bn).to(cfg.device)
                # self.multiHidden[key].LIFdata = copy.deepcopy(self.group2.state_dict())
                self.multiHidden[key]["LIF"] = copy.deepcopy(self.group2).to(cfg.device)
                self.multiHidden[key]["LIF"].name = key + self.multiHidden[key]["LIF"].name
            # self.hidden_switch(key)
        else:
            self.multiHidden = False
        return

    def output_BN_del(self):
        # 删除输出层的BatchNorm
        for c in self.connections[-5:]:
            if hasattr(c, "bn") and c.bn is not None:
                c.bn = None

    def prepare_for_deepcopy(self):
        """
        准备模型以便深拷贝
        """
        for g in self.groups:
            if hasattr(g, 'prepare_for_deepcopy'):
                g.prepare_for_deepcopy()
        if hasattr(self, "multiHidden") and self.multiHidden:
            for key in self.multiHidden:
                if hasattr(self.multiHidden[key]["LIF"], 'prepare_for_deepcopy'):
                    self.multiHidden[key]["LIF"].prepare_for_deepcopy()
        self.out=None
        self.out_loss=None
        self.reg_loss=None


        return self

def forward_pass_multiGPU(model, x_batch_, target_y, cur_batch_size, record=False, stored_sequences_Flag=False):
    with (torch.autograd.profiler.record_function("DataParallel.forward_pass_RSNN")):
        for t in chain(model.parameters(), model.buffers()):
            if t.device != model.multiGPU['src_device_obj']:
                raise RuntimeError(
                    "module must have its parameters and buffers "
                    f"on device {model.multiGPU['src_device_obj']} (device_ids[0]) but found one of "
                    f"them on device: {t.device}"
                )

        model.reset_states()


        # # 获取设备信息
        # device_ids = model.multiGPU['device_ids']
        # num_devices = len(device_ids)
        # # 计算切分尺寸 (处理无法整除的情况)
        # split_sizes = [len(x_batch_) // num_devices] * num_devices
        # remainder = len(x_batch_) % num_devices
        # for i in range(remainder):
        #     split_sizes[i] += 1
        # # 自动分割数据和标签到不同设备
        # x_batches = [
        #     chunk.to(device_ids[i])
        #     for i, chunk in enumerate(x_batch_.split(split_sizes, dim=0))
        # ]
        # target_batches = [
        #     chunk.to(device_ids[i])
        #     for i, chunk in enumerate(target_y.split(split_sizes, dim=0))
        # ]



        # x_batch, y_batch, cur_batch_size, record = scatter(x_batch_,
        #                                         target_y,
        #                                         cur_batch_size,
        #                                         record,
        #                                         device_ids=model.multiGPU['device_ids'],
        #                                         dim=model.multiGPU['dim']
        #                                                         )
        # replicas = get_replicas(
        #     model,
        #     batchsize = cur_batch_size,
        #     detach = not torch.is_grad_enabled(),
        # )

        x_batches, _, batches_size = data_scatter(
            x_batch_,
            target_y,
            device_ids=model.multiGPU['device_ids'],
        )
        replicas = get_replicas(
            model,
            batchsize = batches_size,
            detach = not torch.is_grad_enabled(),
        )
        outputs = parallel_apply(
            replicas,
            x_batches,
            batches_size,
            [record] * len(model.multiGPU['device_ids']),
            model.multiGPU['device_ids'],
        )

        # outputs
        outputs=gather(outputs, model.multiGPU['output_device'], model.multiGPU['dim'])

        model.multiGPU['reg_loss'] = torch.mean(
            torch.stack(
                [replica.reg_loss.to(model.multiGPU['output_device']) for replica in replicas]
            )
        ).unsqueeze(0)

        if stored_sequences_Flag:
            for group_indx in range(len(model.groups)):

                # store_state_sequences\stored_sequences_
                assert (replicas[0].groups[group_indx].store_state_sequences ==
                        replicas[1].groups[group_indx].store_state_sequences), \
                    "The store_state_sequences of the two modules is not equal"
                assert (model.groups[group_indx].store_state_sequences ==
                        replicas[0].groups[group_indx].store_state_sequences), \
                    "The store_state_sequences has been changed"

                # stored_sequences_
                for key in model.groups[group_indx].store_state_sequences:
                    assert (not model.groups[group_indx].stored_sequences_[key]), \
                        "The stored_sequences_ is not empty"
                    assert all([type(replicas[i].groups[group_indx].stored_sequences_[key]) == torch.Tensor for i in range(len(model.multiGPU['device_ids']))]), \
                        "The type of stored_sequences_ is not torch.Tensor"
                    assert all(replicas[i].groups[group_indx].flat_seq_shape[1:2] == replicas[0].groups[group_indx].flat_seq_shape[1:2] for i in range(len(model.multiGPU['device_ids']))), \
                        "The flat_seq_shape_each_thread is not equal"

                    store_states_tmp = []
                    for i in range(len(model.multiGPU['device_ids'])):
                        store_states_tmp.append(replicas[i].groups[group_indx].stored_sequences_[key].to(model.multiGPU['output_device']))
                    model.groups[group_indx].stored_sequences_[key] = torch.cat(store_states_tmp, dim=0)
                    for i in range(len(model.multiGPU['device_ids'])):
                        replicas[i].groups[group_indx].stored_sequences_[key] = []

                flat_seq_shape0 = []
                for i in range(len(model.multiGPU['device_ids'])):
                    flat_seq_shape0.append(replicas[i].groups[group_indx].flat_seq_shape[0])
                model.groups[group_indx].flat_seq_shape = tuple([sum(flat_seq_shape0), replicas[0].groups[group_indx].flat_seq_shape[1], replicas[0].groups[group_indx].flat_seq_shape[2]])


        for replica in replicas:
            replica.reset_states()

    return outputs


def poisson_nll_loss(x, r):
    """
    计算泊松负对数似然损失 (Poisson Negative Log-Likelihood Loss)

    参数:
        x: 观测的神经锋电位计数 (真实值)，形状为 (batch_size, seq_len, num_neurons) 或 (seq_len, num_neurons)
        r: 模型重构的发放率 (预测值)，形状与 x 相同

    返回:
        标量损失值
    """
    # 确保发放率 r 大于0（避免log(0)问题）
    assert r.min() >= 0, "发放率 r 必须大于等于0"
    r_safe = torch.clamp(r, min=1e-8)
    x=x.to(r_safe.device)

    # 核心损失计算：r - x*log(r)
    # 注意：常数项 log(x!) 被省略
    pointwise_nll = r_safe - x * torch.log(r_safe)

    # 计算所有维度上的平均值 (时间步 + 神经元)
    # 求和所有元素后除以元素总数
    total_elements = torch.numel(x)  # 总元素数 = batch_size * seq_len * num_neurons
    l_x = torch.sum(pointwise_nll) / total_elements

    return l_x

