from torch.utils.data import DataLoader
from .EBA_critic_dataset import spMolDataset, spcollate_fn


def gnn_eba_scorer(gnn_experimental_affinity_model,gnn_device, protein_dir, query_protein, valid_mols):
    dataset = spMolDataset(protein_dir,
                           query_protein,
                           valid_mols)

    dataloader = DataLoader(dataset,1,
                            shuffle=False, num_workers = 7, collate_fn=spcollate_fn)

    gnn_experimental_affinity_model.eval()
    outputs = {}
    for batch, sample in enumerate(dataloader):
        gnn_experimental_affinity_model.zero_grad()
        H1, H2, A1, A2, V, sa_score, key = sample
        if key == 0:
            key = [0.0]
            sa = [0.0]
            ba = [0.0]
        else:
            H1, H2, A1, A2, V = H1.to(gnn_device), H2.to(gnn_device), A1.to(gnn_device), A2.to(gnn_device),V.to(gnn_device)
            output = gnn_experimental_affinity_model((H1, H2, A1, A2, V))
            ba = output.data.numpy()
            sa = sa_score
        outputs[query_protein+","+ str(key[0])+","+str(sa[0])+"\n"] = str(ba[0])
    return outputs


