import argparse
import time
import torch.nn as nn
import torch
from collections import OrderedDict
from .EBA_critic_dataset import spMolDataset, spcollate_fn
from torch.utils.data import DataLoader
import logging
import numpy as np

def gnn_eba_scorer(model, protein_dir, query_protein, valid_mols):
    dataset = spMolDataset(protein_dir,
                           query_protein,
                           valid_mols)

    ebas = []
    dataloader = DataLoader(dataset, len(valid_mols),
                            shuffle=False, num_workers = 7, collate_fn=spcollate_fn)

    device = torch.device("cpu")
    model.eval()

    for batch, sample in enumerate(dataloader):
        model.zero_grad()
        H1, H2, A1, A2, V, sa_score,key = sample
        if key == 0:
            ebas.append(np.array(0))
        else:
           H1, H2, A1, A2, V = H1.to(device), H2.to(device), A1.to(device), A2.to(device),V.to(device)
           output = model((H1, H2, A1, A2, V))
           ebas.append(output.data.numpy())

        
    return ebas



