from problem import *
import torch


problem = get_problem('molecular_discovery')
# problem.targets : stores embeddings for each protein target, of which there are 102
# problem.objectives : names of each of the 10 objectives

N = 10      # batch size
a = torch.randn(N, 32).cuda()       # compounds randomly sampled from latent space
c = problem.targets[torch.randperm(len(problem.targets))[:N]]       # randomly selected protein targets

x = torch.cat((a, c), dim=1)        # input to the surrogate model

print(problem.evaluate(x, 0).shape)     # output shape: (N, 10), each objective is in the range [0, 1]