# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

from config.get_config import get_config
from config.config_template import ConfigTemplate

def calculate_parameter_count(config: ConfigTemplate):
    if config.ffwd_name in {"MLP"}:
        assert config.num_block >= 2

        c_pre_processing = config.emb_size * config.vocab_size
        c_post_processing = config.emb_size * config.num_class
        c_attention = 4 * config.emb_size * config.attn_num_head * config.attn_head_size
        c_stem = 2 * config.emb_size * (4 * config.emb_size)
        c_mlp = 2 * config.emb_size * config.ffwd_hid_size

        out = 0
        out += c_pre_processing + c_post_processing
        out += 2 * (c_attention + c_stem)
        out += (config.num_block - 2) * (c_attention + c_mlp)
        return {"active": out, "total": out}

    if config.ffwd_name in {"MoE", "MoEEP"}:
        assert config.num_block >= 2

        c_pre_processing = config.emb_size * config.vocab_size
        c_post_processing = config.emb_size * config.num_class
        c_attention = 4 * config.emb_size * config.attn_num_head * config.attn_head_size
        c_stem = 2 * config.emb_size * (4 * config.emb_size)
        c_router = config.emb_size * config.ffwd_num_expert
        c_expert_total = 2 * config.emb_size * config.ffwd_expert_size * config.ffwd_num_expert
        c_expert_active = 2 * config.emb_size * config.ffwd_expert_size * config.ffwd_num_expert_active

        out_base = 0
        out_base += c_pre_processing + c_post_processing
        out_base += 2 * (c_attention + c_stem)

        out_total = out_base + (config.num_block - 2) * (c_attention + c_router + c_expert_total)
        out_active = out_base + (config.num_block - 2) * (c_attention + c_router + c_expert_active)
        return {"active": out_active, "total": out_total}

    if config.ffwd_name in {"MHMoEHPNRT"}:
        assert config.num_block >= 2

        c_pre_processing = config.emb_size * config.vocab_size
        c_post_processing = config.emb_size * config.num_class
        c_attention = 4 * config.emb_size * config.attn_num_head * config.attn_head_size
        c_stem = 2 * config.emb_size * (4 * config.emb_size)
        c_latent = 3 * config.emb_size * config.ffwd_num_head * config.ffwd_head_size
        c_router = config.ffwd_num_head * config.ffwd_head_size * config.ffwd_num_expert
        c_expert_total = 2 * config.ffwd_num_head * config.ffwd_head_size * config.ffwd_num_expert * config.ffwd_expert_size
        c_expert_active = 2 * config.ffwd_num_head * config.ffwd_head_size * config.ffwd_num_expert_active * config.ffwd_expert_size

        out_base = 0
        out_base += c_pre_processing + c_post_processing
        out_base += 2 * (c_attention + c_stem)

        out_total = out_base + (config.num_block - 2) * (c_attention + c_latent + c_router + c_expert_total)
        out_active = out_base + (config.num_block - 2) * (c_attention + c_latent + c_router + c_expert_active)
        return {"active": out_active, "total": out_total}

    if config.ffwd_name in {"LatentMoE"}:
        assert config.num_block >= 2
        latent_moe_reduction_ratio = 4
        assert config.emb_size % latent_moe_reduction_ratio == 0
        emb_size_reduced = config.emb_size // latent_moe_reduction_ratio

        c_pre_processing = config.emb_size * config.vocab_size
        c_post_processing = config.emb_size * config.num_class
        c_attention = 4 * config.emb_size * config.attn_num_head * config.attn_head_size
        c_stem = 2 * config.emb_size * (4 * config.emb_size)
        c_router = config.emb_size * config.ffwd_num_expert
        c_expert_total = 2 * emb_size_reduced * config.ffwd_expert_size * config.ffwd_num_expert
        c_expert_active = 2 * emb_size_reduced * config.ffwd_expert_size * config.ffwd_num_expert_active
        c_latent = 2 * config.emb_size * emb_size_reduced

        out_base = 0
        out_base += c_pre_processing + c_post_processing
        out_base += 2 * (c_attention + c_stem)

        out_total = out_base + (config.num_block - 2) * (c_attention + c_router + c_expert_total + c_latent)
        out_active = out_base + (config.num_block - 2) * (c_attention + c_router + c_expert_active + c_latent)
        return {"active": out_active, "total": out_total}

    raise Exception("Unexpected config.ffwd_name")

for name in ["X0-00", "X1-01", "X1-01-G", "X1-02", "X1-03", "X2-01", "X2-01-G", "X2-02", "X2-03"]:
    config = get_config(f"./config_files/{name}.yaml")
    counts = calculate_parameter_count(config)
    active_b = round(counts["active"] / 1_000_000_000, 1)
    total_b = round(counts["total"] / 1_000_000_000, 1)
    print(f"{name}: active = {active_b} billion, total = {total_b} billion")
    print("\n")
