import torch
import torch.nn as nn

class YieldPredLayer(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=1):
        super(YieldPredLayer, self).__init__()
        self.act = nn.SiLU()
        self.predictor = nn.Sequential(
                            nn.Linear(input_size, hidden_size),
                            nn.Linear(hidden_size, 1),
                        )
        
    def forward(self, x):
        pred = self.predictor(x)
     
        return pred

yield_predictor_path = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/share/llama-3.1-8B/clustered/suzuki_60_new_prompt/predictor.pt"

predictor = YieldPredLayer(4096,1024,1)
ckpt = torch.load(yield_predictor_path)
import pdb;pdb.set_trace()
predictor.load_state_dict(ckpt)