'''
start from 'C'


'''

### 1. import
import os
import numpy as np 
from time import time
from tqdm import tqdm 
from matplotlib import pyplot as plt
import pickle 
from random import shuffle 
import torch
import torch.nn as nn
import torch.nn.functional as F
from tdc import Oracle
torch.manual_seed(1)
np.random.seed(2)
from tdc import Evaluator
from chemutils import * 

## 2. data and oracle 
qed = Oracle(name = 'qed')
logp = Oracle(name = 'logp')
jnk = Oracle(name = 'JNK3')
gsk = Oracle(name = 'GSK3B')
def foracle(smiles):
	# scores = qed(smiles), logp(smiles), jnk(smiles), gsk(smiles)
	scores = qed(smiles), logp(smiles), max(0.02,jnk(smiles)), max(0.02,gsk(smiles)) 
	return qed_logp_jnk_gsk_fusion(*scores)







file = "result/modify_qedlogpjnkgsk.txt"
with open(file, 'r') as fin:
	lines = fin.readlines() 
from tdc import Oracle
oracle_dic = {'QED':qed, 'LogP':logp, 'JNK3':jnk, 'GSK3B':gsk}
thre_dic = {'QED':0.6, 'LogP':0, 'JNK3':0.2, 'GSK3B':0.2}
f_improve_lst = []
prop2improvelst = {}
prop2value = {}
sim_lst = []
for prop in oracle_dic:
	prop2improvelst[prop]=[]
	prop2value[prop] = []

success_num, total_num = 0,0
for line in lines[-200:]:
	input_smiles, input_score, output_smiles, output_score = line.strip().split() 
	f_improve = float(output_score) - float(input_score)
	if f_improve < 0.1:
		continue 
	f_improve_lst.append(f_improve)
	sim = similarity(input_smiles, output_smiles)
	sim_lst.append(sim)
	input_score = input_score[:4]
	output_score = output_score[:4]
	print("input", input_smiles, input_score)
	print("output", output_smiles, output_score)
	print("similarity", str(sim)[:5])
	success = True
	for name, oracle in oracle_dic.items():
		v1,v2 = oracle(input_smiles), oracle(output_smiles)
		print('\t'+name, str(v1)[:4], str(v2)[:4])
		prop2improvelst[name].append(v2-v1)
		prop2value[name].append(v2)
		if success and thre_dic[name] > v2:
			success = False
	total_num += 1
	if success:
		success_num += 1


print('========= f improve =========', str(np.mean(f_improve_lst))[:4], str(np.std(f_improve_lst))[:4])
print('success rate', success_num/total_num)
for prop,improve_lst in prop2improvelst.items():
	value_lst = prop2value[prop]
	print(prop, 'improve' ,str(np.mean(improve_lst))[:4], str(np.std(improve_lst))[:4])
	print(prop ,str(np.mean(value_lst))[:4], str(np.std(value_lst))[:4])








'''
input CC[C@@](C)(NC(=O)[C@@H]1CCS(=O)(=O)C1)C(=O)[O-] 0.28
output CCC(C)(NC(=O)C1CCS(=O)(=O)C1)C(=O)N1CC=CCC1 0.35
similarity 0.588
	QED 0.66 0.76
	LogP -3.5 -2.0
	JNK3 0.02 0.04
	GSK3B 0.01 0.05


input [NH3+][C@H]1CCCC[C@@H]1NC(=O)Cc1ccsc1 0.35
output O=C(Cc1ccsc1)NC1CCC(=C2C=CCCN2)CC1 0.43
similarity 0.438
	QED 0.81 0.90
	LogP -1.6 0.62
	JNK3 0.02 0.05
	GSK3B 0.0 0.05


'''


