from problem import *
import torch
import numpy as np

device_ind = 0
np.random.seed(42)
device = torch.device("cuda:" + str(device_ind) if torch.cuda.is_available() else "cpu")
problem = get_problem('molecular_discovery', device)
# problem.targets : stores embeddings for each protein target, of which there are 102
# problem.objectives : names of each of the 10 objectives

N = 10000     # batch size
a = np.random.randn(N, 32)     # compounds randomly sampled from latent space
# c = problem.targets.detach().cpu().numpy()     # randomly selected protein targets
# c1 = c[:1]
# print(c1.shape)

# c1 = np.repeat(c1, N, axis=0)
# print(c1.shape)
# print(a.shape)
# x = np.concatenate((a, c1), axis=1)       # input to the surrogate model
x = a

y = problem.predict(x)
print(x.min(), x.max())
print(y.min(), y.max())
print(x.shape)
print(y.shape)

np.save('design_baselines/diff/molecular_discovery/md_obj.npy', y)
np.save('design_baselines/diff/molecular_discovery/md_design.npy', x)