import os
import torch


def extract_backbone(src_ckpt: str, dst_ckpt: str, prefix: str = "backbone."):
    # 1. 加载原始 checkpoint
    checkpoint = torch.load(src_ckpt, map_location="cpu")
    # 如果 checkpoint 里有 'state_dict' 包装，先拆一次
    if "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint

    print(state_dict.keys())
    # 2. 筛选并去前缀
    backbone_dict = {}
    for key, val in state_dict.items():
        if key.startswith(prefix):
            new_key = key[len(prefix) :]  # 去掉 'backbone.' 前缀
            backbone_dict[new_key] = val
    print(backbone_dict.keys())

    # 3. 保存
    os.makedirs(os.path.dirname(dst_ckpt), exist_ok=True)
    torch.save(backbone_dict, dst_ckpt)
    print(f"Extracted {len(backbone_dict)} {prefix} parameters to {dst_ckpt}")


if __name__ == "__main__":
    src = "checkpoints/sapiens-pose-0.3b/sapiens_0.3b_goliath_best_goliath_AP_573.pth"
    dst = "checkpoints/sapiens-pose-0.3b/backbone.pth"
    extract_backbone(src, dst, "backbone.")
