import numpy as np 
import matplotlib.pyplot as plt
import pickle 
from random import shuffle 
import matplotlib.cm as cm
import torch 
from tqdm import tqdm
from tdc import Oracle
from chemutils import is_valid


props = ['jnk', 'gsk3', 'qed', 'jnkgsk','qedsajnkgsk']
props = ['logp']
for prop in props:
	if prop in ['jnk', 'gsk3', 'qed', 'logp']:
		oracle = Oracle(name = prop)
	if prop == 'jnkgsk':
		o1 = Oracle('jnk')
		o2 = Oracle('gsk3')
		def oracle(smiles):
			return [(i+j)/2 for i,j in zip(o1(smiles), o2(smiles))]

	if prop == 'qedsajnkgsk':
		qed = Oracle(name = 'qed')
		jnk = Oracle('jnk3')
		gsk = Oracle('gsk3b')
		from sa import sa 
		def oracle(smiles):
			scores = [qed(smiles), sa(smiles), jnk(smiles), gsk(smiles)]
			return np.mean(scores)

	if prop == 'logp':
		pkl_file = "result/denovo_from_CC_logp2.pkl"
	else:
		pkl_file = "result/denovo_"+prop+".pkl"

	idx_2_smiles2f, trace_dict = pickle.load(open(pkl_file, 'rb'))
	generated_smiles_set = set()
	idx2stat = {}
	for idx,x in tqdm(idx_2_smiles2f.items()):
		smiles2f, current_set = x 
		current_set = list(current_set)
		# current_f = [smiles2f[smiles] for smiles in current_set]
		current_f = list(smiles2f.values())

		scores = oracle(current_set)
		idx2stat[idx] = np.mean(current_f), np.std(current_f), \
						np.mean(scores), np.std(scores)

	sort_idx_lst = list(idx_2_smiles2f.keys())
	sort_idx_lst.sort()
	sort_idx_lst = sort_idx_lst[:50]
	sort_stats = [idx2stat[idx] for idx in sort_idx_lst]


	avg_list = [stat[2] for stat in sort_stats]
	plt.plot(list(range(len(avg_list))), avg_list)
	plt.xlabel("iteration", fontsize = 18)
	plt.ylabel("Average objective value", fontsize = 19)
	plt.savefig("figure/curve_" + prop + '.png')
	plt.cla()


