import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# from spikingjelly.activation_based.neuron import IFNode # Replaced by LIFNeuron

# --- Custom LIF Neuron Node --- (与您上一版本代码一致)
class LIFNeuron(nn.Module):
    def __init__(self, tau_mem: float = 20.0, v_threshold: float = 1.0, v_reset: float | None = 0.0, **kwargs):
        super().__init__()
        self.v_threshold: torch.Tensor | None = None 
        if v_reset is None: self.v_reset_val = 0.0
        else: self.v_reset_val = v_reset
        self.register_buffer('v', torch.tensor(0.0, dtype=torch.float32)) 
        self.tau_mem = float(tau_mem)
        if self.tau_mem <= 0: self.decay_factor = 0.0 
        else: self.decay_factor = math.exp(-1.0 / self.tau_mem)

    def reset(self):
        if isinstance(self.v, torch.Tensor) and self.v.numel() > 0:
            reset_val = self.v_reset_val
            if isinstance(self.v_reset_val, torch.Tensor) and self.v_reset_val.numel() == 1:
                reset_val = self.v_reset_val.item()
            if isinstance(reset_val, torch.Tensor) and reset_val.numel() > 1 and self.v.shape == reset_val.shape:
                 self.v = reset_val.clone().detach().to(self.v.device)
            else: self.v.fill_(float(reset_val))
        elif isinstance(self.v, torch.Tensor) and self.v.numel() == 0: pass
        else: self.v = torch.tensor(float(self.v_reset_val), dtype=torch.float32, device=self.v.device if isinstance(self.v, torch.Tensor) else torch.device("cpu"))

    def forward(self, x_in: torch.Tensor): 
        if not isinstance(self.v, torch.Tensor) or self.v.shape != x_in.shape:
            current_device = x_in.device; reset_val_for_fill = self.v_reset_val
            if isinstance(self.v_reset_val, torch.Tensor) and self.v_reset_val.numel() == 1:
                reset_val_for_fill = self.v_reset_val.item()
            if isinstance(self.v_reset_val, torch.Tensor) and self.v_reset_val.ndim == 1 and self.v_reset_val.numel() == x_in.shape[1]:
                 self.v = self.v_reset_val.clone().detach().to(current_device).unsqueeze(0).expand_as(x_in)
            elif isinstance(self.v_reset_val, torch.Tensor) and self.v_reset_val.shape == x_in.shape:
                 self.v = self.v_reset_val.clone().detach().to(current_device)
            else: self.v = torch.full_like(x_in, float(reset_val_for_fill), device=current_device)
        
        potential_for_spike_check = self.v + x_in 
        if self.v_threshold is None: raise ValueError("LIFNeuron: v_threshold has not been set.")
        if not isinstance(self.v_threshold, torch.Tensor): raise TypeError("LIFNeuron: v_threshold must be a tensor.")
        current_v_threshold = self.v_threshold 
        if self.v_threshold.ndim == 1 and x_in.ndim == 2 and self.v_threshold.shape[0] == x_in.shape[1]:
            current_v_threshold = self.v_threshold.unsqueeze(0) 
        elif self.v_threshold.shape != x_in.shape : 
             if not (current_v_threshold.numel() == x_in.shape[-1] and x_in.ndim >= current_v_threshold.ndim and current_v_threshold.ndim >0):
                 if not (current_v_threshold.ndim == 1 and current_v_threshold.numel() == x_in.shape[-1]): 
                    raise ValueError(f"LIFNeuron: v_threshold shape {self.v_threshold.shape} not directly broadcastable to input shape {x_in.shape}.")
        spikes = (potential_for_spike_check > current_v_threshold).to(x_in.dtype)
        reset_values_tensor = torch.full_like(potential_for_spike_check, float(self.v_reset_val))
        if isinstance(self.v_reset_val, torch.Tensor):
            if self.v_reset_val.numel() == 1: reset_values_tensor.fill_(self.v_reset_val.item())
            elif self.v_reset_val.ndim == 1 and self.v_reset_val.numel() == potential_for_spike_check.shape[1]:
                reset_values_tensor = self.v_reset_val.unsqueeze(0).expand_as(potential_for_spike_check)
            elif self.v_reset_val.shape == potential_for_spike_check.shape: reset_values_tensor = self.v_reset_val
        self.v = torch.where(spikes.bool(), reset_values_tensor, potential_for_spike_check)
        return spikes

# --- Synapse and Soma classes (与您上一版本代码一致) ---
class Synapse(nn.Module): 
    def __init__(self):
        super().__init__()
        self.if_neurons = LIFNeuron(tau_mem=20.0, v_reset=0.0)
        self.g_n_tensor_cache = {} 
        self.FLOAT_FORMAT_PARAMS = {
            torch.float8_e4m3fn: {"name": "E4M3FN", "exp_bits": 4, "man_bits": 3, "bias": 7, "min_norm_coded_exp": 1, "max_norm_coded_exp": 14, "total_bits": 8, "has_subnormals": False, "epsilon": 1e-9},
            torch.float8_e5m2: {"name": "E5M2", "exp_bits": 5, "man_bits": 2, "bias": 15, "min_norm_coded_exp": 1, "max_norm_coded_exp": 30, "total_bits": 8, "has_subnormals": False, "epsilon": 1e-9},
            torch.bfloat16: {"name": "BF16", "exp_bits": 8, "man_bits": 7, "bias": 127, "min_norm_coded_exp": 1, "max_norm_coded_exp": 254, "total_bits": 16, "has_subnormals": True, "epsilon": 1e-9},
            torch.float16: {"name": "FP16", "exp_bits": 5, "man_bits": 10, "bias": 15, "min_norm_coded_exp": 1, "max_norm_coded_exp": 30, "total_bits": 16, "has_subnormals": True, "epsilon": 1e-9},
            torch.float32: {"name": "FP32", "exp_bits": 8, "man_bits": 23, "bias": 127, "min_norm_coded_exp": 1, "max_norm_coded_exp": 254, "total_bits": 32, "has_subnormals": True, "epsilon": 1e-9}
        }
    def get_float_format_params(self, dtype: torch.dtype) -> dict:
        params = self.FLOAT_FORMAT_PARAMS.get(dtype)
        if params is None: raise ValueError(f"Synapse: Unsupported float dtype: {dtype}")
        return params.copy() 
    def get_float_special_pattern(self, float_params: dict, device: torch.device, pattern_type: str = "nan", sign_bit: int = 0) -> torch.Tensor:
        if sign_bit not in [0, 1]: raise ValueError("Sign bit must be 0 or 1.")
        bits = [sign_bit]; bits.extend([1] * float_params["exp_bits"]) 
        if pattern_type.lower() == "nan":
            if float_params["man_bits"] > 0: bits.extend([1] + [0] * (float_params["man_bits"] - 1))
        elif pattern_type.lower() == "inf": bits.extend([0] * float_params["man_bits"])
        else: raise ValueError(f"Unknown special pattern type: {pattern_type}.")
        if len(bits) != float_params["total_bits"]: bits = (bits + [0]*float_params["total_bits"])[:float_params["total_bits"]]
        return torch.tensor(bits, dtype=torch.int8, device=device)
    def _get_g_n_tensor_for_dtype(self, float_params: dict, device: torch.device) -> torch.Tensor:
        format_name = float_params["name"] 
        if format_name in self.g_n_tensor_cache: return self.g_n_tensor_cache[format_name].to(device)
        g_n_list = [0.0] 
        for _ in range(float_params["total_bits"] - 1): g_n_list.append(0.5) # G=0.5 for exp/man bits
        if len(g_n_list) != float_params["total_bits"]: raise ValueError(f"G(N) tensor length incorrect for {format_name}")
        g_n_tensor = torch.tensor(g_n_list, dtype=torch.float32, device=device)
        self.g_n_tensor_cache[format_name] = g_n_tensor
        return g_n_tensor
    def _calculate_f_x_n_batch(self, x_input_flat: torch.Tensor, float_params: dict) -> tuple[torch.Tensor, torch.Tensor]:
        if x_input_flat.numel() == 0: return torch.empty((0, float_params["total_bits"]), device=x_input_flat.device), torch.empty((0,), dtype=torch.bool, device=x_input_flat.device)
        x_f32_batch = x_input_flat.to(torch.float32) if x_input_flat.dtype != torch.float32 else x_input_flat.clone()
        batch_size = x_f32_batch.numel(); device = x_f32_batch.device
        finite_mask = torch.isfinite(x_f32_batch); non_zero_mask = (x_f32_batch != 0.0) 
        initial_processing_mask = finite_mask & non_zero_mask 
        f_x_n_batch_tensors = torch.zeros((batch_size, float_params["total_bits"]), dtype=torch.float32, device=device)
        final_representable_mask = torch.zeros(batch_size, dtype=torch.bool, device=device)
        if torch.any(initial_processing_mask):
            x_to_process = x_f32_batch[initial_processing_mask]; abs_x_to_process = torch.abs(x_to_process)
            idx_initial_processing = torch.where(initial_processing_mask)[0]
            f_x_n_batch_tensors[idx_initial_processing, 0] = -x_to_process 
            tiny_val = torch.finfo(torch.float32).tiny 
            exp_actual_for_x_to_process = torch.floor(torch.log2(abs_x_to_process.clamp(min=tiny_val)))
            coded_exp_for_x_to_process = (exp_actual_for_x_to_process + float_params["bias"]).int()
            is_normal_range_subset = (coded_exp_for_x_to_process >= float_params["min_norm_coded_exp"]) & (coded_exp_for_x_to_process <= float_params["max_norm_coded_exp"])
            is_subnormal_range_subset = (coded_exp_for_x_to_process == 0) & float_params.get("has_subnormals", False) & (abs_x_to_process > 0)
            can_extract_bits_mask_subset = is_normal_range_subset | is_subnormal_range_subset
            if idx_initial_processing.numel() > 0 : 
                valid_indices_in_original_batch = idx_initial_processing[can_extract_bits_mask_subset]
                final_representable_mask[valid_indices_in_original_batch] = True
            if torch.any(can_extract_bits_mask_subset):
                idx_for_bit_assignment = valid_indices_in_original_batch
                x_for_bits_abs = abs_x_to_process[can_extract_bits_mask_subset]
                E_coded_for_bits = coded_exp_for_x_to_process[can_extract_bits_mask_subset]
                E_actual_for_bits = exp_actual_for_x_to_process[can_extract_bits_mask_subset] 
                for N_exp_field_idx in range(float_params["exp_bits"]): 
                    abs_bit_idx = 1 + N_exp_field_idx
                    k_E_from_LSB = (float_params["exp_bits"] - 1) - N_exp_field_idx
                    exp_bit_values = (E_coded_for_bits >> k_E_from_LSB) & 1
                    f_x_n_batch_tensors[idx_for_bit_assignment, abs_bit_idx] = exp_bit_values.float()
                num_mant_bits = float_params["man_bits"]
                if num_mant_bits > 0:
                    mantissa_frac_values = torch.zeros_like(x_for_bits_abs, device=device) # Ensure correct device
                    is_normal_in_final_subset = is_normal_range_subset[can_extract_bits_mask_subset]
                    if torch.any(is_normal_in_final_subset):
                        mantissa_frac_values[is_normal_in_final_subset] = x_for_bits_abs[is_normal_in_final_subset] / (2**E_actual_for_bits[is_normal_in_final_subset]) - 1.0
                    is_subnormal_in_final_subset = is_subnormal_range_subset[can_extract_bits_mask_subset]
                    if torch.any(is_subnormal_in_final_subset):
                        min_normal_exp_actual = 1 - float_params["bias"]
                        mantissa_frac_values[is_subnormal_in_final_subset] = x_for_bits_abs[is_subnormal_in_final_subset] / (2**min_normal_exp_actual)
                    scaled_mantissa_f64 = mantissa_frac_values.to(torch.float64) * (2.0**num_mant_bits)
                    mantissa_int_repr = torch.floor(scaled_mantissa_f64).to(torch.int64)
                    mantissa_int_repr = mantissa_int_repr.clamp(min=0, max=(1 << num_mant_bits) - 1)
                    for N_man_field_idx in range(num_mant_bits): 
                        abs_bit_idx = 1 + float_params["exp_bits"] + N_man_field_idx
                        k_M_from_LSB = (num_mant_bits - 1) - N_man_field_idx
                        man_bit_values = (mantissa_int_repr >> k_M_from_LSB) & 1
                        f_x_n_batch_tensors[idx_for_bit_assignment, abs_bit_idx] = man_bit_values.float()
        return f_x_n_batch_tensors, final_representable_mask
    def forward(self, x_input_tensor: torch.Tensor) -> torch.Tensor: 
        original_numerical_shape = x_input_tensor.shape 
        num_elements = x_input_tensor.numel(); device = x_input_tensor.device
        float_params = self.get_float_format_params(x_input_tensor.dtype) 
        n_bits = float_params["total_bits"]
        if num_elements == 0: return torch.empty((n_bits,) + original_numerical_shape, dtype=torch.int8, device=device)
        self.if_neurons.reset()
        f_x_n_batch, representable_mask = self._calculate_f_x_n_batch(x_input_tensor.contiguous().view(-1), float_params)
        g_n_tensor_for_dtype = self._get_g_n_tensor_for_dtype(float_params, device)
        self.if_neurons.v_threshold = g_n_tensor_for_dtype 
        spikes = self.if_neurons(f_x_n_batch) 
        output_bit_cols = []
        for i in range(num_elements):
            if representable_mask[i]: output_bit_cols.append(spikes[i].to(torch.int8)) 
            else: 
                original_value_f32 = x_input_tensor.view(-1)[i].to(torch.float32)
                current_sign_bit = 1 if torch.signbit(original_value_f32) else 0 
                if original_value_f32 == 0.0:
                    zero_pattern = torch.zeros(n_bits, dtype=torch.int8, device=device)
                    if current_sign_bit == 1 : zero_pattern[0] = 1 
                    output_bit_cols.append(zero_pattern)
                else: output_bit_cols.append(self.get_float_special_pattern(float_params, device, pattern_type="nan", sign_bit=current_sign_bit))
        if not output_bit_cols: return torch.empty((n_bits,) + original_numerical_shape, dtype=torch.int8, device=device)
        stacked_bits = torch.stack(output_bit_cols, dim=0); bits_flat_transposed = stacked_bits.transpose(0, 1)
        return bits_flat_transposed.view(n_bits, *original_numerical_shape)

class Soma(nn.Module): 
    def __init__(self):
        super().__init__()
        self.FLOAT_FORMAT_PARAMS = {
             torch.float8_e4m3fn: {"name": "E4M3FN", "exp_bits": 4, "man_bits": 3, "bias": 7, "min_norm_coded_exp": 1, "max_norm_coded_exp": 14, "total_bits": 8, "has_subnormals": False},
             torch.float8_e5m2: {"name": "E5M2", "exp_bits": 5, "man_bits": 2, "bias": 15, "min_norm_coded_exp": 1, "max_norm_coded_exp": 30, "total_bits": 8, "has_subnormals": False},
             torch.bfloat16: {"name": "BF16", "exp_bits": 8, "man_bits": 7, "bias": 127, "min_norm_coded_exp": 1, "max_norm_coded_exp": 254, "total_bits": 16, "has_subnormals": True},
             torch.float16: {"name": "FP16", "exp_bits": 5, "man_bits": 10, "bias": 15, "min_norm_coded_exp": 1, "max_norm_coded_exp": 30, "total_bits": 16, "has_subnormals": True},
             torch.float32: {"name": "FP32", "exp_bits": 8, "man_bits": 23, "bias": 127, "min_norm_coded_exp": 1, "max_norm_coded_exp": 254, "total_bits": 32, "has_subnormals": True}
        }
    def get_float_format_params(self, dtype: torch.dtype) -> dict:
        params = self.FLOAT_FORMAT_PARAMS.get(dtype)
        if params is None: raise ValueError(f"Soma: Unsupported float dtype: {dtype}")
        return params.copy()
    def get_float_special_pattern(self, float_params: dict, device: torch.device, pattern_type: str = "nan", sign_bit: int = 0) -> torch.Tensor:
        if sign_bit not in [0, 1]: raise ValueError("Sign bit must be 0 or 1.")
        bits = [sign_bit]; bits.extend([1] * float_params["exp_bits"]) 
        if pattern_type.lower() == "nan":
            if float_params["man_bits"] > 0: bits.extend([1] + [0] * (float_params["man_bits"] - 1))
        elif pattern_type.lower() == "inf": bits.extend([0] * float_params["man_bits"])
        else: raise ValueError(f"Unknown special pattern type: {pattern_type}.")
        if len(bits) != float_params["total_bits"]: bits = (bits + [0]*float_params["total_bits"])[:float_params["total_bits"]]
        return torch.tensor(bits, dtype=torch.int8, device=device)
    def forward(self, bit_tensor: torch.Tensor, target_dtype: torch.dtype, device: str = "cpu") -> torch.Tensor:
        output_device = torch.device(device); float_params = self.get_float_format_params(target_dtype)
        n_bits = float_params["total_bits"]
        if bit_tensor.shape[0] != n_bits: raise ValueError(f"Input bit_tensor's first dim must be {n_bits}, got {bit_tensor.shape[0]}")
        original_numerical_shape = bit_tensor.shape[1:]
        bit_tensor_flat_cols = bit_tensor.reshape(n_bits, -1); num_values = bit_tensor_flat_cols.shape[1]
        if num_values == 0: return torch.empty(original_numerical_shape, dtype=target_dtype, device=output_device)
        reconstructed_floats_list = []
        nan_pattern_s0 = self.get_float_special_pattern(float_params, bit_tensor.device, "nan", 0)
        nan_pattern_s1 = self.get_float_special_pattern(float_params, bit_tensor.device, "nan", 1)
        pos_inf_pattern = self.get_float_special_pattern(float_params, bit_tensor.device, "inf", 0)
        neg_inf_pattern = self.get_float_special_pattern(float_params, bit_tensor.device, "inf", 1)
        pos_zero_pattern = torch.zeros(n_bits, dtype=torch.int8, device=bit_tensor.device)
        neg_zero_pattern = pos_zero_pattern.clone(); 
        if n_bits > 0: neg_zero_pattern[0] = 1
        exp_all_ones_coded = (1 << float_params["exp_bits"]) - 1
        for i in range(num_values):
            bits_for_val_i_tensor = bit_tensor_flat_cols[:, i].int(); val = 0.0
            if torch.equal(bits_for_val_i_tensor, nan_pattern_s0) or torch.equal(bits_for_val_i_tensor, nan_pattern_s1): val = float('nan')
            elif torch.equal(bits_for_val_i_tensor, pos_inf_pattern): val = float('inf')
            elif torch.equal(bits_for_val_i_tensor, neg_inf_pattern): val = float('-inf')
            elif torch.equal(bits_for_val_i_tensor, pos_zero_pattern): val = 0.0
            elif torch.equal(bits_for_val_i_tensor, neg_zero_pattern): val = -0.0
            else:
                binary_str = "".join(map(str, bits_for_val_i_tensor.tolist())); sign_val = -1.0 if binary_str[0] == '1' else 1.0
                exp_start_idx = 1; exp_end_idx = 1 + float_params["exp_bits"]; coded_exp_str = binary_str[exp_start_idx:exp_end_idx]
                try: coded_exp_val = int(coded_exp_str, 2)
                except ValueError: val = float('nan')
                else:
                    man_start_idx = exp_end_idx; mantissa_str = binary_str[man_start_idx:]; mantissa_val_frac = 0.0
                    for idx, bit_char in enumerate(mantissa_str):
                        if bit_char == '1': mantissa_val_frac += 2**(-(idx + 1))
                    if coded_exp_val == 0: 
                        if float_params.get("has_subnormals", False) and mantissa_val_frac != 0.0 : 
                            min_normal_exp_actual = 1 - float_params["bias"] 
                            val = sign_val * mantissa_val_frac * (2**min_normal_exp_actual)
                        elif mantissa_val_frac == 0.0: val = sign_val * 0.0 
                        else: val = sign_val * 0.0 
                    elif float_params["min_norm_coded_exp"] <= coded_exp_val <= float_params["max_norm_coded_exp"]: 
                        actual_exp = coded_exp_val - float_params["bias"]
                        val = sign_val * (1.0 + mantissa_val_frac) * (2**actual_exp)
                    elif coded_exp_val == exp_all_ones_coded and mantissa_val_frac != 0.0 : val = float('nan')
                    elif coded_exp_val == exp_all_ones_coded and mantissa_val_frac == 0.0 : val = float('inf') if sign_val == 1.0 else float('-inf') # Should be caught by pattern
                    else: val = float('nan') 
            try: reconstructed_floats_list.append(val)
            except OverflowError: reconstructed_floats_list.append(float('inf') if val > 0 else float('-inf'))
        f32_tensor_flat = torch.tensor(reconstructed_floats_list, dtype=torch.float32, device=output_device)
        output_tensor_flat = f32_tensor_flat.to(target_dtype)
        return output_tensor_flat.view(original_numerical_shape)

# --- SNNLayerBase and SNN Layers (Simplified to use new Synapse/Soma) ---
class SNNLayerBase(nn.Module):
    def __init__(self, float_dtype: torch.dtype, device_str: str = "cpu"):
        super().__init__()
        self.float_dtype = float_dtype
        self.device = torch.device(device_str)
        self.soma = Soma()
        self.synapse = Synapse()
        float_params = self.soma.get_float_format_params(float_dtype) 
        self.num_bits_in_float = float_params["total_bits"]

class LinearLayerSNN(SNNLayerBase):
    def __init__(self, pytorch_linear_layer: torch.nn.Linear, float_dtype: torch.dtype, device: str = "cpu"):
        super().__init__(float_dtype, device)
        self.pytorch_layer = pytorch_linear_layer.to(device=self.device) 
        self.in_features = pytorch_linear_layer.in_features
        self.out_features = pytorch_linear_layer.out_features
        
    def forward(self, x_bit_tensor_input: torch.Tensor) -> torch.Tensor:
        if x_bit_tensor_input.shape[0] != self.num_bits_in_float:
            raise ValueError(f"Input bit tensor's first dim must be {self.num_bits_in_float}.")
        # Input bit tensor: (N_bits, *BatchDims, InFeatures)
        # Last dim of numerical shape is InFeatures
        if x_bit_tensor_input.shape[-1] != self.in_features: 
            raise ValueError(f"Input bit tensor's last effective numerical feature dimension "
                             f"({x_bit_tensor_input.shape[-1]}) must match layer in_features ({self.in_features}).")

        # Soma takes (N_bits, *NumericalShape) and returns (*NumericalShape)
        x_numerical_structured = self.soma.forward(x_bit_tensor_input, self.float_dtype, str(self.device))
        
        # Reshape for nn.Linear: (Batch_Combined, InFeatures)
        x_for_linear_op = x_numerical_structured.reshape(-1, self.in_features)
        
        # Use the wrapped PyTorch linear layer
        # Input to pytorch_layer should match its weight dtype or be float32
        y_op_output = self.pytorch_layer(x_for_linear_op.to(self.pytorch_layer.weight.dtype))
        
        # Quantize result back to target float_dtype
        y_quantized = y_op_output.to(self.float_dtype) 
        # y_quantized shape: (Batch_Combined, OutFeatures)
        
        # Reshape numerical output to its conceptual multi-dimensional shape before Synapse
        # This shape should be (*BatchDims from input, OutFeatures)
        original_batch_star_shape = x_bit_tensor_input.shape[1:-1] # These are D1...Dk
        target_numerical_output_shape = original_batch_star_shape + (self.out_features,)
        y_quantized_structured = y_quantized.view(target_numerical_output_shape)
        
        # Synapse takes (*NumericalShape) and returns (N_bits, *NumericalShape)
        return self.synapse.forward(y_quantized_structured)

class ConvLayerSNN(SNNLayerBase):
    def __init__(self, pytorch_conv_layer: torch.nn.Conv2d, float_dtype: torch.dtype, device: str = "cpu"):
        super().__init__(float_dtype, device)
        self.pytorch_layer = pytorch_conv_layer.to(device=self.device)
        self.in_channels = pytorch_conv_layer.in_channels
        # Out_channels not strictly needed here if pytorch_layer output shape is used by Synapse

    def forward(self, x_bit_tensor_input: torch.Tensor) -> torch.Tensor:
        # x_bit_tensor_input shape: (N_bits, N_batch, C_in, H_in, W_in)
        if x_bit_tensor_input.shape[0] != self.num_bits_in_float:
            raise ValueError(f"Input bit tensor's first dim must be {self.num_bits_in_float}.")
        if x_bit_tensor_input.shape[2] != self.in_channels: # C_in is at index 2
            raise ValueError(f"Input bit tensor's C_in dim (index 2) must match layer in_channels.")

        # Soma input: (N_bits, N, Cin, Hin, Win), output: (N, Cin, Hin, Win)
        x_numerical_structured = self.soma.forward(x_bit_tensor_input, self.float_dtype, str(self.device))
        
        # Use wrapped PyTorch conv layer
        input_for_pytorch_conv = x_numerical_structured.to(self.pytorch_layer.weight.dtype)
        y_op_output = self.pytorch_layer(input_for_pytorch_conv)
        # y_op_output shape: (N, Cout, Hout, Wout)
        
        y_quantized = y_op_output.to(self.float_dtype)
        
        # Synapse input: (N, Cout, Hout, Wout), output: (N_bits, N, Cout, Hout, Wout)
        return self.synapse.forward(y_quantized)

class ActivationLayerSNN(SNNLayerBase):
    def __init__(self, activation_module_instance: torch.nn.Module, 
                 float_dtype: torch.dtype, device: str = "cpu"):
        super().__init__(float_dtype, device)
        self.activation_module = activation_module_instance.to(self.device) 

    def forward(self, x_bit_tensor_input: torch.Tensor) -> torch.Tensor:
        # x_bit_tensor_input 形状: (N_bits, *NumericalDims)
        if x_bit_tensor_input.shape[0] != self.num_bits_in_float:
            raise ValueError(f"Input bit tensor's first dim must be {self.num_bits_in_float}.")

        # Soma input: (N_bits, *NumDims), output: (*NumDims)
        x_numerical_structured = self.soma.forward(x_bit_tensor_input, self.float_dtype, str(self.device))
        
        # Activation functions usually expect float32 for broader compatibility,
        # or operate on the dtype of input if supported.
        # Casting to weight.dtype isn't applicable for activations.
        # Let's use float32 for activation input for robustness.
        activated_f32 = self.activation_module(x_numerical_structured.to(torch.float32))
        activated_quantized = activated_f32.to(self.float_dtype)
        
        # Synapse input: (*NumDims), output: (N_bits, *NumDims)
        return self.synapse.forward(activated_quantized)

class EmbeddingLayerSNN(SNNLayerBase):
    def __init__(self, pytorch_embedding_layer: torch.nn.Embedding, 
                 float_dtype: torch.dtype, device: str = "cpu"):
        super().__init__(float_dtype, device) 
        # Embedding weights should be in self.float_dtype for consistency of output precision
        self.pytorch_layer = pytorch_embedding_layer.to(device=self.device)
        self.pytorch_layer.weight.data = self.pytorch_layer.weight.data.to(float_dtype) # Cast weights
        self.embedding_dim = pytorch_embedding_layer.embedding_dim
        
    def forward(self, x_indices: torch.Tensor) -> torch.Tensor: 
        # Input x_indices is integer indices, NOT a bit tensor
        if x_indices.dtype not in [torch.long, torch.int, torch.int64, torch.int32]:
             print(f"Warning: EmbeddingLayerSNN input x_indices type {x_indices.dtype}, typical is Long or Int.")
        
        # Embedding lookup uses its own weights (now self.float_dtype)
        # Output will be self.float_dtype
        embedding_output = self.pytorch_layer(x_indices)
        
        # Synapse input: (*x_indices.shape, embedding_dim) of self.float_dtype
        # Output: (N_bits, *x_indices.shape, embedding_dim)
        return self.synapse.forward(embedding_output)
    
if __name__ == "__main__":
    print(f"PyTorch version: {torch.__version__}")
    
    # 检查设备和浮点类型支持
    device_name = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device_name)
    print(f"Using device: {device}")
    
    # 确定可测试的浮点类型
    float_dtypes_to_test = []
    if hasattr(torch, 'bfloat16'): float_dtypes_to_test.append(torch.bfloat16)
    if hasattr(torch, 'float16'): float_dtypes_to_test.append(torch.float16)
    if hasattr(torch, 'float32'): float_dtypes_to_test.append(torch.float32)
    
    if not float_dtypes_to_test:
        print("没有可测试的浮点类型 (BF16, FP16, FP32)。测试终止。")
        exit()

    # 初始化测试组件
    test_data_synapse = Synapse() 
    test_data_soma = Soma()      

    # 定义不同精度的容差
    TOLERANCES = {
        torch.float32: {"atol": 1e-6, "rtol": 1e-5}, 
        torch.bfloat16: {"atol": 3e-2, "rtol": 3e-1},
        torch.float16: {"atol": 1e-3, "rtol": 1e-2},
    }
    DEFAULT_TOLERANCE = {"atol": 1e-5, "rtol": 1e-3}
    overall_summary_passed = True 

    # --- Test LinearLayerSNN ---
    print("\n--- 开始测试 LinearLayerSNN (包装标准 nn.Linear) ---")
    for float_dtype in float_dtypes_to_test:
        float_params_l = test_data_soma.get_float_format_params(float_dtype) 
        N_BITS_L = float_params_l["total_bits"]
        current_tolerances = TOLERANCES.get(float_dtype, DEFAULT_TOLERANCE)
        print(f"  --- 测试格式: {float_params_l['name']} (dtype: {float_dtype}, Bits: {N_BITS_L}, atol={current_tolerances['atol']:.1e}, rtol={current_tolerances['rtol']:.1e}) ---")

        in_features_test, out_features_test = 4, 2
        pytorch_linear_orig = torch.nn.Linear(in_features_test, out_features_test, device=device)
        with torch.no_grad():
            pytorch_linear_orig.weight.data.uniform_(-0.5, 0.5) 
            if pytorch_linear_orig.bias is not None:
                pytorch_linear_orig.bias.data.uniform_(-0.5, 0.5)
        snn_linear_layer = LinearLayerSNN(pytorch_linear_orig, float_dtype, device_name)

        for conceptual_shape_no_bits in [(in_features_test,), (3,in_features_test), (2,1,in_features_test)]:
            # 准备SNN层的比特输入
            source_values_tensor = (torch.randn(conceptual_shape_no_bits, device=device)*2).to(float_dtype)
            layer_input_bit_tensor = test_data_synapse.forward(source_values_tensor)
            
            print(f"    测试 LinearLayerSNN - 输入数值形状 {conceptual_shape_no_bits} (比特输入形状: {layer_input_bit_tensor.shape})")
            
            output_bit_tensor = snn_linear_layer.forward(layer_input_bit_tensor)
            reconstructed_snn_output_numerical = test_data_soma.forward(output_bit_tensor, float_dtype, device_name)
            
            # 预期输出
            x_for_pytorch_linear = source_values_tensor.reshape(-1, in_features_test) 
            y_expected_f32 = pytorch_linear_orig(x_for_pytorch_linear.to(torch.float32))
            y_expected_quantized = y_expected_f32.to(float_dtype)
            expected_numerical_output_shape = conceptual_shape_no_bits[:-1] + (out_features_test,)
            y_expected_structured = y_expected_quantized.view(expected_numerical_output_shape)

            # 比较结果
            if reconstructed_snn_output_numerical.shape == y_expected_structured.shape:
                num_elements = y_expected_structured.numel()
                if num_elements > 0:
                    expected_is_nan = torch.isnan(y_expected_structured)
                    recon_is_nan = torch.isnan(reconstructed_snn_output_numerical)
                    nan_match = (expected_is_nan == recon_is_nan)
                    finite_comparison = torch.isclose(
                        reconstructed_snn_output_numerical.to(torch.float32), 
                        y_expected_structured.to(torch.float32),
                        atol=current_tolerances["atol"], 
                        rtol=current_tolerances["rtol"], 
                        equal_nan=False
                    )
                    correct_elements = (finite_comparison & ~expected_is_nan & ~recon_is_nan) | nan_match
                    num_correct = torch.sum(correct_elements).item()
                    accuracy = (num_correct / num_elements) * 100
                else:
                    accuracy = 100.0
                
                test_desc = f"LinearLayerSNN - {float_params_l['name']} - 输入 {conceptual_shape_no_bits}"
                if accuracy >= 99.999: 
                    print(f"      {test_desc}: 无误差或在容差内 (准确率: {accuracy:.2f}%)")
                else:
                    overall_summary_passed = False
                    print(f"    {test_desc}: 比较结果: 总元素: {num_elements}, 匹配数: {num_correct}, 准确率: {accuracy:.2f}%")
            else: 
                overall_summary_passed = False
                print(f"    错误 (LinearLayerSNN): 输出形状 ({reconstructed_snn_output_numerical.shape}) 与预期 ({y_expected_structured.shape}) 不匹配")
            print("-" * 30)
        print("-" * 40)

    # --- Test ConvLayerSNN ---
    print("\n--- 开始测试 ConvLayerSNN (包装标准 nn.Conv2d) ---")
    for float_dtype in float_dtypes_to_test:
        float_params_c = test_data_synapse.get_float_format_params(float_dtype) 
        N_BITS_C = float_params_c["total_bits"]
        current_tolerances = TOLERANCES.get(float_dtype, DEFAULT_TOLERANCE)
        print(f"  --- 测试格式: {float_params_c['name']} (dtype: {float_dtype}, Bits: {N_BITS_C}, atol={current_tolerances['atol']:.1e}, rtol={current_tolerances['rtol']:.1e}) ---")
        
        batch_test_conv, cin_test_conv, cout_test_conv = 2, 3, 2
        h_test, w_test, k_test = 5, 5, 3
        pytorch_conv_orig = torch.nn.Conv2d(cin_test_conv, cout_test_conv, kernel_size=k_test, padding=(k_test//2), device=device)
        
        with torch.no_grad(): 
            pytorch_conv_orig.weight.data.uniform_(-0.2, 0.2)
            if pytorch_conv_orig.bias is not None: 
                pytorch_conv_orig.bias.data.uniform_(-0.2, 0.2)
                
        snn_conv_layer = ConvLayerSNN(pytorch_conv_orig, float_dtype, device_name)
        
        conceptual_numerical_shape_conv = (batch_test_conv, cin_test_conv, h_test, w_test)
        print(f"    测试 ConvLayerSNN 输入数值形状 {conceptual_numerical_shape_conv}")
        
        source_values_tensor = (torch.randn(conceptual_numerical_shape_conv, device=device)*0.5).to(float_dtype)
        layer_input_bit_tensor = test_data_synapse.forward(source_values_tensor)
            
        output_bit_tensor = snn_conv_layer.forward(layer_input_bit_tensor)
        reconstructed_snn_output_numerical = test_data_soma.forward(output_bit_tensor, float_dtype, device_name)
        
        y_expected_f32 = pytorch_conv_orig(source_values_tensor.to(torch.float32))
        y_expected_quantized = y_expected_f32.to(float_dtype)
        
        if reconstructed_snn_output_numerical.shape == y_expected_quantized.shape:
            num_elements = y_expected_quantized.numel()
            if num_elements > 0:
                expected_is_nan = torch.isnan(y_expected_quantized)
                recon_is_nan = torch.isnan(reconstructed_snn_output_numerical)
                nan_match = (expected_is_nan == recon_is_nan)
                finite_comparison = torch.isclose(
                    reconstructed_snn_output_numerical.to(torch.float32), 
                    y_expected_quantized.to(torch.float32), 
                    atol=current_tolerances["atol"], 
                    rtol=current_tolerances["rtol"], 
                    equal_nan=False
                )
                correct_elements = (finite_comparison & ~expected_is_nan & ~recon_is_nan) | nan_match
                num_correct = torch.sum(correct_elements).item()
                accuracy = (num_correct / num_elements) * 100
            else:
                accuracy = 100.0
            
            test_desc = f"ConvLayerSNN - {float_params_c['name']} - 输入 {conceptual_numerical_shape_conv}"
            if accuracy >= 99.999: 
                print(f"      {test_desc}: 无误差或在容差内 (准确率: {accuracy:.2f}%)")
            else:
                overall_summary_passed = False
                print(f"    {test_desc}: 比较结果: 总元素: {num_elements}, 匹配数: {num_correct}, 准确率: {accuracy:.2f}%")
        else:
            overall_summary_passed = False
            print(f"    错误 (ConvLayerSNN): 输出形状 ({reconstructed_snn_output_numerical.shape}) 与预期 ({y_expected_quantized.shape}) 不匹配")
        print("-" * 40)
    print("\n")

    # --- Test ActivationLayerSNN ---
    print("\n--- 开始测试 ActivationLayerSNN ---")
    activation_modules_to_test = {
        "ReLU": torch.nn.ReLU(), 
        "Sigmoid": torch.nn.Sigmoid(), 
        "SiLU": torch.nn.SiLU(), 
        "Tanh": torch.nn.Tanh()
    }
    
    for float_dtype in float_dtypes_to_test:
        float_params_a = test_data_soma.get_float_format_params(float_dtype)
        N_BITS_A = float_params_a["total_bits"]
        current_tolerances_act = TOLERANCES.get(float_dtype, DEFAULT_TOLERANCE)
        print(f"  --- 测试格式: {float_params_a['name']} (dtype: {float_dtype}, Bits: {N_BITS_A}, atol={current_tolerances_act['atol']:.1e}, rtol={current_tolerances_act['rtol']:.1e}) ---")
        
        for act_name, act_module_instance in activation_modules_to_test.items():
            snn_activation_layer = ActivationLayerSNN(act_module_instance, float_dtype, device_name)
            features_act_test = 6
            batch_act_test = 2
            conceptual_numerical_shape_act = (batch_act_test, features_act_test)
            
            print(f"    测试 ActivationLayerSNN ({act_name}) 输入数值形状 {conceptual_numerical_shape_act}")
            source_values_tensor_act = (torch.randn(conceptual_numerical_shape_act, device=device) * 3 - 1).to(float_dtype)
            layer_input_bit_tensor_act = test_data_synapse.forward(source_values_tensor_act)
            
            output_bit_tensor_act = snn_activation_layer.forward(layer_input_bit_tensor_act)
            if output_bit_tensor_act.shape != layer_input_bit_tensor_act.shape:
                overall_summary_passed = False
                print(f"      错误: ActivationLayerSNN ({act_name}) - 输出比特形状 {output_bit_tensor_act.shape} 与输入比特形状 {layer_input_bit_tensor_act.shape} 不符")
                continue
            
            reconstructed_snn_output_numerical_act = test_data_soma.forward(output_bit_tensor_act, float_dtype, device_name)
            y_expected_f32_act = act_module_instance(source_values_tensor_act.to(torch.float32))
            y_expected_quantized_act = y_expected_f32_act.to(float_dtype)
            
            if reconstructed_snn_output_numerical_act.shape == y_expected_quantized_act.shape:
                num_elements_act = y_expected_quantized_act.numel()
                if num_elements_act > 0:
                    expected_is_nan = torch.isnan(y_expected_quantized_act)
                    recon_is_nan = torch.isnan(reconstructed_snn_output_numerical_act)
                    nan_match = (expected_is_nan == recon_is_nan)
                    finite_comparison = torch.isclose(
                        reconstructed_snn_output_numerical_act.to(torch.float32), 
                        y_expected_quantized_act.to(torch.float32), 
                        atol=current_tolerances_act["atol"], 
                        rtol=current_tolerances_act["rtol"], 
                        equal_nan=False
                    )
                    correct_elements = (finite_comparison & ~expected_is_nan & ~recon_is_nan) | nan_match
                    num_correct_act = torch.sum(correct_elements).item()
                    accuracy_act = (num_correct_act / num_elements_act) * 100
                else:
                    accuracy_act = 100.0
                
                test_desc_act = f"ActivationLayerSNN ({act_name}) - {float_params_a['name']} - 输入 {conceptual_numerical_shape_act}"
                if accuracy_act >= 99.999:
                    print(f"      {test_desc_act}: 无误差或在容差内 (准确率: {accuracy_act:.2f}%)")
                else:
                    overall_summary_passed = False
                    print(f"    {test_desc_act}: 比较结果: 总元素: {num_elements_act}, 匹配数: {num_correct_act}, 准确率: {accuracy_act:.2f}%")
            else:
                overall_summary_passed = False
                print(f"    错误 (ActivationLayerSNN - {act_name}): 输出形状 ({reconstructed_snn_output_numerical_act.shape}) 与预期 ({y_expected_quantized_act.shape}) 不匹配")
            print("-" * 30) 
        print("-" * 40) 

    # --- Test EmbeddingLayerSNN ---
    print("\n--- 开始测试 EmbeddingLayerSNN (包装标准 nn.Embedding) ---")
    for float_dtype in float_dtypes_to_test:
        float_params_e = test_data_soma.get_float_format_params(float_dtype) 
        N_BITS_E = float_params_e["total_bits"]
        current_tolerances_emb = TOLERANCES.get(float_dtype, DEFAULT_TOLERANCE)
        print(f"  --- 测试格式: {float_params_e['name']} (dtype: {float_dtype}, Bits: {N_BITS_E}, atol={current_tolerances_emb['atol']:.1e}, rtol={current_tolerances_emb['rtol']:.1e}) ---")
        
        num_embeddings_test, embedding_dim_test = 10, 4 
        pytorch_embedding = torch.nn.Embedding(num_embeddings_test, embedding_dim_test, device=device)
        with torch.no_grad(): 
            pytorch_embedding.weight.data.uniform_(-0.1, 0.1) 
        
        snn_embedding_layer = EmbeddingLayerSNN(pytorch_embedding, float_dtype, device_name)
        
        for input_indices_shape in [(5,), (2,3), (num_embeddings_test,)]:
            print(f"    测试 EmbeddingLayerSNN 输入索引形状 {input_indices_shape}")
            x_indices_input = torch.randint(0, num_embeddings_test, input_indices_shape, device=device, dtype=torch.long)
            output_bit_tensor = snn_embedding_layer.forward(x_indices_input)
            reconstructed_snn_output_numerical = test_data_soma.forward(output_bit_tensor, float_dtype, device_name)
            y_expected_numerical = pytorch_embedding(x_indices_input) 
            y_expected_quantized = y_expected_numerical.to(float_dtype) 
            
            if reconstructed_snn_output_numerical.shape == y_expected_quantized.shape:
                num_elements = y_expected_quantized.numel()
                if num_elements > 0:
                    expected_is_nan = torch.isnan(y_expected_quantized)
                    recon_is_nan = torch.isnan(reconstructed_snn_output_numerical)
                    nan_match = (expected_is_nan == recon_is_nan)
                    finite_comparison = torch.isclose(
                        reconstructed_snn_output_numerical.to(torch.float32), 
                        y_expected_quantized.to(torch.float32), 
                        atol=current_tolerances_emb["atol"], 
                        rtol=current_tolerances_emb["rtol"], 
                        equal_nan=False
                    )
                    correct_elements = (finite_comparison & ~expected_is_nan & ~recon_is_nan) | nan_match
                    num_correct = torch.sum(correct_elements).item()
                    accuracy = (num_correct / num_elements) * 100
                else:
                    accuracy = 100.0
                
                test_desc_emb = f"EmbeddingLayerSNN - {float_params_e['name']} - 输入索引形状 {input_indices_shape}"
                if accuracy >= 99.999:
                    print(f"      {test_desc_emb}: 无误差或在容差内 (准确率: {accuracy:.2f}%)")
                else:
                    overall_summary_passed = False
                    print(f"    {test_desc_emb}: 比较结果: 总元素: {num_elements}, 匹配数: {num_correct}, 准确率: {accuracy:.2f}%")
            else:
                overall_summary_passed = False
                print(f"    错误 (EmbeddingLayerSNN): 输出形状 ({reconstructed_snn_output_numerical.shape}) 与预期 ({y_expected_quantized.shape}) 不匹配")
            print("-" * 30)
        print("-" * 40)
    print("\n")

    if overall_summary_passed:
        print("所有已执行的SNN层测试均通过（或在容差范围内）。")
    else:
        print("部分SNN层测试存在不匹配或较大差异，请检查上述详细输出以分析原因。")
    
    print("所有层测试完成。")