from torch.utils.data import DataLoader
from .BP_SA_critic_dataset import spMolDataset, spcollate_fn


def gnn_bp_sa_scorer(gnn_bp_model,device,protein_dir,query_protein,valid_smiles):
    dataset = spMolDataset(protein_dir,
                             query_protein,
                             valid_smiles)

    dataloader = DataLoader(dataset,1,
         shuffle=False, num_workers = 7, collate_fn=spcollate_fn)

    gnn_bp_model.eval()
    outputs = {}
    for batch, sample in enumerate(dataloader):
        gnn_bp_model.zero_grad()
        H1, H2, A1, A2, V, sa_score, key = sample
        if key == 0:
            key = [0.0]
            sa = [0.0]
            bp = [0.0]
        else:
           H1, H2, A1, A2, V = H1.to(device), H2.to(device), A1.to(device), A2.to(device),V.to(device)
           output = gnn_bp_model((H1, H2, A1, A2, V))
           bp = output.data.numpy()
           sa = sa_score
        outputs[query_protein+","+ str(key[0])+","+str(sa[0])+"\n"] = str(bp[0])
    return outputs


