import torch
from collections import defaultdict

# 构建树结构
def tree(): return defaultdict(tree)

def insert(t, key_path):
    for k in key_path[:-1]:
        t = t[k]
    t[key_path[-1]] = None  # 最后一层为叶节点

def main(state_dict, file):
    structure = tree()
    for k in state_dict.keys():
        insert(structure, k.split("."))
    
    # 递归打印结构并写入文件
    def write_tree_to_file(d, f, indent=0):
        for k in sorted(d.keys()):
            f.write("  " * indent + str(k) + "\n")
            if isinstance(d[k], dict):
                write_tree_to_file(d[k], f, indent + 1)
    
    with open(file, "w") as f:
        write_tree_to_file(structure, f)
    
    print(f"模型树状结构已保存到: {file}")


if __name__ == '__main__':
    checkpoint = 'models/resnet18_cifar10.bin'
    state_dict = torch.load(checkpoint, map_location='cpu')
    if "model" in state_dict:
        state_dict = state_dict["model"]
    file = 'model_info.txt'
    main(state_dict, file)