import math
import torch
import torch.nn.functional as F
import torch.nn as nn

class Soma(nn.Module): 
    def __init__(self):
        super().__init__()
        self.FLOAT_FORMAT_PARAMS = {
             torch.float8_e4m3fn: {"name": "E4M3FN", "exp_bits": 4, "man_bits": 3, "bias": 7, "min_norm_coded_exp": 1, "max_norm_coded_exp": 14, "total_bits": 8, "has_subnormals": False},
             torch.float8_e5m2: {"name": "E5M2", "exp_bits": 5, "man_bits": 2, "bias": 15, "min_norm_coded_exp": 1, "max_norm_coded_exp": 30, "total_bits": 8, "has_subnormals": False},
             torch.bfloat16: {"name": "BF16", "exp_bits": 8, "man_bits": 7, "bias": 127, "min_norm_coded_exp": 1, "max_norm_coded_exp": 254, "total_bits": 16, "has_subnormals": True},
             torch.float16: {"name": "FP16", "exp_bits": 5, "man_bits": 10, "bias": 15, "min_norm_coded_exp": 1, "max_norm_coded_exp": 30, "total_bits": 16, "has_subnormals": True},
             torch.float32: {"name": "FP32", "exp_bits": 8, "man_bits": 23, "bias": 127, "min_norm_coded_exp": 1, "max_norm_coded_exp": 254, "total_bits": 32, "has_subnormals": True}
        }
        # 缓存乘数向量，避免重复计算
        self._multiplier_cache = {}

    def get_float_format_params(self, dtype: torch.dtype) -> dict:
        params = self.FLOAT_FORMAT_PARAMS.get(dtype)
        if params is None: raise ValueError(f"Soma: Unsupported float dtype: {dtype}")
        return params.copy()
        
    def _get_multipliers(self, exp_bits, man_bits, device, dtype):
        """获取指数和尾数的乘数向量，使用缓存提高性能"""
        cache_key = (exp_bits, man_bits, device, dtype)
        if cache_key in self._multiplier_cache:
            return self._multiplier_cache[cache_key]
            
        # 计算指数乘数
        exp_multiplier = torch.pow(2.0, torch.arange(exp_bits-1, -1, -1, 
                                                  device=device, dtype=dtype))
        # 计算尾数乘数
        man_multiplier = torch.pow(2.0, -torch.arange(1, man_bits+1, 
                                                   device=device, dtype=dtype))
                                                   
        result = (exp_multiplier, man_multiplier)
        # 缓存结果
        if len(self._multiplier_cache) < 10:  # 限制缓存大小
            self._multiplier_cache[cache_key] = result
            
        return result

    def forward(self, bit_tensor: torch.Tensor, target_dtype: torch.dtype, device = None) -> torch.Tensor:
        """处理输入的位张量，转换为数值张量，确保梯度正确流动
        
        参数:
            bit_tensor: 输入位张量
            target_dtype: 目标数据类型
            device: 输出设备，可以是字符串或torch.device对象
        """
        # 获取输入张量的设备
        input_device = bit_tensor.device
        
        # 确定输出设备
        if device is None:
            output_device = input_device
        elif isinstance(device, str):
            output_device = torch.device(device)
        else:
            # 假设device已经是torch.device对象
            output_device = device
        
        # 获取数据类型参数
        float_params = self.get_float_format_params(target_dtype)
        n_bits = float_params["total_bits"]
        
        # 检查输入形状
        if bit_tensor.shape[0] != n_bits:
            raise ValueError(f"Input bit_tensor's first dim must be {n_bits}, got {bit_tensor.shape[0]}")
            
        # 保存原始形状，并重塑为2D张量以简化处理
        original_shape = bit_tensor.shape[1:]
        bit_tensor = bit_tensor.reshape(n_bits, -1)
        
        # 处理空张量情况
        if bit_tensor.shape[1] == 0:
            return torch.empty(original_shape, dtype=target_dtype, device=output_device)

        # 使用float32进行中间计算以保持精度，同时保持梯度流
        compute_dtype = torch.float32
        
        # 特别处理符号位 - 直接使用位0作为符号位
        sign_bits = bit_tensor[0].to(compute_dtype)  # 转为计算类型
            
        # 提取指数位和尾数位，保持梯度流
        exp_bits = bit_tensor[1:1+float_params["exp_bits"]].to(compute_dtype)
        man_bits = bit_tensor[1+float_params["exp_bits"]:].to(compute_dtype)
        
        # 获取指数和尾数乘数向量
        exp_multiplier, man_multiplier = self._get_multipliers(
            float_params["exp_bits"], 
            float_params["man_bits"], 
            input_device,
            compute_dtype
        )
        
        # 计算指数值，保持梯度流
        coded_exp = torch.matmul(exp_bits.t(), exp_multiplier.unsqueeze(1)).squeeze(-1)
        
        # 计算尾数值，保持梯度流
        mantissa = torch.matmul(man_bits.t(), man_multiplier.unsqueeze(1)).squeeze(-1)
        
        # 创建输出张量，使用计算数据类型
        output = torch.zeros(bit_tensor.shape[1], dtype=compute_dtype, device=input_device)
        
        # 处理特殊情况
        zero_mask = (coded_exp == 0) & (mantissa == 0)
        inf_mask = (coded_exp == ((1 << float_params["exp_bits"]) - 1)) & (mantissa == 0)
        nan_mask = (coded_exp == ((1 << float_params["exp_bits"]) - 1)) & (mantissa != 0)
        normal_mask = ~(zero_mask | inf_mask | nan_mask)
        
        # 特殊情况下的符号位处理 - 保留符号位的负零，确保梯度流动
        # 零值处理
        if torch.any(zero_mask):
            # 设置符号位为零，但保留符号位的梯度流
            # 这里使用零乘以符号位，而不是直接设零，以保持梯度流
            output[zero_mask] = 0.0 * sign_bits[zero_mask]
        
        # 设置其他特殊值
        output[inf_mask] = float('inf')
        output[nan_mask] = float('nan')
        
        # 处理正常数值，确保梯度流
        if torch.any(normal_mask):
            normal_indices = torch.nonzero(normal_mask).squeeze(-1)
            actual_exp = coded_exp[normal_indices] - float_params["bias"]
            computed_values = (1.0 + mantissa[normal_indices]) * torch.pow(2.0, actual_exp)
            output[normal_indices] = computed_values
            
        # 应用符号位，确保梯度流
        # 使用符号位值本身决定符号，而不是二次加工
        sign_multiplier = torch.where(sign_bits > 0, 
                                    torch.tensor(-1.0, device=input_device, dtype=compute_dtype), 
                                    torch.tensor(1.0, device=input_device, dtype=compute_dtype))
        output = output * sign_multiplier
        
        # 转换到目标数据类型和设备
        if output_device != input_device or compute_dtype != target_dtype:
            output = output.to(dtype=target_dtype, device=output_device)
        
        # 重塑回原始形状
        return output.reshape(original_shape)