import torch

weight = torch.load("data/PPG_DM/linear_prob_weight_mae.pt")['weight']
print(weight.shape)