
### 1. import
import numpy as np 
from tqdm import tqdm 
from matplotlib import pyplot as plt
import pickle 
from random import shuffle 
import torch
import torch.nn as nn
import torch.nn.functional as F
from tdc import Oracle
torch.manual_seed(1)
np.random.seed(2)
from chemutils import * 
'''
  chemutils 
	smiles2differentiable_graph
	differentiable_graph2smiles
'''



## 2. data and oracle
# smiles_lst = ['CC[NH+](CC)[C@](C)(CC)[C@H](O)c1cscc1Br', 
# 			  'COc1ccc(C(=O)N(C)[C@@H](C)C/C(N)=N/O)cc1O', 
# 			  'O=C(Nc1nc[nH]n1)c1cccnc1Nc1cccc(F)c1',    
# 			  'Cc1c(/C=N/c2cc(Br)ccn2)c(O)n2c(nc3ccccc32)c1C#N',
# 			  'C[C@@H]1CN(C(=O)c2cc(Br)cn2C)CC[C@H]1[NH3+]' 
# 			  ]
data_file = "data/clean_zinc.txt"
result_file = "result/DRD2.txt"
with open(data_file, 'r') as fin:
	lines = fin.readlines()
smiles_lst = [line.strip() for line in lines] 
smiles_lst = smiles_lst[1001:]
smiles = smiles_lst[0]
global num_oracle_call
num_oracle_call = 0 
drd2 = Oracle(name = 'drd2')
def oracle(smiles):
	global num_oracle_call
	num_oracle_call += 1 
	return drd2(smiles) 




## 3. load model 
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu' ## cpu is better 
model_ckpt = "save_model/DRD2_epoch_1_iter_78000_validloss_884.ckpt"
gnn = torch.load(model_ckpt)
gnn.switch_device(device)



## 4. inference function 
from inference_utils import optimize_single_molecule_all_generations


def calculate_results(input_smiles, input_score, result_file, best_mol_score_list, oracle, traceback_dict):
	if best_mol_score_list == []:
		with open(result_file, 'a') as fout:
			fout.write("fail to optimize" + input_smiles + '\n')
		return None 
	output_scores = [i[1] for i in best_mol_score_list]
	smiles_lst = [i[0] for i in best_mol_score_list]

	#### trace back to input smiles 
	trace_smiles = smiles_lst[0]
	trace_record = [trace_smiles]
	while trace_smiles in traceback_dict:
		trace_smiles = traceback_dict[trace_smiles]
		trace_record.append(trace_smiles)
	trace_record = '-->'.join(trace_record)

	global num_oracle_call 
	with open(result_file, 'a') as fout:
		fout.write(str(input_score) + '\t' + str(output_scores[0]) + '\t' + str(np.mean(output_scores[:3]))
				 + '\t' + input_smiles + '\t' + ' '.join(smiles_lst[:3]) + '\t' + str(num_oracle_call) + '\t' + trace_record + '\n')
	return input_score, output_scores[0]


def inference_single_molecule(input_smiles, gnn, result_file, generations, population_size):
	global num_oracle_call
	num_oracle_call = 0 
	best_mol_score_list, input_score, traceback_dict = optimize_single_molecule_all_generations(input_smiles, gnn, oracle, generations, population_size)
	return calculate_results(input_smiles, input_score, result_file, best_mol_score_list, oracle, traceback_dict)


def inference_molecule_set(input_smiles_lst, gnn, result_file, generations, population_size):
	score_lst = []
	for input_smiles in tqdm(input_smiles_lst):
		if not is_valid(input_smiles):
			continue 
		result = inference_single_molecule(input_smiles, gnn, result_file, generations, population_size)
		if result is None:
			score_lst.append(None)
		else:
			input_score, output_score = result
			score_lst.append((input_score, output_score))
	return score_lst





## 5. run 
if __name__ == "__main__":
	score_lst = inference_molecule_set(smiles_lst, gnn, result_file, generations = 3, population_size = 10)












