import json
import argparse

def get_by_path(data, path):
    """按点路径访问 json，例如 path='struct.layer1' """
    keys = path.split(".")
    for k in keys:
        data = data[k]
    return data

def sum_first_layer_out_channels(target_dict):
    """
    target_dict: 某路径下的结构，例如 json['struct']
    功能：遍历所有 *_cluster 的 dict，获取第一层 shape 的第 0 维（输出通道）
    """
    total = 0
    for cluster_name, cluster_dict in target_dict.items():
        if not cluster_name.endswith("_cluster"):
            continue

        # 取 cluster 内的第一层（dict 是有序的，Python 3.7+ 默认有序）
        first_layer_name = next(iter(cluster_dict))
        shape = cluster_dict[first_layer_name]

        if not isinstance(shape, list) or len(shape) == 0:
            continue

        out_channels = shape[0]
        total += out_channels

    return total


if __name__ == "__main__":
    json_path = "/nfs196/wjx/projects/PMP/outputs/LT/Imagenet_vit_b_16_cluster_data/tgt/vit_b_16_Imagenet/param_info.json"
    path = "struct"

    with open(json_path, "r") as f:
        data = json.load(f)

    target = get_by_path(data, path)
    total_out = sum_first_layer_out_channels(target)

    print(f"Total output channels sum = {total_out}")
