import torch
import os
import glob

# 模型所在目录（统一使用原始反斜杠或原始字符串避免转义问题）
model_dir = r'F:\OT_Score\feature_extractor\adapted_models'

# 查找所有 .pth 文件
model_paths = glob.glob(os.path.join(model_dir, '*.pth'))

# 遍历处理每个模型文件
for model_path in model_paths:
    print(f"Processing: {model_path}")
    try:
        adapted_net = torch.load(model_path, map_location='cpu')

        # 如果已经是 state_dict，可以跳过保存
        if isinstance(adapted_net, dict) and all(isinstance(k, str) for k in adapted_net.keys()):
            print("Already a state_dict, skipping...")
            continue

        # 否则保存为 state_dict 格式
        torch.save(adapted_net.state_dict(), model_path)
        print(f"State dict saved to: {model_path}")

    except Exception as e:
        print(f"Failed to process {model_path}: {e}")
