"""
Neural network layer implementations for HelioX training.
This module provides modular layer classes for building neural networks.
"""

import numpy as np
from neuron import h
from typing import Any, Dict, List, Optional


class DenseLayer:
    """
    Dense (fully connected) layer for the neural network.
    Can be used for hidden layers with PassiveHPC neurons or output layers with point neurons.
    """
    
    def __init__(self, n_neurons, neuron_type='point', layer_name='dense', morph_file=None, cm=1.0, layer_role='hidden', learning_rate=0.005):
        """
        Initialize the dense layer.
        
        Args:
            n_neurons: Number of neurons in this layer
            neuron_type: Type of neuron ('point' or 'passive_hpc')
            layer_name: Name of the layer (e.g., 'hidden', 'output')
            morph_file: Morphology file for PassiveHPC neurons (required if neuron_type='passive_hpc')
            cm: Membrane capacitance (used for PassiveHPC neurons)
            layer_role: Role of the layer ('hidden' or 'output') - affects how connections are handled
            learning_rate: Learning rate for this layer (default 0.005)
        """
        self.n_neurons = n_neurons
        self.neuron_type = neuron_type
        self.layer_name = layer_name
        self.layer_role = layer_role  # 'hidden' or 'output'
        self.morph_file = morph_file
        self.cm = cm
        self.learning_rate = learning_rate  # 学习率
        self.trainable = True  # Dense layers have trainable parameters
        self.backend = None
        
        # These will be created during initialization phases
        self.neurons = []
        self.nclist = []  # NetCon list for voltage recording
        
        # ID management
        self.neuron_gids = None
        self.voltage_sids = None
        
        # Connections (will be set when connecting to other layers)
        self.gradient_aggregators = []  # For gradient aggregation (hidden layer only)
        self.gradient_aggregator_wrappers = []
        
        # Learning rates for each synapse (will be set during connection)
        self.weight_impedance_factors = []  # 每个权重突触的阻抗修正系数
        self.bias_impedance_factors = []  # 每个偏置突触的阻抗修正系数（通常为1.0）
        
        # Output layer specific
        self.needs_aggregator = (layer_role == 'hidden')  # Hidden layers need aggregators, output layers don't
        self.weight_synapse_wrappers = []
        self.bias_synapse_wrappers = []

    def set_backend(self, backend):
        """Attach hybrid backend controller."""
        self.backend = backend

    def _register_heliox_optimizer_params(self):
        """Register parameters with the HelioX optimizer if available."""
        if (not hasattr(self, "backend") or self.backend is None
                or not getattr(self.backend, "enable_heliox", False)):
            return

        weight_wrappers = getattr(self, "weight_synapse_wrappers", []) or []
        weight_factors = getattr(self, "weight_impedance_factors", []) or []
        for wrappers, factors in zip(weight_wrappers, weight_factors):
            if not wrappers or not factors:
                continue
            for wrapper, factor in zip(wrappers, factors):
                if wrapper is None or factor is None:
                    continue
                self.backend.register_optimizer_param(wrapper, factor)

        bias_wrappers = getattr(self, "bias_synapse_wrappers", []) or []
        bias_factors = getattr(self, "bias_impedance_factors", []) or []
        for wrapper, factor in zip(bias_wrappers, bias_factors):
            if wrapper is None or factor is None:
                continue
            self.backend.register_optimizer_param(wrapper, factor)

    def phase1_create_neurons(self, pc, id_manager, ihost):
        """
        Phase 1: Create neurons, register GIDs/SIDs, and setup source_var.

        Args:
            pc: ParallelContext
            id_manager: NetworkIDManager for GID management
            ihost: Host ID for parallel execution
        """
        # Register GIDs and SIDs
        self.neuron_gids = id_manager.register_gids_batch(self.layer_name, self.n_neurons)
        self.voltage_sids = id_manager.register_sids_batch(f'{self.layer_name}_voltage', self.n_neurons)
        
        # Load necessary HOC files
        if self.neuron_type == 'passive_hpc':
            h.load_file("import3d.hoc")
            h.load_file("PassiveHPC.hoc")
        elif self.neuron_type == 'point':
            h.load_file("PointNeuron.hoc")
        
        # Create neurons and setup source_var
        for i, (gid, sid) in enumerate(zip(self.neuron_gids, self.voltage_sids)):
            pc.set_gid2node(gid, ihost)
            
            # Create neuron based on type
            if self.neuron_type == 'passive_hpc':
                neuron = self._setup_hpc("PassiveHPC", self.morph_file)
                for sec in neuron.all:
                    sec.cm = self.cm
                # For PassiveHPC, soma is a section list
                nc = h.NetCon(neuron.soma[0](0.5)._ref_v, None, sec=neuron.soma[0])
                pc.source_var(neuron.soma[0](0.5)._ref_v, sid, sec=neuron.soma[0])
            elif self.neuron_type == 'point':
                neuron = h.PointNeuron()
                # For point neurons, soma is a single section
                nc = h.NetCon(neuron.soma(0.5)._ref_v, None, sec=neuron.soma)
                pc.source_var(neuron.soma(0.5)._ref_v, sid, sec=neuron.soma)
            
            self.neurons.append(neuron)
            self.nclist.append(nc)
            pc.cell(gid, nc)
            pc.outputcell(gid)
    
    def set_training_mode(self, is_training):
        """
        Set training mode for all synapses in this layer.
        
        Args:
            is_training: True for training mode, False for test mode
        """
        # Check if synapses have been created
        if not hasattr(self, 'weight_synapses'):
            return
        
        # Set learning window based on mode
        lr_start = 20 if is_training else 1e9
        lr_end = 1e9
        
        # 统一处理权重和偏置突触
        wrapper_matrix = getattr(self, 'weight_synapse_wrappers', [])
        for idx, syn_list in enumerate(self.weight_synapses):
            wrapper_list = wrapper_matrix[idx] if idx < len(wrapper_matrix) else []
            for jdx, syn in enumerate(syn_list):
                if syn:
                    syn.lr_start = lr_start
                    syn.lr_end = lr_end
                if (self.backend and self.backend.enable_heliox and
                        jdx < len(wrapper_list) and wrapper_list[jdx] is not None):
                    wrapper_list[jdx].lr_start = lr_start
                    wrapper_list[jdx].lr_end = lr_end
        
        if hasattr(self, 'bias_synapses'):
            wrapper_list = getattr(self, 'bias_synapse_wrappers', [])
            for idx, syn in enumerate(self.bias_synapses):
                if syn:
                    syn.lr_start = lr_start
                    syn.lr_end = lr_end
                if (self.backend and self.backend.enable_heliox and
                        idx < len(wrapper_list) and wrapper_list[idx] is not None):
                    wrapper_list[idx].lr_start = lr_start
                    wrapper_list[idx].lr_end = lr_end
        
        # 梯度聚合器（如果存在）
        if hasattr(self, 'gradient_aggregators'):
            wrapper_list = getattr(self, 'gradient_aggregator_wrappers', [])
            for idx, syn in enumerate(self.gradient_aggregators):
                if syn:
                    syn.lr_start = lr_start
                    syn.lr_end = lr_end
                if (self.backend and self.backend.enable_heliox and
                        idx < len(wrapper_list) and wrapper_list[idx] is not None):
                    wrapper_list[idx].lr_start = lr_start
                    wrapper_list[idx].lr_end = lr_end
    
    def load_weights(self, weight_matrix, bias_vector):
        """
        加载权重到突触。
        
        Args:
            weight_matrix: 权重矩阵，形状根据层的角色和神经元类型而定：
                - 隐藏层 with passive_hpc: (N_hidden, N_in, num_proj_dend)
                - 其他层 (point neurons): (N_current, N_prev)
            bias_vector: 偏置向量 (N_neurons,)
        """
        # Check if synapses have been created
        if not hasattr(self, 'weight_synapses'):
            print(f"Warning: {self.layer_name} has no weight_synapses to load weights into")
            return
        
        # Check if we need 3D weights (hidden layer with morphological neurons)
        if self.layer_role == 'hidden' and self.neuron_type == 'passive_hpc':
            # 隐藏层with PassiveHPC：权重矩阵是3维的
            if len(weight_matrix.shape) == 3:
                N_hidden, N_in, num_proj_dend = weight_matrix.shape
            else:
                # 如果传入的是2D，需要reshape成3D
                N_hidden, N_in = weight_matrix.shape
                num_proj_dend = 1
                weight_matrix = weight_matrix.reshape(N_hidden, N_in, num_proj_dend)
            
            # Load weights for each hidden neuron
            for hidden_idx in range(min(len(self.weight_synapses), N_hidden)):
                hidden_syns = self.weight_synapses[hidden_idx]
                wrapper_syns = []
                if (self.backend and self.backend.enable_heliox and
                        hidden_idx < len(getattr(self, 'weight_synapse_wrappers', []))):
                    wrapper_syns = self.weight_synapse_wrappers[hidden_idx]
                for j in range(N_in):
                    for k in range(num_proj_dend):
                        syn_idx = j * num_proj_dend + k
                        if syn_idx < len(hidden_syns) and hidden_syns[syn_idx]:
                            hidden_syns[syn_idx].w = weight_matrix[hidden_idx, j, k]
                        if (self.backend and self.backend.enable_heliox and
                                syn_idx < len(wrapper_syns) and wrapper_syns[syn_idx] is not None):
                            wrapper_syns[syn_idx].w = float(weight_matrix[hidden_idx, j, k])
        else:
            # Point neurons (hidden or output)：权重矩阵是2维的
            for neuron_idx in range(min(len(self.weight_synapses), weight_matrix.shape[0])):
                neuron_syns = self.weight_synapses[neuron_idx]
                wrapper_syns = []
                if (self.backend and self.backend.enable_heliox and
                        neuron_idx < len(getattr(self, 'weight_synapse_wrappers', []))):
                    wrapper_syns = self.weight_synapse_wrappers[neuron_idx]
                for prev_idx in range(min(len(neuron_syns), weight_matrix.shape[1])):
                    syn = neuron_syns[prev_idx]
                    if syn:
                        syn.w = weight_matrix[neuron_idx, prev_idx]
                    if (self.backend and self.backend.enable_heliox and
                            prev_idx < len(wrapper_syns) and wrapper_syns[prev_idx] is not None):
                        wrapper_syns[prev_idx].w = float(weight_matrix[neuron_idx, prev_idx])
        
        # 加载偏置
        if hasattr(self, 'bias_synapses'):
            wrapper_bias = getattr(self, 'bias_synapse_wrappers', [])
            for i in range(min(len(self.bias_synapses), len(bias_vector))):
                syn = self.bias_synapses[i]
                if syn:
                    syn.w = bias_vector[i]
                if (self.backend and self.backend.enable_heliox and
                        i < len(wrapper_bias) and wrapper_bias[i] is not None):
                    wrapper_bias[i].w = float(bias_vector[i])
    
    def extract_gradients(self, dt, record_time=30.0, backend: str = "auto"):
        """
        从突触提取累积梯度，并应用阻抗修正系数（不包含学习率）。
        
        Args:
            dt: 时间步长
            record_time: 记录时间（默认30.0）
            backend: 'heliox' 或 'auto'（默认）。
        
        Returns:
            dw: 权重梯度矩阵（已乘以阻抗修正系数）
            db: 偏置梯度向量（已乘以阻抗修正系数，通常为1.0）
        """
        record_steps = record_time / dt
        backend_mode = (backend or "auto").lower()
        enable_heliox = False
        if getattr(self, "backend", None) is not None:
            enable_heliox = self.backend.enable_heliox

        if backend_mode not in ("auto", "heliox"):
            raise ValueError(f"Unsupported backend mode '{backend}'")
        if not enable_heliox:
            raise ValueError("HelioX backend not available for gradient extraction.")

        results: Dict[str, Any] = {}
        
        # Check if we need 3D gradients (hidden layer with morphological neurons)
        if self.layer_role == 'hidden' and self.neuron_type == 'passive_hpc':
            # 隐藏层 with PassiveHPC：使用保存的维度信息
            N_hidden = self.n_neurons
            N_in = self.input_dim
            num_proj_dend = self.num_proj_dend
            
            if enable_heliox:
                dw_ng = np.zeros((N_hidden, N_in, num_proj_dend))
                wrappers_matrix = getattr(self, "weight_synapse_wrappers", [])
                for hidden_idx, wrapper_syns in enumerate(wrappers_matrix):
                    if hidden_idx >= len(self.weight_impedance_factors):
                        continue
                    hidden_factors = self.weight_impedance_factors[hidden_idx]
                    for syn_idx, wrapper in enumerate(wrapper_syns):
                        if wrapper is None or syn_idx >= len(hidden_factors):
                            continue
                        impedance_factor = hidden_factors[syn_idx]
                        j = syn_idx // num_proj_dend
                        k = syn_idx % num_proj_dend
                        dw_ng[hidden_idx, j, k] = impedance_factor * wrapper.acc_grad / record_steps
                db_ng = np.zeros(N_hidden)
                wrapper_bias = getattr(self, "bias_synapse_wrappers", [])
                for i, wrapper in enumerate(wrapper_bias):
                    if wrapper is None or i >= len(self.bias_impedance_factors):
                        continue
                    impedance_factor = self.bias_impedance_factors[i]
                    db_ng[i] = impedance_factor * wrapper.acc_grad / record_steps
                results["heliox"] = (dw_ng, db_ng)
        else:
            # 输出层
            N_out = len(self.weight_synapses)
            if self.weight_synapses and self.weight_synapses[0]:
                N_hidden = len(self.weight_synapses[0])
                if enable_heliox:
                    dw_ng = np.zeros((N_out, N_hidden))
                    wrappers_matrix = getattr(self, "weight_synapse_wrappers", [])
                    for output_idx, wrapper_syns in enumerate(wrappers_matrix):
                        if output_idx >= len(self.weight_impedance_factors):
                            continue
                        output_factors = self.weight_impedance_factors[output_idx]
                        for hidden_idx, wrapper in enumerate(wrapper_syns):
                            if wrapper is None or hidden_idx >= len(output_factors):
                                continue
                            impedance_factor = output_factors[hidden_idx]
                            dw_ng[output_idx, hidden_idx] = (
                                impedance_factor * wrapper.acc_grad / record_steps
                            )
                    db_ng = np.zeros(N_out)
                    wrapper_bias = getattr(self, "bias_synapse_wrappers", [])
                    for i, wrapper in enumerate(wrapper_bias):
                        if wrapper is None or i >= len(self.bias_impedance_factors):
                            continue
                        impedance_factor = self.bias_impedance_factors[i]
                        db_ng[i] = impedance_factor * wrapper.acc_grad / record_steps
                    results["heliox"] = (dw_ng, db_ng)
        
        if not results:
            return None, None
        if len(results) == 1:
            return next(iter(results.values()))
        return results
    
    def compute_backend_differences(self) -> Dict[str, Dict[str, float]]:
        """
        Compare wrapper-side variables and return error statistics.
        """
        if self.backend is None or not self.backend.enable_heliox:
            return {}
        
        differences: Dict[str, Dict[str, float]] = {}
        
        def record_metric(name: str, neuron_values: List[float], heliox_values: List[float]):
            if not neuron_values or not heliox_values:
                return
            if len(neuron_values) != len(heliox_values):
                return
            arr_neuron = np.asarray(neuron_values, dtype=float)
            arr_heliox = np.asarray(heliox_values, dtype=float)
            if arr_neuron.size == 0:
                return
            delta = np.abs(arr_neuron - arr_heliox)
            differences[name] = {
                "max_abs": float(delta.max()),
                "mean_abs": float(delta.mean())
            }
        
        def extract_array(obj: Any, attr: str, length: int) -> Optional[List[float]]:
            if obj is None or length <= 0:
                return None
            target = getattr(obj, attr, None)
            if target is None:
                return None
            # 优先使用 to_list / to_numpy
            if hasattr(target, "to_list"):
                data = target.to_list()
                if len(data) >= length:
                    return [float(data[i]) for i in range(length)]
            if hasattr(target, "to_numpy"):
                data = target.to_numpy()
                if len(data) >= length:
                    return [float(data[i]) for i in range(length)]
            values: List[float] = []
            for idx in range(length):
                try:
                    values.append(float(target[idx]))
                    continue
                except Exception:
                    pass
                alt_name = f"{attr}_{idx}"
                if hasattr(obj, alt_name):
                    try:
                        values.append(float(getattr(obj, alt_name)))
                        continue
                    except Exception:
                        return None
                try:
                    scalar = float(getattr(obj, attr))
                except Exception:
                    return None
                if length == 1:
                    values.append(scalar)
                    return values
                return None
            return values
        
        # 比较权重突触
        weight_pairs = []
        weight_wrappers = getattr(self, "weight_synapse_wrappers", [])
        for row_idx, syns in enumerate(getattr(self, "weight_synapses", [])):
            wrappers = weight_wrappers[row_idx] if row_idx < len(weight_wrappers) else []
            for col_idx, syn in enumerate(syns):
                if syn is None:
                    continue
                wrapper = wrappers[col_idx] if col_idx < len(wrappers) else None
                if wrapper is None:
                    continue
                weight_pairs.append((syn, wrapper))
        
        weight_attrs = ("v_gap", "grad_from_next", "grad_to_prev", "acc_grad", "w")
        for attr in weight_attrs:
            neuron_vals: List[float] = []
            heliox_vals: List[float] = []
            for syn, wrapper in weight_pairs:
                if not hasattr(syn, attr) or not hasattr(wrapper, attr):
                    continue
                neuron_vals.append(float(getattr(syn, attr)))
                heliox_vals.append(float(getattr(wrapper, attr)))
            record_metric(f"weight.{attr}", neuron_vals, heliox_vals)
        
        # 比较偏置突触
        bias_syns = getattr(self, "bias_synapses", [])
        bias_wrappers = getattr(self, "bias_synapse_wrappers", [])
        bias_attrs = ("grad_from_next", "grad_to_prev", "acc_grad", "w")
        for attr in bias_attrs:
            neuron_vals = []
            heliox_vals = []
            for idx, syn in enumerate(bias_syns):
                if syn is None or idx >= len(bias_wrappers):
                    continue
                wrapper = bias_wrappers[idx]
                if wrapper is None or not hasattr(syn, attr) or not hasattr(wrapper, attr):
                    continue
                neuron_vals.append(float(getattr(syn, attr)))
                heliox_vals.append(float(getattr(wrapper, attr)))
            record_metric(f"bias.{attr}", neuron_vals, heliox_vals)
        
        # 比较梯度聚合器（隐藏层）
        aggregator_pairs = []
        aggregator_wrappers = getattr(self, "gradient_aggregator_wrappers", [])
        for idx, aggregator in enumerate(getattr(self, "gradient_aggregators", [])):
            if aggregator is None or idx >= len(aggregator_wrappers):
                continue
            wrapper = aggregator_wrappers[idx]
            if wrapper is None:
                continue
            aggregator_pairs.append((aggregator, wrapper))
        
        if aggregator_pairs:
            agg_vals_neuron: List[float] = []
            agg_vals_heliox: List[float] = []
            for aggregator, wrapper in aggregator_pairs:
                if hasattr(aggregator, "aggregated_grad"):
                    agg_vals_neuron.append(float(getattr(aggregator, "aggregated_grad")))
                if wrapper is not None and hasattr(wrapper, "aggregated_grad"):
                    agg_vals_heliox.append(float(getattr(wrapper, "aggregated_grad")))
            record_metric("aggregator.aggregated_grad", agg_vals_neuron, agg_vals_heliox)
            
            grad_neuron_all: List[float] = []
            grad_heliox_all: List[float] = []
            for aggregator, wrapper in aggregator_pairs:
                n_outputs = int(getattr(aggregator, "n_outputs", 0))
                neuron_vals = extract_array(aggregator, "grad_from_output", n_outputs)
                heliox_vals = extract_array(wrapper, "grad_from_output", n_outputs) if wrapper else None
                if neuron_vals and heliox_vals:
                    grad_neuron_all.extend(neuron_vals)
                    grad_heliox_all.extend(heliox_vals)
            record_metric("aggregator.grad_from_output", grad_neuron_all, grad_heliox_all)
        
        return differences
    
    def _setup_hpc(self, model, morph):
        """
        Setup PassiveHPC neuron with morphology.
        """
        cell = getattr(h, model)()
        nl = h.Import3d_Neurolucida3()
        nl.quiet = 1
        nl.input(morph)
        imprt = h.Import3d_GUI(nl, 0)   
        imprt.instantiate(cell)    
        cell.indexSections(imprt)
        cell.geom_nsec()
        cell.geom_nseg()
        cell.delete_axon()
        cell.insertChannel()
        cell.init_rc()
        cell.biophys()
        return cell
    
    def get_neurons(self):
        """Return the list of neurons."""
        return self.neurons
    
    def get_neuron_count(self):
        """Return the number of neurons."""
        return self.n_neurons
    
    def get_gids(self):
        """Return the GIDs of neurons."""
        return self.neuron_gids
    
    def get_voltage_sids(self):
        """Return the SIDs for voltage output."""
        return self.voltage_sids
    
    def get_gradient_source_count(self):
        """
        Return the number of gradient sources this layer produces.
        For DenseLayer, each neuron produces one gradient.
        
        Returns:
            int: Number of gradient sources (equals n_neurons for DenseLayer)
        """
        return self.n_neurons
    
    def provides_loss_gradient(self):
        """
        Return whether this layer provides loss gradient.
        DenseLayer doesn't provide loss gradient directly.
        
        Returns:
            bool: False for DenseLayer
        """
        return False
    
    def calculate_impedance_matrix(self, conn_matrix, conn_loc_matrix):
        """
        计算连接的转移阻抗矩阵。
        用于优化突触的电导和学习率。
        
        Args:
            conn_matrix: 连接矩阵 (N_in, N_hidden, num_proj_dend)
            conn_loc_matrix: 连接位置矩阵 (N_in, N_hidden, num_proj_dend)
        
        Returns:
            impedance_matrix: 阻抗矩阵，bias_impedance: bias的阻抗值
        """
        # Create temporary cell for impedance calculation
        tmpcell = self._setup_hpc("PassiveHPC", self.morph_file)
        for sec in tmpcell.all:
            sec.cm = self.cm
        
        impd = h.Impedance()
        impd.loc(0.5, sec=tmpcell.soma[0])
        impd.compute(0)
        
        # Calculate bias impedance
        bias_impedance = impd.transfer(0.5, sec=tmpcell.soma[0])
        
        # Calculate transfer impedance for each connection
        N_in, N_hidden, num_proj_dend = conn_matrix.shape
        impedance_matrix = np.ones((N_in, N_hidden, num_proj_dend))
        
        for i in range(N_hidden):
            for j in range(N_in):
                for k in range(num_proj_dend):
                    dend_id = conn_matrix[j, i, k]
                    loc = conn_loc_matrix[j, i, k]
                    impedance_matrix[j, i, k] = impd.transfer(loc, sec=tmpcell.dend[dend_id])
        
        return impedance_matrix, bias_impedance
    
    def connect_layers(self, prev_layer, pc, id_manager,
                      next_layer_info=None, next_layer=None,
                      connection_pattern=None, connection_locs=None):
        """
        通用的层连接方法，支持任意层到层的连接
        
        Args:
            prev_layer: 前一层对象 
            pc: ParallelContext
            id_manager: NetworkIDManager
            next_layer_info: dict包含下一层信息 {
                'n_neurons': 下一层神经元数量,
                'grad_sids': 来自下一层的梯度SID (2D list: [next_neuron][current_neuron]),
                'aggregated_grad_sids': 聚合后的梯度SID (1D list)，可选
            }
            next_layer: 下一层对象（可选，用于查询梯度源数量）
            connection_pattern: 连接模式矩阵，None则创建全连接
            connection_locs: 连接位置矩阵，None则随机生成
        
        Returns:
            Tuple: (weight_synapses, bias_synapses, aggregators)
        """
        # Determine if aggregator is needed based on gradient source count
        needs_aggregator = False
        
        # If next_layer object is provided, use its gradient source count
        if next_layer is not None and hasattr(next_layer, 'get_gradient_source_count'):
            gradient_source_count = next_layer.get_gradient_source_count()
            # Need aggregator if there are multiple gradient sources
            needs_aggregator = (gradient_source_count > 1)
            
            # Special case: if next layer provides loss gradient directly, might not need aggregator
            if hasattr(next_layer, 'provides_loss_gradient') and next_layer.provides_loss_gradient():
                # For loss layers (like SoftMax), we typically don't need aggregator
                # as they provide direct gradient per output
                needs_aggregator = False
        
        # Fallback to old logic if next_layer not provided (for backward compatibility)
        elif next_layer_info is not None:
            needs_aggregator = ('grad_sids' in next_layer_info and 
                              'aggregated_grad_sids' in next_layer_info)
        
        if needs_aggregator:
            # 隐藏层模式：需要聚合器处理来自多个输出神经元的梯度
            return self._connect_with_aggregator(prev_layer, pc, id_manager,
                                               next_layer_info, connection_pattern, connection_locs)
        else:
            # 输出层模式：直接接收梯度，无需聚合
            return self._connect_without_aggregator(prev_layer, pc, id_manager,
                                                  next_layer_info, connection_pattern, connection_locs)
    
    def connect_from_input(self, input_layer, pc, id_manager, 
                          in2hd_conn, in2hd_conn_loc,
                          output2hidden_grad_sids, hidden_aggregated_grad_sids):
        """
        Create connections from input layer to this hidden layer.
        包括聚合器、权重突触和bias突触的创建。
        
        现在使用通用连接方法实现。
        
        Returns:
            synlist_in2hd: List of lists containing weight synapses
            synlist_in2hd_bias: List containing bias synapses  
            synlist_aggregators: List containing gradient aggregators
        """
        # 构建下一层信息（需要聚合器，因为隐藏层要处理来自多个输出神经元的梯度）
        N_out = len(output2hidden_grad_sids)  # 动态获取输出层大小，而不是硬编码10
        next_layer_info = {
            'n_neurons': N_out,
            'grad_sids': output2hidden_grad_sids,  # [output_neuron][hidden_neuron]
            'aggregated_grad_sids': hidden_aggregated_grad_sids  # [hidden_neuron]
        }
        
        # 使用通用连接方法
        weight_synapses, bias_synapses, aggregators = self.connect_layers(
            prev_layer=input_layer,
            pc=pc,
            id_manager=id_manager,
            next_layer_info=next_layer_info,
            connection_pattern=in2hd_conn,
            connection_locs=in2hd_conn_loc
        )
        
        return weight_synapses, bias_synapses, aggregators
    
    def create_backward_synapses(self, prev_layer, pc, id_manager, 
                                 softmax_grad_sids=None, output2hidden_grad_sids=None):
        """
        Create synapses from previous layer to this layer.
        For OutputLayer, also handles gradient connections.
        
        现在使用通用连接方法实现。
        
        Args:
            prev_layer: Previous layer object (hidden layer)
            pc: ParallelContext
            id_manager: NetworkIDManager
            softmax_grad_sids: Gradient SIDs from SoftMax layer
            output2hidden_grad_sids: Gradient SIDs for output to hidden connections
        
        Returns:
            Tuple of (weight_synapses, bias_synapses)
        """
        # 构建梯度连接信息（输出层模式，不需要聚合器）
        next_layer_info = {
            'softmax_grad_sids': softmax_grad_sids,
            'output2hidden_grad_sids': output2hidden_grad_sids
        }
        
        # 使用通用连接方法
        weight_synapses, bias_synapses, _ = self.connect_layers(
            prev_layer=prev_layer,
            pc=pc,
            id_manager=id_manager,
            next_layer_info=next_layer_info
        )
        
        return weight_synapses, bias_synapses
    
    def _connect_with_aggregator(self, prev_layer, pc, id_manager,
                                next_layer_info, connection_pattern, connection_locs):
        """
        使用聚合器的连接方式（隐藏层模式）
        处理来自多个下一层神经元的梯度聚合
        """
        weight_synapses = []
        bias_synapses = []
        aggregators = []
        weight_factor_list = []
        bias_factor_list = []
        weight_wrapper_matrix = []
        bias_wrapper_list = []
        aggregator_wrappers = []
        
        # 获取必要参数
        prev_voltage_sids = prev_layer.get_voltage_sids()
        current_gids = self.get_gids()
        N_prev = prev_layer.n_neurons
        N_next = next_layer_info['n_neurons']
        grad_sids = next_layer_info['grad_sids']  # [next_neuron][current_neuron]
        aggregated_grad_sids = next_layer_info['aggregated_grad_sids']  # [current_neuron]
        
        # 处理连接模式
        if connection_pattern is None:
            # 默认全连接到树突
            num_proj_dend = 1 if self.neuron_type != 'passive_hpc' else 1
            total_dend = self.get_total_dendrites() if self.neuron_type == 'passive_hpc' else 1
            
            # 生成随机连接模式
            rng = np.random.default_rng(1234)  # 固定种子保持一致性
            connection_pattern = rng.integers(0, max(1, total_dend), (N_prev, self.n_neurons, num_proj_dend))
            connection_locs = rng.random((N_prev, self.n_neurons, num_proj_dend))
        
        # 计算阻抗矩阵（如果是精细神经元）
        if self.neuron_type == 'passive_hpc':
            impedance_matrix, bias_impedance_r = self.calculate_impedance_matrix(connection_pattern, connection_locs)
        else:
            # 点神经元使用默认阻抗
            impedance_matrix = np.ones_like(connection_pattern) * 100  # 默认阻抗
            bias_impedance_r = 100
        
        # 为每个当前层神经元创建连接
        for i, current_gid in enumerate(current_gids):
            neuron_weight_syns = []
            neuron_weight_factors = []
            neuron_weight_wrappers = []
            
            if pc.gid_exists(current_gid):
                target = pc.gid2cell(current_gid)
                
                # 1. 创建梯度聚合器
                if self.neuron_type == 'passive_hpc':
                    aggregator = h.BP_Syn_Aggregator(target.soma[0](0.5))
                else:
                    # Point neurons have soma as section, not list
                    aggregator = h.BP_Syn_Aggregator(target.soma(0.5))
                
                # 设置聚合器参数
                aggregator.n_outputs = N_next  # 动态设置输入数量
                aggregator_wrapper = None
                if self.backend and self.backend.enable_heliox:
                    aggregator_wrapper = self.backend.wrap_obj(aggregator)
                
                # 注册聚合器输出
                if self.neuron_type == 'passive_hpc':
                    pc.source_var(aggregator._ref_aggregated_grad, aggregated_grad_sids[i], sec=target.soma[0])
                else:
                    pc.source_var(aggregator._ref_aggregated_grad, aggregated_grad_sids[i], sec=target.soma)
                
                # 连接来自下一层的梯度
                for next_idx in range(N_next):
                    if next_idx < len(grad_sids) and i < len(grad_sids[next_idx]):
                        pc.target_var(aggregator, aggregator._ref_grad_from_output[next_idx], grad_sids[next_idx][i])
                
                aggregators.append(aggregator)
                aggregator_wrappers.append(aggregator_wrapper)
                
                # 2. 创建权重突触
                for j in range(N_prev):
                    for k in range(connection_pattern.shape[2]):
                        if self.neuron_type == 'passive_hpc':
                            dend_id = connection_pattern[j, i, k]
                            loc = connection_locs[j, i, k]
                            syn = h.BP_Syn_FullyConnected(target.dend[dend_id](loc))
                            
                            # 计算阻抗修正系数
                            r = impedance_matrix[j, i, k]
                            r_mean = np.sqrt(np.max(impedance_matrix) * np.min(impedance_matrix))
                            syn.g = 1 / r_mean
                            impedance_factor = r / r_mean  # 只保存阻抗修正系数
                            neuron_weight_factors.append(impedance_factor)
                        else:
                            # 点神经元
                            syn = h.BP_Syn_FullyConnected(target.soma(0.5))
                            syn.g = 1 / bias_impedance_r
                            neuron_weight_factors.append(1.0)  # 点神经元的阻抗修正系数为1.0
                        syn_wrapper = None
                        if self.backend and self.backend.enable_heliox:
                            syn_wrapper = self.backend.wrap_obj(syn)
                        
                        # 连接前向信号和反向梯度
                        pc.target_var(syn, syn._ref_v_gap, prev_voltage_sids[j])
                        pc.target_var(syn, syn._ref_grad_from_next, aggregated_grad_sids[i])
                        
                        neuron_weight_syns.append(syn)
                        neuron_weight_wrappers.append(syn_wrapper)
                
                # 3. 创建偏置突触
                if self.neuron_type == 'passive_hpc':
                    bias_syn = h.BP_Syn_FullyConnected(target.soma[0](0.5))
                else:
                    bias_syn = h.BP_Syn_FullyConnected(target.soma(0.5))
                bias_syn.g = 1 / bias_impedance_r
                bias_syn.v_gap = 1.0  # 固定偏置输入
                bias_syn_wrapper = None
                if self.backend and self.backend.enable_heliox:
                    bias_syn_wrapper = self.backend.wrap_obj(bias_syn)
                
                # 连接聚合梯度
                pc.target_var(bias_syn, bias_syn._ref_grad_from_next, aggregated_grad_sids[i])
                bias_synapses.append(bias_syn)
                bias_factor_list.append(1.0)  # 偏置的阻抗修正系数为1.0
                bias_wrapper_list.append(bias_syn_wrapper)
            else:
                # 保持索引对应
                aggregators.append(None)
                aggregator_wrappers.append(None)
                bias_synapses.append(None)
                bias_factor_list.append(None)
                bias_wrapper_list.append(None)
            
            weight_synapses.append(neuron_weight_syns)
            weight_factor_list.append(neuron_weight_factors)
            weight_wrapper_matrix.append(neuron_weight_wrappers)
        
        # 保存结果
        self.weight_synapses = weight_synapses
        self.bias_synapses = bias_synapses
        self.gradient_aggregators = aggregators
        self.weight_synapse_wrappers = weight_wrapper_matrix
        self.bias_synapse_wrappers = bias_wrapper_list
        self.gradient_aggregator_wrappers = aggregator_wrappers
        self.weight_impedance_factors = weight_factor_list
        self.bias_impedance_factors = bias_factor_list

        # 保存维度信息供后续使用（兼容extract_gradients方法）
        self.input_dim = N_prev
        self.num_proj_dend = connection_pattern.shape[2] if connection_pattern is not None else 1

        self._register_heliox_optimizer_params()

        return weight_synapses, bias_synapses, aggregators
    
    def _connect_without_aggregator(self, prev_layer, pc, id_manager,
                                   next_layer_info, connection_pattern, connection_locs):
        """
        不使用聚合器的连接方式（输出层模式）
        直接接收来自SoftMax的梯度
        Now supports both point neurons and PassiveHPC neurons.
        """
        weight_synapses = []
        bias_synapses = []
        weight_factor_list = []
        bias_factor_list = []
        weight_wrapper_matrix = []
        bias_wrapper_list = []
        
        # 获取必要参数
        prev_voltage_sids = prev_layer.get_voltage_sids()
        current_gids = self.get_gids()
        N_prev = prev_layer.n_neurons
        
        # 梯度连接信息（可选）
        softmax_grad_sids = None
        output2hidden_grad_sids = None
        if next_layer_info:
            softmax_grad_sids = next_layer_info.get('softmax_grad_sids')
            output2hidden_grad_sids = next_layer_info.get('output2hidden_grad_sids')
        
        # Handle connection pattern for PassiveHPC neurons
        if self.neuron_type == 'passive_hpc':
            # Generate or use provided connection pattern
            if connection_pattern is None:
                # Default connection pattern for PassiveHPC
                num_proj_dend = 1  # Default projection per dendrite
                total_dend = self.get_total_dendrites() if self.neurons else 64  # Default value
                
                # Generate random connection pattern
                rng = np.random.default_rng(1234)  # Fixed seed for consistency
                connection_pattern = rng.integers(0, max(1, total_dend), (N_prev, self.n_neurons, num_proj_dend))
                connection_locs = rng.random((N_prev, self.n_neurons, num_proj_dend))
            
            # Calculate impedance matrix for PassiveHPC
            impedance_matrix, bias_impedance_r = self.calculate_impedance_matrix(connection_pattern, connection_locs)
        else:
            # Point neurons use standard impedance
            r = 100.0  # Standard impedance for point neurons
            bias_impedance_r = r
        
        # 为每个输出神经元创建连接
        for i, current_gid in enumerate(current_gids):
            neuron_weight_syns = []
            neuron_weight_factors = []
            neuron_weight_wrappers = []
            
            if pc.gid_exists(current_gid):
                target = pc.gid2cell(current_gid)
                
                # 创建来自每个前层神经元的权重突触
                if self.neuron_type == 'passive_hpc':
                    # PassiveHPC neurons: connect to dendrites
                    for j in range(N_prev):
                        for k in range(connection_pattern.shape[2]):
                            dend_id = connection_pattern[j, i, k]
                            loc = connection_locs[j, i, k]
                            syn = h.BP_Syn_FullyConnected(target.dend[dend_id](loc))
                            
                            # Calculate conductance based on impedance
                            r = impedance_matrix[j, i, k]
                            r_mean = np.sqrt(np.max(impedance_matrix) * np.min(impedance_matrix))
                            syn.g = 1 / r_mean
                            impedance_factor = r / r_mean
                            
                            # 连接前向信号
                            pc.target_var(syn, syn._ref_v_gap, prev_voltage_sids[j])
                            
                            # 连接来自SoftMax的梯度（如果有）
                            if softmax_grad_sids and i < len(softmax_grad_sids):
                                pc.target_var(syn, syn._ref_grad_from_next, softmax_grad_sids[i])
                            
                            # 注册输出到隐藏层的梯度（如果有）
                            if output2hidden_grad_sids and i < len(output2hidden_grad_sids):
                                # For PassiveHPC, need to handle multiple connections per neuron pair
                                syn_idx = j * connection_pattern.shape[2] + k
                                if syn_idx < len(output2hidden_grad_sids[i]):
                                    pc.source_var(syn._ref_grad_to_prev, output2hidden_grad_sids[i][syn_idx], 
                                                sec=target.dend[dend_id])
                            
                            neuron_weight_syns.append(syn)
                            neuron_weight_factors.append(impedance_factor)
                            if self.backend and self.backend.enable_heliox:
                                neuron_weight_wrappers.append(self.backend.wrap_obj(syn))
                            else:
                                neuron_weight_wrappers.append(None)
                else:
                    # Point neurons: connect to soma
                    for j in range(N_prev):
                        syn = h.BP_Syn_FullyConnected(target.soma(0.5))
                        syn.g = 1 / r
                        
                        # 连接前向信号
                        pc.target_var(syn, syn._ref_v_gap, prev_voltage_sids[j])
                        
                        # 连接来自SoftMax的梯度（如果有）
                        if softmax_grad_sids and i < len(softmax_grad_sids):
                            pc.target_var(syn, syn._ref_grad_from_next, softmax_grad_sids[i])
                        
                        # 注册输出到隐藏层的梯度（如果有）
                        if output2hidden_grad_sids and i < len(output2hidden_grad_sids):
                            pc.source_var(syn._ref_grad_to_prev, output2hidden_grad_sids[i][j], sec=target.soma)
                        
                        neuron_weight_syns.append(syn)
                        neuron_weight_factors.append(1.0)  # 点神经元的阻抗修正系数为1.0
                        if self.backend and self.backend.enable_heliox:
                            neuron_weight_wrappers.append(self.backend.wrap_obj(syn))
                        else:
                            neuron_weight_wrappers.append(None)
                
                # 创建偏置突触
                if self.neuron_type == 'passive_hpc':
                    bias_syn = h.BP_Syn_FullyConnected(target.soma[0](0.5))
                else:
                    bias_syn = h.BP_Syn_FullyConnected(target.soma(0.5))
                bias_syn.g = 1 / bias_impedance_r
                bias_syn.v_gap = 1.0
                if self.backend and self.backend.enable_heliox:
                    bias_wrapper = self.backend.wrap_obj(bias_syn)
                else:
                    bias_wrapper = None
                
                # 连接来自SoftMax的偏置梯度（如果有）
                if softmax_grad_sids and i < len(softmax_grad_sids):
                    pc.target_var(bias_syn, bias_syn._ref_grad_from_next, softmax_grad_sids[i])
                
                bias_synapses.append(bias_syn)
                bias_factor_list.append(1.0)  # 偏置的阻抗修正系数为1.0
                bias_wrapper_list.append(bias_wrapper)
            else:
                bias_synapses.append(None)
                bias_factor_list.append(None)
                bias_wrapper_list.append(None)
            
            weight_synapses.append(neuron_weight_syns)
            weight_factor_list.append(neuron_weight_factors)
            weight_wrapper_matrix.append(neuron_weight_wrappers)
        
        # 保存结果
        self.weight_synapses = weight_synapses
        self.bias_synapses = bias_synapses
        # 输出层没有gradient_aggregators
        self.weight_impedance_factors = weight_factor_list
        self.bias_impedance_factors = bias_factor_list
        self.weight_synapse_wrappers = weight_wrapper_matrix
        self.bias_synapse_wrappers = bias_wrapper_list
        self.gradient_aggregator_wrappers = []

        self._register_heliox_optimizer_params()

        # 保存维度信息供后续使用（兼容extract_gradients方法）
        self.input_dim = N_prev
        if self.neuron_type == 'passive_hpc' and connection_pattern is not None:
            self.num_proj_dend = connection_pattern.shape[2]
        else:
            self.num_proj_dend = 1  # Point neurons have single connection
        
        return weight_synapses, bias_synapses, []
    
    def is_trainable(self):
        """返回该层是否可训练"""
        return self.trainable
    
    def get_layer_role(self):
        """返回层的角色/类型"""
        return self.layer_role
    
    def get_total_dendrites(self):
        if self.neuron_type == 'passive_hpc' and self.neurons:
            return len(self.neurons[0].dend)
        return 0


class OutputLayer(DenseLayer):
    """
    Output layer for the neural network.
    Special handling for output layer which connects to SoftMax.
    """
    
    def __init__(self, n_neurons=10):
        """
        Initialize the output layer.
        
        Args:
            n_neurons: Number of output neurons (default 10 for digits 0-9)
        """
        super().__init__(
            n_neurons=n_neurons,
            neuron_type='point',
            layer_name='output',
            layer_role='output'
        )
        
        # Output layer doesn't need aggregators
        self.needs_aggregator = False
        
        # Store synapses created by create_backward_synapses
        self.weight_synapses = []
        self.bias_synapses = []
    
    def connect_from_previous(self, prev_layer, pc, id_manager, softmax_grad_sids=None, output2hidden_grad_sids=None):
        """
        Connect from previous layer using the create_backward_synapses method.
        
        Args:
            prev_layer: Previous layer (hidden layer)
            pc: ParallelContext
            id_manager: NetworkIDManager
            softmax_grad_sids: Gradient SIDs from SoftMax
            output2hidden_grad_sids: Gradient SIDs for output to hidden
        
        Returns:
            Tuple of (weight_synapses, bias_synapses)
        """
        self.weight_synapses, self.bias_synapses = self.create_backward_synapses(
            prev_layer, pc, id_manager, softmax_grad_sids, output2hidden_grad_sids
        )
        return self.weight_synapses, self.bias_synapses


class SoftMaxLayer:
    """
    SoftMax layer for the neural network.
    Encapsulates a single mirror neuron with BP_Syn_SoftMax mechanism.
    """
    
    def __init__(self, n_classes=10):
        """
        Initialize the SoftMax layer.
        
        Args:
            n_classes: Number of output classes (default 10 for digits 0-9)
        """
        self.n_classes = n_classes
        self.layer_name = 'softmax'
        self.trainable = False  # SoftMax itself has no trainable parameters
        self.backend = None
        
        # These will be created during initialization phases
        self.mirror_neuron = None
        self.nclist = []
        self.softmax_syn = None  # BP_Syn_SoftMax mechanism
        self.mirror_neuron_wrapper = None
        self.softmax_wrapper = None
        
        # ID management
        self.neuron_gid = None
        self.input_sids = None  # SIDs for receiving input from output layer
        self.output_sids = None  # SIDs for softmax outputs
        self.grad_sids = None  # SIDs for gradient outputs
        
    def set_backend(self, backend):
        """Set the hybrid backend controller for HelioX integration."""
        self.backend = backend

    def phase1_create_neuron(self, pc, id_manager, ihost):
        """
        Phase 1: Create mirror neuron and register GID.

        Args:
            pc: ParallelContext
            id_manager: NetworkIDManager for GID management
            ihost: Host ID for parallel execution
        """
        # Load necessary HOC files
        h.load_file("PointNeuron.hoc")
        
        # Register GID for mirror neuron
        self.neuron_gid = id_manager.register_gid('mirror')
        
        # Create mirror neuron
        pc.set_gid2node(self.neuron_gid, ihost)
        self.mirror_neuron = h.PointNeuron()

        # Create NetCon for voltage recording
        nc = h.NetCon(self.mirror_neuron.soma(0.5)._ref_v, None, sec=self.mirror_neuron.soma)
        self.nclist.append(nc)
        pc.cell(self.neuron_gid, nc)
        
    def phase2_setup_mechanism(self, pc, id_manager):
        """
        Phase 2: Setup BP_Syn_SoftMax mechanism and register SIDs.
        
        Args:
            pc: ParallelContext
            id_manager: NetworkIDManager for SID management
        """
        # Register SIDs for SoftMax
        self.input_sids = id_manager.register_sids_batch('softmax_input', self.n_classes)
        self.output_sids = id_manager.register_sids_batch('softmax_output', self.n_classes)
        self.grad_sids = id_manager.register_sids_batch('softmax_grad', self.n_classes)
        
        # Create BP_Syn_SoftMax mechanism if this process owns the mirror neuron
        if pc.gid_exists(self.neuron_gid):
            self.softmax_syn = h.BP_Syn_SoftMax(self.mirror_neuron.soma(0.5))
            if self.backend and self.backend.enable_heliox:
                self.softmax_wrapper = self.backend.wrap_obj(self.softmax_syn)
            
            # Setup source_var for outputs
            for i in range(self.n_classes):
                # Register SoftMax outputs
                pc.source_var(self.softmax_syn._ref_s[i], self.output_sids[i], 
                            sec=self.mirror_neuron.soma)
                # Register gradient outputs
                pc.source_var(self.softmax_syn._ref_grad_to_prev[i], self.grad_sids[i], 
                            sec=self.mirror_neuron.soma)
    
    def phase3_connect_inputs(self, pc, output_voltage_sids):
        """
        Phase 3: Connect inputs from output layer.
        
        Args:
            pc: ParallelContext
            output_voltage_sids: SIDs from output layer voltages
        """
        if pc.gid_exists(self.neuron_gid) and self.softmax_syn:
            # Connect output layer voltages to SoftMax inputs
            for i in range(self.n_classes):
                pc.target_var(self.softmax_syn, self.softmax_syn._ref_u[i], 
                            output_voltage_sids[i])
    
    def set_target(self, target_class):
        """
        Set the target class for training.
        
        Args:
            target_class: Index of the target class (0-9 for digits)
        """
        if self.softmax_syn:
            for i in range(self.n_classes):
                self.softmax_syn.tgt[i] = 1.0 if i == target_class else 0.0
        if self.backend and self.backend.enable_heliox and self.softmax_wrapper:
            for i in range(self.n_classes):
                self.softmax_wrapper.tgt[i] = 1.0 if i == target_class else 0.0
    
    def set_training_mode(self, is_training=True):
        """
        设置SoftMax层的训练/测试模式。
        """
        lr_start = 20 if is_training else 1e9
        lr_end = 1e9
        
        # SoftMax层只有一个突触
        if self.softmax_syn:
            self.softmax_syn.lr_start = lr_start
            self.softmax_syn.lr_end = lr_end
        if self.backend and self.backend.enable_heliox and self.softmax_wrapper:
            self.softmax_wrapper.lr_start = lr_start
            self.softmax_wrapper.lr_end = lr_end
    
    def get_softmax_outputs(self, backend: str = "auto"):
        """
        Get the softmax outputs.
        
        Returns:
            List of softmax values if available, None otherwise
        """
        backend_mode = (backend or "auto").lower()
        enable_heliox = (
            self.backend is not None
            and self.backend.enable_heliox
            and self.softmax_wrapper is not None
        )

        if backend_mode not in ("auto", "heliox"):
            raise ValueError(f"Unsupported backend '{backend}'.")
        if not enable_heliox:
            raise ValueError("HelioX backend not available for SoftMax outputs.")
        return [self.softmax_wrapper.s[i] for i in range(self.n_classes)]
    
    def compute_backend_differences(self) -> Dict[str, Dict[str, float]]:
        """
        对比SoftMax机制在 HelioX 的关键数组。
        """
        if self.backend is None or not self.backend.enable_heliox:
            return {}
        if self.softmax_syn is None or self.softmax_wrapper is None:
            return {}
        
        differences: Dict[str, Dict[str, float]] = {}
        
        def record_metric(name: str, neuron_values: Optional[List[float]], heliox_values: Optional[List[float]]):
            if not neuron_values or not heliox_values:
                return
            if len(neuron_values) != len(heliox_values):
                return
            arr_neuron = np.asarray(neuron_values, dtype=float)
            arr_heliox = np.asarray(heliox_values, dtype=float)
            if arr_neuron.size == 0:
                return
            delta = np.abs(arr_neuron - arr_heliox)
            differences[name] = {
                "max_abs": float(delta.max()),
                "mean_abs": float(delta.mean())
            }
        
        def extract_array(obj: Any, attr: str, length: int) -> Optional[List[float]]:
            if obj is None or length <= 0:
                return None
            target = getattr(obj, attr, None)
            if target is None:
                return None
            if hasattr(target, "to_list"):
                data = target.to_list()
                if len(data) >= length:
                    return [float(data[i]) for i in range(length)]
            if hasattr(target, "to_numpy"):
                data = target.to_numpy()
                if len(data) >= length:
                    return [float(data[i]) for i in range(length)]
            values: List[float] = []
            for idx in range(length):
                try:
                    values.append(float(target[idx]))
                    continue
                except Exception:
                    pass
                alt_name = f"{attr}_{idx}"
                if hasattr(obj, alt_name):
                    try:
                        values.append(float(getattr(obj, alt_name)))
                        continue
                    except Exception:
                        return None
                try:
                    scalar = float(getattr(obj, attr))
                except Exception:
                    return None
                if length == 1:
                    values.append(scalar)
                    return values
                return None
            return values
        
        attrs = ("u", "s", "grad_to_prev", "tgt")
        for attr in attrs:
            neuron_vals = extract_array(self.softmax_syn, attr, self.n_classes)
            heliox_vals = extract_array(self.softmax_wrapper, attr, self.n_classes)
            record_metric(attr, neuron_vals, heliox_vals)
        
        return differences
    
    def get_gid(self):
        """Return the GID of the mirror neuron."""
        return self.neuron_gid
    
    def get_input_sids(self):
        """Return the input SIDs."""
        return self.input_sids
    
    def get_output_sids(self):
        """Return the output SIDs."""
        return self.output_sids
    
    def get_grad_sids(self):
        """Return the gradient SIDs."""
        return self.grad_sids
    
    def get_gradient_source_count(self):
        """
        Return the number of gradient sources this layer produces.
        SoftMaxLayer produces one gradient per class.
        
        Returns:
            int: Number of gradient sources (equals n_classes for SoftMaxLayer)
        """
        return self.n_classes
    
    def provides_loss_gradient(self):
        """
        Return whether this layer provides loss gradient.
        SoftMaxLayer provides the initial loss gradient.
        
        Returns:
            bool: True for SoftMaxLayer
        """
        return True


class InputLayer:
    """
    Input layer for the neural network.
    Encapsulates point neurons with stimulators (VecStim or NetStim).
    """
    
    def __init__(self, n_neurons=784, use_vecstim=True):
        """
        Initialize the input layer.
        
        Args:
            n_neurons: Number of input neurons (default 784 for MNIST)
            use_vecstim: If True, use VecStim; if False, use NetStim
        """
        self.n_neurons = n_neurons
        self.use_vecstim = use_vecstim
        self.layer_name = 'input'
        self.trainable = False  # Input layer has no trainable parameters
        self.backend = None
        
        # These will be created during initialization phases
        self.neurons = []
        self.nclist = []  # NetCon list for voltage recording
        self.stim_synapses = []  # ExpSyn for receiving stimulation
        self.stimulators = []  # VecStim or NetStim objects
        self.ncstim_list = []  # NetCon list for stimulators
        self.stim_wrappers = []  # ObjWrappers for stimulators (HelioX)
        self._stim_batch_id = None
        self._stim_batch_mode = None
        
        # GID/SID management (will be set during network construction)
        self.neuron_gids = None
        self.stim_gids = None
        self.voltage_sids = None  # Source IDs for voltage output

    def set_backend(self, backend):
        """Attach hybrid backend controller."""
        self.backend = backend

    def post_backend_init(self):
        """Finalize backend-specific initialization for HelioX wrappers."""
        # HelioX manager is responsible for wrapper initialization.
        return

    def phase1_create_neurons(self, pc, id_manager, ihost):
        """
        Phase 1: Create neurons, register GIDs/SIDs, and setup source_var.
        
        Args:
            pc: ParallelContext
            id_manager: NetworkIDManager for GID management
            ihost: Host ID for parallel execution
        """
        # Load necessary HOC files
        h.load_file("PointNeuron.hoc")
        
        # Register GIDs and SIDs
        self.neuron_gids = id_manager.register_gids_batch('input', self.n_neurons)
        self.voltage_sids = id_manager.register_sids_batch('input_voltage', self.n_neurons)
        
        # Create neurons and setup source_var
        for i, (gid, sid) in enumerate(zip(self.neuron_gids, self.voltage_sids)):
            pc.set_gid2node(gid, ihost)
            
            # Create point neuron
            neuron = h.PointNeuron()
            for sec in neuron.all:
                sec.Ra = 100
                sec.cm = 1
                sec.g_pas = 1e-4
            self.neurons.append(neuron)

            # Create NetCon for voltage recording
            nc = h.NetCon(neuron.soma(0.5)._ref_v, None, sec=neuron.soma)
            self.nclist.append(nc)
            pc.cell(gid, nc)
            
            # Setup source_var for voltage output (InputLayer管理自己的source_var)
            if pc.gid_exists(gid):
                pc.source_var(neuron.soma(0.5)._ref_v, sid, sec=neuron.soma)
            
            # Create ExpSyn for receiving stimulation
            syn = h.ExpSyn(neuron.soma(0.5))
            syn.tau = 0.5
            syn.e = 1
            self.stim_synapses.append(syn)
    
    def phase2_create_stimulators(self, pc, id_manager, ihost):
        """
        Phase 2: Create stimulators and connect them to neurons.
        
        Args:
            pc: ParallelContext
            id_manager: NetworkIDManager for GID management
            ihost: Host ID for parallel execution
        """
        # Register stimulator GIDs
        self.stim_gids = id_manager.register_gids_batch('input_stim', self.n_neurons)
        
        # Create stimulators and connections
        for i, (neuron_gid, stim_gid) in enumerate(zip(self.neuron_gids, self.stim_gids)):
            pc.set_gid2node(stim_gid, ihost)
            
            # Create stimulator based on type
            wrapper = None
            if self.use_vecstim:
                stim = h.VecStim()
                # Provide an empty vector so wrappers have valid handles during export
                stim.play(h.Vector())
            else:
                stim = h.NetStim()
                stim.number = 100
            self.stimulators.append(stim)
            if self.backend and self.backend.enable_heliox:
                wrapper = self.backend.wrap_obj(stim)
            self.stim_wrappers.append(wrapper)
            
            # Create NetCon from stimulator to ExpSyn
            if pc.gid_exists(neuron_gid):
                ncstim = h.NetCon(stim, self.stim_synapses[i])
                ncstim.delay = 1
                ncstim.weight[0] = 0.05
                self.ncstim_list.append(ncstim)
            else:
                self.ncstim_list.append(None)

        # Register input stimulators with HelioX for fast batch updates (GPU-side).
        if self.backend and self.backend.enable_heliox:
            manager = getattr(self.backend, "get_manager", lambda: None)()
            if manager is not None and hasattr(manager, "register_input_stimulators"):
                mode = "vecstim" if self.use_vecstim else "netstim"
                batch_id = manager.register_input_stimulators(self.stim_wrappers, mode=mode, owner=self)
                if batch_id is not None:
                    self._stim_batch_id = batch_id
                    self._stim_batch_mode = mode
    
    def set_stim(self, img_data, pc=None):
        """
        设置输入层的刺激模式
        
        Args:
            img_data: 输入数据（例如MNIST图像的像素值）
            pc: ParallelContext（用于检查GID存在性）
        """
        # Fast path: push the whole input vector to HelioX in one call.
        if self.backend and getattr(self.backend, "enable_heliox", False):
            manager = getattr(self.backend, "get_manager", lambda: None)()
            if manager is not None and hasattr(manager, "set_input_stimulus"):
                mode = "vecstim" if self.use_vecstim else "netstim"
                handled = manager.set_input_stimulus(
                    img_data,
                    mode=mode,
                    batch_id=self._stim_batch_id,
                )
                if handled:
                    return

        def set_stim_by_pixel(stim, wrapper, pixel):
            """根据像素值设置刺激器参数"""
            if self.use_vecstim:
                # VecStim：使用向量定义脉冲序列
                import numpy as np
                spk_train = 9. + 5. / (pixel + 0.01) * np.arange(1, 20)
                stim.play(h.Vector(spk_train))
                if wrapper is not None:
                    wrapper.play(spk_train.tolist())
            else:
                # NetStim：设置间隔和开始时间
                stim.interval = 5. / (pixel + 0.01)
                stim.start = 9. + stim.interval
                stim.number = 100
                if wrapper is not None:
                    wrapper.interval = stim.interval
                    wrapper.start = stim.start
                    wrapper.number = stim.number
        
        # 为每个输入神经元设置刺激
        for i, neuron_gid in enumerate(self.neuron_gids):
            # 如果提供了pc，检查GID是否存在于当前进程
            if pc is None or pc.gid_exists(neuron_gid):
                if i < len(img_data) and i < len(self.stimulators):
                    stim = self.stimulators[i]
                    if stim is not None:
                        wrapper = self.stim_wrappers[i] if i < len(self.stim_wrappers) else None
                        set_stim_by_pixel(stim, wrapper, img_data[i])
    
    def set_input_spikes(self, spike_times, input_indices):
        """
        Set spike times for VecStim stimulators.
        
        Args:
            spike_times: List of spike times
            input_indices: List of input neuron indices
        """
        if not self.use_vecstim:
            raise ValueError("set_input_spikes only works with VecStim")
            
        # Clear existing spike times
        for stim in self.stimulators:
            if hasattr(stim, 'play'):
                stim.play(h.Vector())
        
        # Set new spike times
        for spike_time, idx in zip(spike_times, input_indices):
            if idx < len(self.stimulators):
                vec = h.Vector([spike_time])
                self.stimulators[idx].play(vec)
    
    def set_netstim_params(self, start_time, interval=1.0, number=1):
        """
        Set parameters for NetStim stimulators.
        
        Args:
            start_time: When to start stimulation
            interval: Interval between spikes (ms)
            number: Number of spikes
        """
        if self.use_vecstim:
            raise ValueError("set_netstim_params only works with NetStim")
            
        for stim in self.stimulators:
            stim.start = start_time
            stim.interval = interval
            stim.number = number
    
    def get_voltages(self):
        """
        Get current voltages from all neurons.
        
        Returns:
            List of voltages
        """
        return [neuron.soma(0.5).v for neuron in self.neurons]
    
    def get_neuron_count(self):
        """Return the number of neurons in this layer."""
        return self.n_neurons
    
    def get_neurons(self):
        """Return the list of neurons."""
        return self.neurons
    
    def get_gids(self):
        """Return the GIDs of neurons."""
        return self.neuron_gids
    
    def set_training_mode(self, is_training):
        """
        Set training mode for all synapses in this layer.
        InputLayer has ExpSyn type synapses which have no learning parameters, so this method is empty.
        
        Args:
            is_training: True for training mode, False for test mode
        """
        # InputLayer synapses (ExpSyn) do not participate in learning
        pass
    
    def get_voltage_sids(self):
        """Return the SIDs for voltage output."""
        return self.voltage_sids
    
    def get_gradient_source_count(self):
        """
        Return the number of gradient sources this layer produces.
        InputLayer doesn't produce gradients (it's the first layer).
        
        Returns:
            int: 0 for InputLayer
        """
        return 0
    
    def provides_loss_gradient(self):
        """
        Return whether this layer provides loss gradient.
        InputLayer doesn't provide loss gradient.
        
        Returns:
            bool: False for InputLayer
        """
        return False
    def set_backend(self, backend):
        """Attach hybrid backend controller."""
        self.backend = backend
