from energy import init_energy, EngModule, wrap_model_with_eng
import torch
import torch.nn as nn
import numpy as np
from typing import Callable, Tuple, List, Any

def compute_MASFR_from_LASFR(
    LASFR: np.ndarray,
    model: nn.Module,
    dummy_input: torch.Tensor,
    device
): 
    """
    输入:
      LASFR: np.ndarray, shape (K, T) -- 每层（按 ReLU 出现顺序）到每个时间步的脉冲发放率
    - model: nn.Module
    - dummy_input: Tensor or tuple/list of Tensors suitable for model forward
    - device: torch.device or string

    输出:
      LASFR_used: np.ndarray shape (L, T) -- 用于计算 MASFR 的 LASFR（可能丢弃最后一行）
      MASFR: np.ndarray shape (T,) -- 模型级别按神经元数加权平均后得到的脉冲发放率
      num_neurons: list[int] -- 每个 ReLU 层的单样本神经元数（顺序与 LASFR_used 匹配）

    要求/注意:
      - L 必须等于 K 或 K-1。若 L == K-1，则丢弃 LASFR 的最后一行（如你所述）。
      - 神经元数以单个样本为基准：输出 tensor 的 shape 除去 batch dim 后元素个数（例如 conv 输出 (N,C,H,W) -> C*H*W）。
    """
    # --- 参数检查 ---
    LASFR = np.asarray(LASFR)
    if LASFR.ndim != 2:
        raise ValueError(f"LASFR must be 2D (K, T). Got shape {LASFR.shape}.")
    K, T = LASFR.shape

    model = model.to(device)
    model.eval()

    # --- 定义 ReLU wrapper that captures num_neurons ---
    class ReLUCapture(nn.Module):
        def __init__(self, inplace: bool = False):
            super().__init__()
            self.relu = nn.ReLU(inplace=inplace)
            self.num_neurons = None  # single-sample neuron count (int)

        def forward(self, x: torch.Tensor):
            out = self.relu(x)
            # compute per-sample neuron count: product of output.shape[1:]
            if out.dim() >= 2:
                shape_tail = out.shape[1:]
                # product of shape_tail
                num = 1
                for s in shape_tail:
                    num *= s
                self.num_neurons = int(num)
            else:
                # scalar or (N,) case
                self.num_neurons = 1
            return out


    # --- helper: replace nn.ReLU with ReLUCapture recursively ---
    def replace_relu_modules(module: nn.Module):
        for name, child in list(module.named_children()):
            # if it's exactly nn.ReLU, replace
            if isinstance(child, nn.ReLU):
                new = ReLUCapture(inplace=child.inplace)
                setattr(module, name, new)
            else:
                # recurse into child
                replace_relu_modules(child)

    replace_relu_modules(model)

    # --- forward once with dummy_input to trigger captures ---
    with torch.no_grad():
        # prepare input(s)
        dummy_input = dummy_input.to(device)
        model(dummy_input)

    # --- collect captured neuron counts in module traversal order ---
    num_neurons = []
    
    def collect_by_named_children(module: nn.Module):
        for name, child in module.named_children():
            if isinstance(child, ReLUCapture):
                if child.num_neurons is None:
                    raise RuntimeError(
                        f"ReLUCapture layer '{name}' didn't record num_neurons. "
                        "Check that dummy_input passed through this layer."
                    )
                num_neurons.append(int(child.num_neurons))
            else:
                # 递归进入子模块（保持层级/定义顺序）
                collect_by_named_children(child)

    collect_by_named_children(model)

    L = len(num_neurons)

    # --- validate K vs L and possibly drop last row of LASFR ---
    if not (L == K or L == K - 1):
        raise ValueError(
            f"Number of captured ReLU layers L={L} must equal K={K} or K-1. "
            "If mismatch, check model/dummy_input or your LASFR input."
        )

    if L == K - 1:
        LASFR_used = LASFR[:-1, :]  # 丢弃 LASFR 的最后一行
    else:
        LASFR_used = LASFR.copy()

    # --- compute weighted average per timestep ---
    weights = np.array(num_neurons, dtype=float)
    wsum = weights.sum()
    if wsum == 0:
        raise ValueError("Sum of neuron counts is zero.")
    weights = weights / wsum  # normalized weights shape (L,)

    # LASFR_used shape (L, T)
    if LASFR_used.shape[0] != L:
        raise RuntimeError("After trimming, LASFR rows do not match captured layers.")

    # MASFR[t] = sum_i weights[i] * LASFR_used[i, t]
    MASFR = (weights[:, None] * LASFR_used).sum(axis=0)

    return MASFR
