import math

def calculate_energy():
    """
    Calculates and prints the energy consumption for different BERT model architectures.
    All parameters and energy constants are defined within this function.
    """

    # --- 1. Model and Architecture Parameters ---
    B = 64      # Batch size
    S = 128     # Sequence length
    C_o = 768   # Output channels
    C_i = 768   # Input channels
    h = 12      # Number of attention heads
    d_k = C_i // h  # Dimension per head (e.g., 768 / 12 = 64)

    # --- 2. Sparsity, SNN, and User-Defined Parameters ---
    # This section contains parameters you might want to adjust for experiments.
    T_typical = 4
    T = 15              # Time steps for SNN models (Sorbet, TTFS)
    s_r_sorbet = 0.33    # Spike rate for Sorbet model
    s_r_ttfs = 0.0514     # Spike rate for TTFS SNN model

    gamma = s_r_ttfs*T    # Sparsity factor for standard BERT (1.0 = dense, <1.0 = sparse)

    # --- 3. Base Energy Costs (in picoJoules, pJ) ---
    # These values are based on your provided information.
    E_leakage = 0.0154218  #pJ
    # FP32 Operations
    E_MAC_fp32 = 4.6
    E_clamp_fp32 = 0.9

    # INT4 Operations
    E_MAC_int4 = 0.066338
    E_ACC_int4 = 0.05021
    E_ACC_int1 = 0.04292
    # NOTE: Energy for INT4 clamp was not provided. Assuming a value based on other INT4 ops.
    E_clamp_int4 = 0.05021

    # SNN-specific Operations (Assumed values as they were not provided)
    E_CMP = 0.05021     # Energy for a comparison operation
    E_SUB = 0.05021     # Energy for a subtraction operation
    E_read_analog1 = 8.75 * 10 **-3 + 1.77*10**-6 # for power up the FTF and sampling
    E_read_analog3 = 5.3*10**-3+ 0.010505 # ADC + look up to table
    E_read_analog = E_read_analog1  + E_read_analog3
  
    E_weight_act = 0.098454 # For both Read and Write
    E_sparse_move_per_bit = 0.18 # For sparse data movement

    # --- 4. Derived Energy Costs (based on data precision) ---
    # These are calculated by multiplying the per-bit cost by the number of bits.

    # FP32 (32 bits)
    E_read_weight_fp32 = E_weight_act * 32
    E_read_kv_fp32 = E_weight_act * 32

    # INT4 (4 bits)
    E_read_weight_int4 = E_weight_act * 4
    E_read_kv_int4 = E_weight_act * 4
    E_write_kv_int4 = E_weight_act * 4      # For Sorbet write operations

    # Binary (1 bit) for TTFS SNN
    E_write_bkv_int1 = E_weight_act * 1     # For writing binary key/value
    E_read_binarykv_int1 = E_weight_act * 1 # For reading binary key/value
    
    # --- 5. Calculation Functions for Each Model ---

    def calculate_full_bert_energy():
        """Calculates energy for Full Precision (FP32) BERT."""
        # NOTE: The term 'math.log2(T + 1)' is included as per your formula, but it's
        # unusual for a non-spiking model. 'T' typically represents timesteps in SNNs.
        # This might be a typo in the original formula.
        
        # FC (Feed-Forward) Layers
        energy_per_output_neuron_fc = (gamma * C_i * (E_MAC_fp32 + E_read_weight_fp32 +32 * E_sparse_move_per_bit) + 2 * E_clamp_fp32+32*E_weight_act)
        energy_leakage = C_i*32*E_leakage
        energy_write_kv = B * S * C_o*32*E_weight_act
        E_FC = B * S * C_o * (energy_per_output_neuron_fc+energy_leakage)

        # QKV Attention Calculation
        energy_per_element_qkv = (d_k * gamma * (E_read_kv_fp32 + E_MAC_fp32 + 32 * E_sparse_move_per_bit) + 2 * E_clamp_fp32)
        energy_leakage_qkv = d_k*32*E_leakage
        E_FC_qkv = B * h * S * S * (energy_per_element_qkv+energy_leakage_qkv)
        
        total_energy = 3*E_FC + 2*E_FC_qkv + 2*energy_write_kv
        return total_energy, E_FC, E_FC_qkv

    def calculate_quantized_bert_energy():
        """Calculates energy for Quantized (INT4) BERT."""
        # NOTE: Same observation as in the full BERT model regarding the 'log2(T + 1)' term.
        
        # FC (Feed-Forward) Layers
        energy_per_output_neuron_fc = (gamma * C_i * (E_MAC_int4 + E_weight_act + math.log2(T + 1) * E_sparse_move_per_bit) + 2 * E_clamp_int4)
        energy_leakage = C_i*(math.log2(T + 1))*E_leakage
        energy_write_kv = B * S * C_o*E_weight_act

        E_FC_q = B * S * C_o * (energy_per_output_neuron_fc + energy_leakage)

        # QKV Attention Calculation
        energy_per_element_qkv = (d_k * gamma * (E_weight_act + E_MAC_int4+ math.log2(T + 1) * E_sparse_move_per_bit) + 2 * E_clamp_int4)
        energy_leakage_qkv = d_k*(math.log2(T + 1))*E_leakage
        E_FC_qkv = B * h * S * S * (energy_per_element_qkv+energy_leakage_qkv)

        total_energy = 3*E_FC_q + 2*E_FC_qkv +2*energy_write_kv
        return total_energy, E_FC_q, E_FC_qkv

    def calculate_sorbet_energy():
        """Calculates energy for Sorbet (INT4 SNN)."""
        # NOTE: The sparse move energy is multiplied by bit-width (4 for INT4)
        
        # FC (Feed-Forward) Layers
        sparse_move_cost_int4 = E_sparse_move_per_bit * 4
        mac_equivalent_fc = C_i * s_r_sorbet * T_typical * (E_ACC_int1 + E_weight_act + sparse_move_cost_int4)
        thresholding_fc = T_typical * (E_CMP + s_r_sorbet * E_SUB)
        leakage = C_i*(T_typical+1)*E_leakage

        energy_write_kv = B * S * C_o*E_weight_act

        E_Sorbet_FC = B * S * C_o * (mac_equivalent_fc + thresholding_fc + leakage)

        # QKV Attention Calculation
        mac_equivalent_qkv = d_k * s_r_sorbet * T_typical * (E_weight_act + E_ACC_int4 + sparse_move_cost_int4)
        thresholding_qkv = T_typical * (E_CMP + s_r_sorbet * E_SUB)
        leakage_qkv = d_k*T_typical*E_leakage
        E_Sorbet_qkv =  B * h * S * S * (mac_equivalent_qkv + thresholding_qkv+leakage_qkv)

        total_energy = 3*E_Sorbet_FC + 2*E_Sorbet_qkv + 2*energy_write_kv
        return total_energy, E_Sorbet_FC, E_Sorbet_qkv

    def calculate_ttfs_snn_energy():
        """Calculates energy for TTFS SNN (INT4)."""
        # NOTE: The sparse move energy is multiplied by bit-width (4 for INT4)
        
        # FC (Feed-Forward) Layers
        mac_equivalent_fc = C_i * T * s_r_ttfs * (E_ACC_int4 + E_read_analog + E_sparse_move_per_bit)
        thresholding_fc = T * (E_CMP+4*E_weight_act)
        leakage = C_i*T*E_leakage
        #E_Opo_FC = B * S * C_o * (mac_equivalent_fc + thresholding_fc + E_write_bkv_int1+leakage)
        E_Opo_FC = B * S * C_o * (mac_equivalent_fc + thresholding_fc +leakage)
        energy_write_kv = B * S * C_o*E_weight_act

        # QKV Attention Calculation
        mac_equivalent_qkv = d_k * s_r_ttfs * T * (E_read_analog + E_ACC_int4 + E_sparse_move_per_bit + E_weight_act)
        thresholding_qkv = T * (E_CMP+4*E_weight_act)
        leakage_qkv = d_k*T*E_leakage
        E_Opo_qkv =  B * h * S * S * (mac_equivalent_qkv + thresholding_qkv+leakage_qkv)

        total_energy = 3*E_Opo_FC + 2*E_Opo_qkv+2*energy_write_kv
        return total_energy, E_Opo_FC, E_Opo_qkv


    def calculate_ttfs_traditional_snn_energy():
        """Calculates energy for TTFS SNN (INT4)."""
        # NOTE: The sparse move energy is multiplied by bit-width (4 for INT4)
        
        # FC (Feed-Forward) Layers
        mac_equivalent_fc = C_i * T * s_r_ttfs * (0.0163 + E_MAC_int4 +E_weight_act+ E_sparse_move_per_bit)
        thresholding_fc = T * (E_CMP+4*E_weight_act)
        leakage = C_i*T*E_leakage
        #E_Opo_FC = B * S * C_o * (mac_equivalent_fc + thresholding_fc + E_write_bkv_int1+leakage)
        E_Opo_FC = B * S * C_o * (mac_equivalent_fc + thresholding_fc +leakage)
        energy_write_kv = B * S * C_o*E_weight_act

        # QKV Attention Calculation
        mac_equivalent_qkv = d_k * s_r_ttfs * T * (E_weight_act +E_MAC_int4+ 0.0163 + E_sparse_move_per_bit + E_weight_act)
        thresholding_qkv = T * (E_CMP+4*E_weight_act)
        leakage_qkv = d_k*T*E_leakage
        E_Opo_qkv =  B * h * S * S * (mac_equivalent_qkv + thresholding_qkv+leakage_qkv)

        total_energy = 3*E_Opo_FC + 2*E_Opo_qkv+2*energy_write_kv
        return total_energy, E_Opo_FC, E_Opo_qkv
    # --- 6. Main Execution and Output ---
    
    # Perform calculations
    full_bert_total, full_bert_fc, full_bert_qkv = calculate_full_bert_energy()
    full_bert_total = full_bert_total*1e-9
    full_bert_fc = full_bert_fc*1e-9
    full_bert_qkv = full_bert_qkv*1e-9
    quant_bert_total, quant_bert_fc, quant_bert_qkv = calculate_quantized_bert_energy()
    quant_bert_total = quant_bert_total*1e-9
    quant_bert_fc = quant_bert_fc*1e-9
    quant_bert_qkv = quant_bert_qkv*1e-9
    sorbet_total, sorbet_fc, sorbet_qkv = calculate_sorbet_energy()
    sorbet_total = sorbet_total*1e-9
    sorbet_fc = sorbet_fc*1e-9
    sorbet_qkv = sorbet_qkv*1e-9
    ttfs_total, ttfs_fc, ttfs_qkv = calculate_ttfs_snn_energy()
    ttfs_total = ttfs_total*1e-9
    ttfs_fc = ttfs_fc*1e-9
    ttfs_qkv = ttfs_qkv*1e-9
    ttfs_traditional_total, ttfs_traditional_fc, ttfs_traditional_qkv = calculate_ttfs_traditional_snn_energy()
    ttfs_traditional_total = ttfs_traditional_total*1e-9
    ttfs_traditional_fc = ttfs_traditional_fc*1e-9
    ttfs_traditional_qkv = ttfs_traditional_qkv*1e-9
    # Print results in a formatted table
    print("-" * 80)
    print(f"{'Model Architecture':<25} | {'FC Energy (mJ)':<18} | {'QKV Energy (mJ)':<18} | {'Total Energy (mJ)':<20}")
    print("-" * 80)
    print(f"{'Full BERT (FP32)':<25} | {full_bert_fc:.2f} | {full_bert_qkv:.2f} | {full_bert_total:.2f}")
    print(f"{'Quantized BERT (INT4)':<25} | {quant_bert_fc:.2f} | {quant_bert_qkv:.2f} | {quant_bert_total:.2f}")
    print(f"{'Sorbet (INT4 SNN)':<25} | {sorbet_fc:.2f} | {sorbet_qkv:.2f} | {sorbet_total:.2f}")
    print(f"{'TTFS SNN (INT4)':<25} | {ttfs_fc:.2f} | {ttfs_qkv:.2f} | {ttfs_total:.2f}")
    print(f"{'TTFS Traditional SNN (INT4)':<25} | {ttfs_traditional_fc:.2f} | {ttfs_traditional_qkv:.2f} | {ttfs_traditional_total:.2f}")
    print("-" * 80)


# Run the calculation when the script is executed
if __name__ == "__main__":
    calculate_energy()
