"""
Time-Series 모델 에너지 측정 모듈

이 모듈은 Time-Series Forecasting 모델(iSpikeformer, iTransformer)의
에너지 소비량을 정확히 측정하기 위한 함수들을 제공합니다.

기존 hook 시스템(Linear, Conv, LIF, BN 등)으로 측정되지 않는
matmul과 einsum 연산을 수동으로 계산하여 완전한 에너지 측정을 수행합니다.

Copyright (C) 2025 - Adapted for Time-Series Neuromorphic Models
"""

import numpy as np


def get_ssa_matmul_ops(model, config):
    """
    iSpikeformer SSA (Spiking Self-Attention) matmul 연산량 계산

    SSA의 Q×K^T와 A×V matmul 연산은 torch.matmul로 구현되어 있어
    hook 시스템에서 측정되지 않습니다. 이 함수는 해당 연산량을
    수동으로 계산하고 발화율을 적용하여 AC operations를 산출합니다.

    Parameters
    ----------
    model : iSpikeformer
        측정할 모델 (hook 적용 완료 후)
        model.blocks[i].attn.q_lif.__syops__[3]에 발화율이 저장됨
    config : dict
        모델 설정
        - 'depth' (int): SSA 블록 수 (기본값: 2)
        - 'heads' (int): Attention head 수 (기본값: 8)
        - 'dim' (int): 모델 차원 (기본값: 64)
        - 'seq_len' (int): Sequence length (기본값: 96)
        - 'T' (int): SNN timesteps (기본값: 4)

    Returns
    -------
    dict
        {
            'total_ACs': float
                총 AC operations (Q×K^T + A×V 합계)
            'qk_ACs': float
                Q×K^T matmul AC operations
            'av_ACs': float
                A×V matmul AC operations
            'layer_details': list of dict
                레이어별 상세 정보
        }

    Notes
    -----
    - Q×K^T: sparse(Q) × sparse(K^T) → min(q_rate, k_rate) 적용
    - A×V: dense(softmax(A)) × sparse(V) → v_rate만 적용
    - FLOPs = T × heads × seq_len × seq_len × (dim / heads)

    Examples
    --------
    >>> config = {'depth': 2, 'heads': 8, 'dim': 64, 'seq_len': 96, 'T': 4}
    >>> result = get_ssa_matmul_ops(model, config)
    >>> print(f"Total ACs: {result['total_ACs'] / 1e9:.3f} G")
    """
    depth = config.get('depth', 2)
    heads = config.get('heads', 8)
    dim = config.get('dim', 64)
    seq_len = config.get('seq_len', 96)
    T = config.get('T', 4)

    dim_per_head = dim // heads

    # SSA matmul 기본 연산량 (발화율 적용 전)
    # Q×K^T: [T, B, heads, seq_len, dim_per_head] @ [T, B, heads, dim_per_head, seq_len]
    #        → [T, B, heads, seq_len, seq_len]
    # FLOPs per batch = T × heads × seq_len × seq_len × dim_per_head
    qk_base = T * heads * seq_len * seq_len * dim_per_head

    # A×V: [T, B, heads, seq_len, seq_len] @ [T, B, heads, seq_len, dim_per_head]
    #      → [T, B, heads, seq_len, dim_per_head]
    # FLOPs per batch = T × heads × seq_len × seq_len × dim_per_head
    av_base = T * heads * seq_len * seq_len * dim_per_head

    total_qk_ACs = 0
    total_av_ACs = 0
    layer_details = []

    # 각 SSA 블록 순회
    for i in range(depth):
        ssa = model.blocks[i].attn  # SSA module

        # 발화율 가져오기 (__syops__[3]에 백분율로 저장됨)
        q_rate = ssa.q_lif.__syops__[3] / 100.0
        k_rate = ssa.k_lif.__syops__[3] / 100.0
        v_rate = ssa.v_lif.__syops__[3] / 100.0

        # Q×K^T: sparse × sparse → 두 발화율 중 작은 값 적용
        # AC operations = base_ops × min(q_rate, k_rate)
        qk_ACs = qk_base * min(q_rate, k_rate)

        # A×V: dense(softmax) × sparse
        # Softmax 출력은 dense (발화율 1.0)이므로 V의 발화율만 적용
        # AC operations = base_ops × v_rate
        av_ACs = av_base * v_rate

        total_qk_ACs += qk_ACs
        total_av_ACs += av_ACs

        layer_details.append({
            'block': i,
            'q_rate': float(q_rate),
            'k_rate': float(k_rate),
            'v_rate': float(v_rate),
            'qk_base_ops': float(qk_base),
            'av_base_ops': float(av_base),
            'qk_ACs': float(qk_ACs),
            'av_ACs': float(av_ACs),
            'total_ACs': float(qk_ACs + av_ACs)
        })

    return {
        'total_ACs': float(total_qk_ACs + total_av_ACs),
        'qk_ACs': float(total_qk_ACs),
        'av_ACs': float(total_av_ACs),
        'layer_details': layer_details
    }


def get_attention_einsum_ops(model, config):
    """
    iTransformer FullAttention einsum 연산량 계산

    iTransformer의 FullAttention은 torch.einsum으로 구현되어 있어
    hook 시스템에서 측정되지 않습니다. 이 함수는 Q×K^T, Softmax, A×V
    einsum 연산량을 수동으로 계산합니다.

    Parameters
    ----------
    model : iTransformer
        측정할 모델
    config : dict
        모델 설정
        - 'depth' (int): Encoder 레이어 수 (기본값: 2)
        - 'heads' (int): Attention head 수 (기본값: 8)
        - 'dim' (int): 모델 차원 (기본값: 64)
        - 'seq_len' (int): Sequence length (기본값: 96)

    Returns
    -------
    dict
        {
            'total_MACs': float
                총 MAC operations (Q×K^T + Softmax + A×V 합계)
            'qk_MACs': float
                Q×K^T einsum MAC operations
            'softmax_MACs': float
                Softmax MAC operations
            'av_MACs': float
                A×V einsum MAC operations
            'layer_details': list of dict
                레이어별 상세 정보
        }

    Notes
    -----
    - Q×K^T einsum: "blhe,bshe->bhls"
      FLOPs = B × heads × seq_len × seq_len × dim_per_head
    - Softmax: exp + sum + div + scale
      FLOPs = B × heads × seq_len × seq_len × 4
    - A×V einsum: "bhls,bshd->blhd"
      FLOPs = B × heads × seq_len × seq_len × dim_per_head

    Examples
    --------
    >>> config = {'depth': 2, 'heads': 8, 'dim': 64, 'seq_len': 96}
    >>> result = get_attention_einsum_ops(model, config)
    >>> print(f"Total MACs: {result['total_MACs'] / 1e9:.3f} G")
    """
    depth = config.get('depth', 2)
    heads = config.get('heads', 8)
    dim = config.get('dim', 64)
    seq_len = config.get('seq_len', 96)

    dim_per_head = dim // heads

    # Q×K^T einsum: "blhe,bshe->bhls"
    # [B, seq_len, heads, dim_per_head] × [B, seq_len, heads, dim_per_head]
    # → [B, heads, seq_len, seq_len]
    # FLOPs per batch = heads × seq_len × seq_len × dim_per_head
    qk_MACs_base = heads * seq_len * seq_len * dim_per_head

    # Softmax: exp, sum, div, scale
    # Input: [B, heads, seq_len, seq_len]
    # Operations: exp(x), sum(exp), exp/sum, scale
    # FLOPs per batch = heads × seq_len × seq_len × 4
    softmax_MACs_base = heads * seq_len * seq_len * 4

    # A×V einsum: "bhls,bshd->blhd"
    # [B, heads, seq_len, seq_len] × [B, seq_len, heads, dim_per_head]
    # → [B, seq_len, heads, dim_per_head]
    # FLOPs per batch = heads × seq_len × seq_len × dim_per_head
    av_MACs_base = heads * seq_len * seq_len * dim_per_head

    total_qk_MACs = 0
    total_softmax_MACs = 0
    total_av_MACs = 0
    layer_details = []

    # Encoder의 각 레이어
    for i in range(depth):
        # iTransformer는 ANN → 발화율 없음 (항상 1.0)
        total_qk_MACs += qk_MACs_base
        total_softmax_MACs += softmax_MACs_base
        total_av_MACs += av_MACs_base

        layer_details.append({
            'layer': i,
            'qk_MACs': float(qk_MACs_base),
            'softmax_MACs': float(softmax_MACs_base),
            'av_MACs': float(av_MACs_base),
            'total_MACs': float(qk_MACs_base + softmax_MACs_base + av_MACs_base)
        })

    return {
        'total_MACs': float(total_qk_MACs + total_softmax_MACs + total_av_MACs),
        'qk_MACs': float(total_qk_MACs),
        'softmax_MACs': float(total_softmax_MACs),
        'av_MACs': float(total_av_MACs),
        'layer_details': layer_details
    }


def calculate_timeseries_energy(
    model,
    model_name,
    config,
    E_AC=0.9e-12,    # 0.9 pJ per AC operation (45nm CMOS)
    E_MAC=4.6e-12    # 4.6 pJ per MAC operation (45nm CMOS)
):
    """
    Time-Series 모델의 전체 에너지 소비량 계산

    기존 hook 시스템으로 측정된 연산량(Linear, Conv, BN, LIF, QAP MHA 등)과
    수동으로 계산한 연산량(SSA matmul, Attention einsum)을 합산하여
    정확한 에너지 소비량을 계산합니다.

    Parameters
    ----------
    model : nn.Module
        측정할 모델 (add_syops_counting_methods 적용 후, forward pass 완료 후)
    model_name : str
        모델 이름 ('ispikeformer' or 'itransformer')
    config : dict
        모델 설정
        - 'depth' (int): 레이어/블록 수
        - 'heads' (int): Attention head 수
        - 'dim' (int): 모델 차원
        - 'seq_len' (int): Sequence length
        - 'T' (int): SNN timesteps (iSpikeformer만)
    E_AC : float, optional
        AC operation당 에너지 (Joules), 기본값: 0.9e-12 (0.9 pJ)
    E_MAC : float, optional
        MAC operation당 에너지 (Joules), 기본값: 4.6e-12 (4.6 pJ)

    Returns
    -------
    dict
        {
            'hook_ACs': float
                Hook으로 측정된 AC operations
            'hook_MACs': float
                Hook으로 측정된 MAC operations
            'manual_ACs': float
                수동 계산 AC operations (SSA matmul)
            'manual_MACs': float
                수동 계산 MAC operations (Attention einsum)
            'total_ACs': float
                총 AC operations
            'total_MACs': float
                총 MAC operations
            'total_ops': float
                총 operations (ACs + MACs)
            'AC_energy_mJ': float
                AC 에너지 (mJ)
            'MAC_energy_mJ': float
                MAC 에너지 (mJ)
            'total_energy_mJ': float
                총 에너지 (mJ)
            'breakdown': dict
                에너지 구성 비율
            'hook_details': list
                Hook으로 측정된 레이어 상세 정보
            'manual_details': list
                수동 계산 레이어 상세 정보
        }

    Notes
    -----
    - Hook 시스템은 Linear, Conv, BN, LIF, QAP MultiheadAttention 등을 측정
    - 수동 계산은 torch.matmul, torch.einsum 등 hook 미등록 연산을 측정
    - 발화율은 hook 적용 후 forward pass 시 자동으로 측정됨

    Examples
    --------
    >>> from energy_consumption_calculation.engine import add_syops_counting_methods
    >>> model_with_hooks = add_syops_counting_methods(model)
    >>> model_with_hooks.start_syops_count(ost=sys.stdout, verbose=False, ignore_list=[])
    >>>
    >>> # Forward pass
    >>> for batch in test_loader:
    >>>     output = model_with_hooks(x)
    >>>     functional.reset_net(model_with_hooks)
    >>>
    >>> # 에너지 계산
    >>> config = {'depth': 2, 'heads': 8, 'dim': 64, 'seq_len': 96, 'T': 4}
    >>> result = calculate_timeseries_energy(model_with_hooks, 'ispikeformer', config)
    >>> print(f"Total energy: {result['total_energy_mJ']:.3f} mJ")
    >>> print(f"Hook: {result['breakdown']['hook_percentage']:.1f}%")
    >>> print(f"Manual: {result['breakdown']['manual_percentage']:.1f}%")
    """
    # 1. Hook 측정 결과 수집
    hook_ACs = 0
    hook_MACs = 0
    hook_details = []

    for name, module in model.named_modules():
        if hasattr(module, '__syops__'):
            syops = module.__syops__
            # syops[0]: oriMACs, syops[1]: ACs, syops[2]: MACs, syops[3]: spike_rate
            ACs = syops[1]
            MACs = syops[2]

            hook_ACs += ACs
            hook_MACs += MACs

            if ACs > 0 or MACs > 0:
                hook_details.append({
                    'name': name,
                    'type': type(module).__name__,
                    'ACs': float(ACs),
                    'MACs': float(MACs),
                    'spike_rate': float(syops[3] / 100.0) if syops[3] > 0 else 0.0
                })

    # 2. 수동 계산 (누락된 연산)
    manual_ACs = 0
    manual_MACs = 0
    manual_details = []

    if model_name.lower() == 'ispikeformer':
        # SSA matmul 계산
        ssa_result = get_ssa_matmul_ops(model, config)
        manual_ACs = ssa_result['total_ACs']
        manual_details = ssa_result['layer_details']

    elif model_name.lower() == 'itransformer':
        # Attention einsum 계산
        attn_result = get_attention_einsum_ops(model, config)
        manual_MACs = attn_result['total_MACs']
        manual_details = attn_result['layer_details']

    # 3. 총합
    total_ACs = hook_ACs + manual_ACs
    total_MACs = hook_MACs + manual_MACs
    total_ops = total_ACs + total_MACs

    # 4. 에너지 계산
    AC_energy = total_ACs * E_AC * 1e3  # mJ
    MAC_energy = total_MACs * E_MAC * 1e3  # mJ
    total_energy = AC_energy + MAC_energy

    # 5. 구성 비율
    breakdown = {
        'hook_ops': float(hook_ACs + hook_MACs),
        'manual_ops': float(manual_ACs + manual_MACs),
        'hook_percentage': float((hook_ACs + hook_MACs) / total_ops * 100) if total_ops > 0 else 0.0,
        'manual_percentage': float((manual_ACs + manual_MACs) / total_ops * 100) if total_ops > 0 else 0.0,
        'AC_energy_percentage': float(AC_energy / total_energy * 100) if total_energy > 0 else 0.0,
        'MAC_energy_percentage': float(MAC_energy / total_energy * 100) if total_energy > 0 else 0.0
    }

    return {
        'hook_ACs': float(hook_ACs),
        'hook_MACs': float(hook_MACs),
        'manual_ACs': float(manual_ACs),
        'manual_MACs': float(manual_MACs),
        'total_ACs': float(total_ACs),
        'total_MACs': float(total_MACs),
        'total_ops': float(total_ops),
        'AC_energy_mJ': float(AC_energy),
        'MAC_energy_mJ': float(MAC_energy),
        'total_energy_mJ': float(total_energy),
        'breakdown': breakdown,
        'hook_details': hook_details,
        'manual_details': manual_details
    }
