import torch, os

pred_path = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/Exp_AB_results/eval_results/arylation/arylation_4.0_yields.pt"
basename = os.path.basename(pred_path).split(".pt")[0] + "_merged.pt"
save_path = os.path.dirname(pred_path)
data = torch.load(pred_path, map_location='cpu')

split_path = "split_idx.pt"
split_data = torch.load(split_path)

data['train_idx'] = split_data['train']
data['val_idx'] = split_data['test']
torch.save(data, os.path.join(save_path, basename))
