from collections import OrderedDict

import torch
from tqdm import tqdm

dict = torch.load("../trained_models/clip_huge_frozen_mlp_multilabel.pth")

new_dict = OrderedDict({})
for key in tqdm(dict):
    if key.startswith("mlp."):
        new_dict[key] = dict[key]
torch.save(new_dict, "../trained_models/clip_huge_frozen_mlp_multilabel_2.pth")