import json
import argparse
from functools import reduce
import operator

def get_by_path(data, path):
    """按点路径访问 json，例如 path='struct.layer1' """
    keys = path.split(".")
    for k in keys:
        data = data[k]
    return data

def compute_cluster_dim(cluster_dict):
    """
    对一个 cluster：
    对每一层 shape 去除第一维并求积，如果无剩余维度，则记为 1。
    然后对整个 cluster 求和
    """
    total_dim = 0
    for layer_name, shape in cluster_dict.items():
        if not isinstance(shape, list) or len(shape) == 0:
            continue

        # 除去第一维
        rest_dims = shape[1:]

        # 若没有其他维度，则按 1 处理
        if len(rest_dims) == 0:
            dim = 1
        else:
            dim = reduce(operator.mul, rest_dims, 1)

        total_dim += dim

    return total_dim

def get_max_cluster_dim(target_dict):
    """
    遍历 struct 下所有 *_cluster，计算每个 cluster 的维数，返回最大值
    """
    max_dim = 0
    max_cluster_name = None

    for cluster_name, cluster_dict in target_dict.items():
        if not cluster_name.endswith("_cluster"):
            continue

        cluster_dim = compute_cluster_dim(cluster_dict)

        if cluster_dim > max_dim:
            max_dim = cluster_dim
            max_cluster_name = cluster_name

    return max_cluster_name, max_dim


if __name__ == "__main__":
    json_path = "/nfs196/wjx/projects/trash/vgg19_bn/vgg19_bn_cifar10_badnet_ata/param_info.json"
    path = "struct"

    with open(json_path, "r") as f:
        data = json.load(f)

    target = get_by_path(data, path)
    name, max_dim = get_max_cluster_dim(target)

    print(f"Max cluster dim = {max_dim}")
    print(f"Cluster name    = {name}")
