import os
import numpy as np 
from time import time
from tqdm import tqdm 
from matplotlib import pyplot as plt
import pickle 
from random import shuffle 
import torch
import torch.nn as nn
import torch.nn.functional as F
from tdc import Oracle
torch.manual_seed(1)
np.random.seed(2)
import random 
random.seed(1)
from chemutils import * 
'''
chemutils 
	smiles2differentiable_graph
	differentiable_graph2smiles
	qed_logp_jnk_gsk_fusion
'''
from tdc import Evaluator


start_smiles_lst = ['C1(N)=NC=CC=N1']
## 'C1=CC=CC=C1NC2=NC=CC=N2'
orac = Oracle(name = 'GSK3B')
def oracle(smiles):
	return orac(smiles)

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu' ## cpu is better 
prop = 'gsk'
# model_ckpt = "save_model/JNK3_epoch_0_iter_93000_validloss_0.0772.ckpt"
model_ckpt = "save_model0/qed_logp_jnk_gsk_epoch_4_iter_14000_validloss_9715.ckpt"
gnn = torch.load(model_ckpt)
gnn.switch_device(device)



## 4. inference function 
from inference_utils import * 





def distribution_learning(start_smiles_lst, gnn, oracle, generations, population_size, lamb, topk, epsilon, result_pkl):
	trace_dict = dict() 
	existing_set = set(start_smiles_lst)  
	current_set = set(start_smiles_lst)
	average_f = np.mean([oracle(smiles) for smiles in current_set])
	f_lst = [(average_f, 0.0)]
	idx_2_smiles2f = {}
	smiles2f_new = {smiles:oracle(smiles) for smiles in start_smiles_lst} 
	idx_2_smiles2f[-1] = smiles2f_new, current_set 
	for i_gen in tqdm(range(generations)):
		next_set = set()
		for smiles in current_set:
			# smiles_set = optimize_single_molecule_one_iterate(smiles, gnn)  ### 
			if substr_num(smiles) < 3: #### short smiles
				smiles_set = optimize_single_molecule_one_iterate(smiles, gnn)  ### optimize_single_molecule_one_iterate_v2
			else:
				smiles_set = optimize_single_molecule_one_iterate_nonleafexpand(smiles, gnn, topk = topk, epsilon = epsilon)
			for smi in smiles_set:
				if smi not in trace_dict:
					trace_dict[smi] = smiles 
			next_set = next_set.union(smiles_set)
		# next_set = next_set.difference(existing_set)   ### if allow repeat molecule  
		smiles_score_lst = oracle_screening(next_set, oracle)  ###  sorted smiles_score_lst 
		print(smiles_score_lst[:12])

		# current_set = [i[0] for i in smiles_score_lst[:population_size]]  # Option I: top-k 
		current_set,_,_ = dpp(smiles_score_lst = smiles_score_lst, num_return = population_size, lamb = lamb) 	# Option II: DPP
		existing_set = existing_set.union(next_set)

		# save 
		smiles2f_new = {smiles:score for smiles,score in smiles_score_lst} 
		idx_2_smiles2f[i_gen] = smiles2f_new, current_set 
		pickle.dump((idx_2_smiles2f, trace_dict), open(result_pkl, 'wb'))

		#### compute f-score
		score_lst = [smiles2f_new[smiles] for smiles in current_set] 
		average_f = np.mean(score_lst)
		std_f = np.std(score_lst)
		f_lst.append((average_f, std_f))
		str_f_lst = [str(i[0])[:5]+'\t'+str(i[1])[:5] for i in f_lst]
		with open("result/denovo_" + prop + "_f_t.txt", 'w') as fout:
			fout.write('\n'.join(str_f_lst))





## 5. run 
if __name__ == "__main__":
	generations = 100
	population_size = 10
	# result_file = "result/denovo_from_" + start_smiles_lst[0] + "_generation_" + str(generations) + "_population_" + str(population_size) + ".pkl"
	result_pkl = "result/denovo_" + prop + ".pkl"
	distribution_learning(start_smiles_lst, gnn, oracle, 
							generations = generations, 
							population_size = population_size, 
							lamb=2, 
							topk = 5, 
							epsilon = 0.7, 
							result_pkl = result_pkl) 



"""
                 JNK3                  GSK3B                            LogP
         Nov   Div   SR    #oracle   
DST-old 97.8% 0.721 92.8%  10K+5K  98.3% 0.732 91.0%  10K+5K  100.0% 0.704 100.0% 10K+5K  	
DST-new	97.5% 0.732	94.0%  10K+5K  98.5% 0.748 92.8%  10K+5K  100.0% 0.747 100.0% 10K+5K  


   JNK3  Nov   Div   SR    #oracle   
DST-old 97.8% 0.721 92.8%  10K+5K
DST-new	97.5% 0.732	94.0%  10K+5K

  GSK3B   Nov   Div  SR    #oracle   
DST-old  98.3% 0.732 91.0%  10K+5K 
DST-new	 98.5% 0.748 92.8%  10K+5K 
          
  LogP    Nov   Div   SR   #oracle	
DST-old 100.0% 0.704 100.0% 10K+5K  	
DST-new	100.0% 0.747 100.0% 10K+5K  


               JNK3+GSK3B            QED+SA+JNK3+GSK3B  
         Nov   Div   SR  #oracle  Nov   Div   SR  #oracle   
DST-old 97.5% 0.721 90.9% 10K+5K 98.9% 0.729 78.6% 20K+5K
DST-new	97.3% 0.745 92.0% 10K+5K 96.4% 0.774 81.3% 20K+5K



JNK3+GSK3B       Nov   Div   SR  #oracle
DST-linear      97.5% 0.721 90.9% 10K+5K
DST-nonlineaer	97.5% 0.745 92.0% 10K+5K

QED+SA+JNK3+GSK3B  Nov   Div   SR  #oracle   
  DST-linear      98.9% 0.729 78.6% 20K+5K
  DST-nonlinear	  98.8% 0.774 79.8.% 20K+5K



"""






