import os
import numpy as np 
from time import time
from tqdm import tqdm 
from matplotlib import pyplot as plt
import pickle 
from random import shuffle 
import torch
import torch.nn as nn
import torch.nn.functional as F
from tdc import Oracle
torch.manual_seed(1)
np.random.seed(2)
import random 
from chemutils import * 
'''
chemutils 
	smiles2differentiable_graph
	differentiable_graph2smiles
	qed_logp_jnk_gsk_fusion
'''
from tdc import Evaluator


data_file = "data/clean_zinc.txt"
with open(data_file, 'r') as fin:
	lines = fin.readlines()
smiles_lst = [line.strip() for line in lines] 
random.shuffle(smiles_lst)
smiles_lst = smiles_lst[1200:]
random.seed(1)


qed = Oracle(name = 'qed')
logp = Oracle(name = 'logp')
jnk = Oracle(name = 'JNK3')
gsk = Oracle(name = 'GSK3B')

def logp_modifier(logp_score):
    return max(0.0,min(1.0,1/8*(logp_score+4))) 

def oracle(smiles):
	# scores = qed(smiles), logp(smiles), max(0.02,jnk(smiles)), max(0.02,gsk(smiles)) 
	scores = jnk(smiles), gsk(smiles) 
	return np.mean(scores)

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu' ## cpu is better 
prop = 'jnkgsk'
model_ckpt = "save_model/jnkgsk_epoch_2_iter_68000_validloss_1329.ckpt"

gnn = torch.load(model_ckpt)
gnn.switch_device(device)
from inference_utils import * 


def distribution_learning(input_smiles, gnn, oracle, generations, population_size, lamb, topk, epsilon, result_file):
	trace_dict = dict() 
	current_set = set([input_smiles])
	input_score = oracle(input_smiles)
	scores = [input_score]
	print("input", input_smiles, input_score)
	best_score, best_smiles = input_score, input_smiles 
	for i_gen in tqdm(range(generations)):
		next_set = set()
		for smiles in current_set:
			# smiles_set = optimize_single_molecule_one_iterate(smiles, gnn)  ### 
			if substr_num(smiles) < 3: #### short smiles
				smiles_set = optimize_single_molecule_one_iterate(smiles, gnn)  ### optimize_single_molecule_one_iterate_v2
			else:
				smiles_set = optimize_single_molecule_one_iterate_v3(smiles, gnn, topk = topk, epsilon = epsilon)
			for smi in smiles_set:
				if smi not in trace_dict:
					trace_dict[smi] = smiles 
			next_set = next_set.union(smiles_set)
		smiles_score_lst = oracle_screening(next_set, oracle)  ###  sorted smiles_score_lst 
		if len(smiles_score_lst)==0:
			continue 
		print(smiles_score_lst[:3])
		smi, score = smiles_score_lst[0]
		if score > best_score: 
			best_smiles, best_score = smi, score 

		# current_set = [i[0] for i in smiles_score_lst[:population_size]]  # Option I: top-k 
		current_set,_,_ = dpp(smiles_score_lst = smiles_score_lst, num_return = population_size, lamb = lamb) 	# Option II: DPP

	# save
	with open(result_file, 'a') as fout:
		fout.write(input_smiles + '\t' + str(input_score)[:5] + '\t' + best_smiles + '\t' + str(best_score)[:5] + '\n')


if __name__ == "__main__":
	generations = 1
	population_size = 10
	# result_file = "result/denovo_from_" + start_smiles_lst[0] + "_generation_" + str(generations) + "_population_" + str(population_size) + ".pkl"
	result_file = "result/modify_" + prop + ".txt"
	for smiles in smiles_lst:
		distribution_learning(smiles, gnn, oracle, 
							generations = generations, 
							population_size = population_size, 
							lamb=2, 
							topk = 10, 
							epsilon = 0.7, 
							result_file = result_file) 










