# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

from config.config_template import ConfigTemplate


def report_parameter_count(config: ConfigTemplate, module, verbose=False):
    # Report the total parameter count
    total_trainable = 0
    total_non_embed = 0
    for name, p in module.named_parameters():
        if p.requires_grad:  # Trainable params
            total_trainable += p.numel()
            if ("wte" not in name) and ("wpe" not in name):  # Non-embed params
                total_non_embed += p.numel()
        if verbose:
            print(name, p.numel(), p.requires_grad)
    print("total_trainable:", f"{total_trainable:,}")
    print("total_non_embed:", f"{total_non_embed:,}")

    # Report the expert parameter count, if applicable
    if config.ffwd_name in {"MLP"}:
        expert_param_total = None
        expert_param_active = None
    elif config.ffwd_name in {"MoE", "MoEEP"}:
        base = config.num_block - 2  # The first two blocks are stem blocks
        base *= 2 * config.emb_size * config.ffwd_expert_size  # Per expert
        expert_param_total = base * config.ffwd_num_expert
        expert_param_active = base * config.ffwd_num_expert_active
    elif config.ffwd_name in {"MHMoE",   "MHMoENaive",   "MHMoETied",
                              "MHMoEHP", "MHMoENaiveHP", "MHMoETiedHP",
                              "MHMoEHPNRT"}:
        base = config.num_block - 2  # The first two blocks are stem blocks
        base *= config.ffwd_num_head
        base *= 2 * config.ffwd_head_size * config.ffwd_expert_size  # Per expert
        expert_param_total = base * config.ffwd_num_expert
        expert_param_active = base * config.ffwd_num_expert_active
    elif config.ffwd_name in {"LatentMoE"}:
        expert_param_total = 0
        expert_param_active = 0
    else:
        raise Exception("Unexpected config.ffwd_name")
    print("expert_param_total  (if applicable):", expert_param_total)
    print("expert_param_active (if applicable):", expert_param_active)
