import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader

from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import Draw, AllChem

from html_utils import HtmlPageVisualizer

import math, random, sys, os
import numpy as np
import argparse
from tqdm import tqdm

import hgraph
from hgraph import *
import rdkit

def check_validity(generated_all_smiles):
    count = 0
    valid_mols = []
    for sm in generated_all_smiles:
        mol = Chem.MolFromSmiles(sm)
        if mol is not None:
            valid_mols.append(sm)
            count += 1
    return count, valid_mols

def check_unique(generated_all_smiles):
    return len(set(generated_all_smiles))

def check_novelty(generated_all_smiles, train_smiles):
    new_molecules = 0
    for sm in generated_all_smiles:
        if sm not in train_smiles:
            new_molecules += 1
    return new_molecules

lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

torch.manual_seed(42)

parser = argparse.ArgumentParser()
parser.add_argument('--vocab', required=True)
parser.add_argument('--atom_vocab', default=common_atom_vocab)
parser.add_argument('--model', required=True)

parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--nsample', type=int, default=10000)

parser.add_argument('--rnn_type', type=str, default='LSTM')
parser.add_argument('--hidden_size', type=int, default=250)
parser.add_argument('--embed_size', type=int, default=250)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--latent_size', type=int, default=32)
parser.add_argument('--depthT', type=int, default=15)
parser.add_argument('--depthG', type=int, default=15)
parser.add_argument('--diterT', type=int, default=1)
parser.add_argument('--diterG', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--mani_range', type=int, default=1)
parser.add_argument('--gpu', type=int, default=0)

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu)

vocab = [x.strip("\r\n ").split() for x in open(args.vocab)] 
args.vocab = PairVocab(vocab)

with open('./data/qm9/all.txt') as f:
    train_smiles = [line.strip("\r\n ") for line in f] 
print (len(train_smiles))

model = HierVAE(args).cuda()

model.load_state_dict(torch.load(args.model)[0])
model.eval()

torch.manual_seed(args.seed)
random.seed(args.seed)

# base_z = torch.randn(args.batch_size, args.latent_size)

# distances = np.linspace(-args.mani_range,args.mani_range,11)[2:9]
# generated_all_smiles = []
# total_generated_all_smiles = 0
# with torch.no_grad():
#     success_rate = []
#     for prop_name, function in Descriptors.descList:
#         print (prop_name,'/',len(Descriptors.descList))
#         prop_name = prop_name.split('_')[-1]
#         if not os.path.exists('./boundaries_sorted_'+str(args.mani_range)+'/'+prop_name):
#             os.makedirs('./boundaries_sorted_'+str(args.mani_range)+'/'+prop_name)
#         direction = np.load('./boundaries/boundary_'+prop_name+'.npy')
#         vizer = HtmlPageVisualizer(num_rows=1000+1,num_cols=2)
#         header_str = 'Distance -3\t -2\t -1\t center\t +1\t +2\t +3\t'
#         vizer.set_headers(['', header_str])
#         sample = base_z
#         dist_slide = []
#         label_slide = []
#         smile_slide = []
#         generated_prop_smiles = []
#         for i in range(args.batch_size):
#             dist_sub_slide = []
#             label_sub_slide = []
#             for d in distances:
#                 sample_d = sample + d * direction
#                 sample_d = sample_d.cuda()
#                 try:
#                     smiles_list = model.sample(1, greedy=True, direction=sample_d)
#                 except:
#                     continue
#                 generated_all_smiles.extend(smiles_list)
#                 generated_prop_smiles.extend(smiles_list)
#                 dist_sub_slide.append(smiles_list[0])
#                 label_slide_one = []
#                 # for sm in smiles_list:
#                 label_slide_one = function(Chem.MolFromSmiles(smiles_list[0]))
#                 label_sub_slide.append(label_slide_one)
#             label_slide.append(label_sub_slide)
#             dist_slide.append(dist_sub_slide)
#         label_slide = np.array(label_slide)
#         dist_slide = np.array(dist_slide)
#         total_generated_all_smiles += args.batch_size * len(distances)
#         success = 0
#         for i in range(args.batch_size):
#             print (i,'th',len(label_slide), len(label_slide[:,i]))
#             sucess_checker = label_slide[i,:].tolist()
#             if len(sucess_checker) == 7 and (all(sucess_checker[idx] <= sucess_checker[idx+1] for idx in range(len(sucess_checker)-1)) or all(sucess_checker[idx] >= sucess_checker[idx+1] for idx in range(len(sucess_checker)-1))) and len(set(sucess_checker)) != 1:
#                 success += 1
#                 labels_for_success = ['{:.2f}'.format(label_score) for mol, label_score in zip(dist_slide[:,i], label_slide[:,i])]
#                 print (str(i),'success',str(success)+'/'+str(args.batch_size))
#                 smile_slide = [Chem.MolFromSmiles(sms) for sms in dist_slide[i,:]]
#                 img = Draw.MolsToGridImage(smile_slide, legends=labels_for_success, molsPerRow=7,
#                                     subImgSize=(200,200))
#                 img.save('./boundaries_sorted_'+str(args.mani_range)+'/'+prop_name+'/'+str(i)+'_'+prop_name+'.png')
#                 vizer.set_cell(i, 1, image=np.array(img))
#                 vizer.save('./boundaries_sorted_'+str(args.mani_range)+'/'+prop_name+'.html')
#         s_rate = (success/args.batch_size)*100
#         print (prop_name, 'sucess rate',s_rate)
#         validity = check_validity(generated_prop_smiles)
#         print ('validity', validity, '/', args.batch_size * len(distances))
#         novelty = check_novelty(generated_prop_smiles,train_smiles)
#         print ('novelty', novelty, '/', args.batch_size * len(distances))
#         uniqueness = check_unique(generated_prop_smiles)
#         print ('uniqueness', uniqueness, '/', args.batch_size * len(distances))
#         success_rate.append(s_rate)
#     print ('total success rate', np.mean(success_rate))
#     validity = check_validity(generated_all_smiles)
#     print ('validity', validity, '/', total_generated_all_smiles)
#     novelty = check_novelty(generated_all_smiles,train_smiles)
#     print ('novelty', novelty, '/', total_generated_all_smiles)
#     uniqueness = check_unique(generated_all_smiles)
#     print ('uniqueness', uniqueness, '/', total_generated_all_smiles)

# exit(0)

# base_z = torch.randn(20, args.latent_size)
# distances = np.linspace(-args.mani_range,args.mani_range,11)
# cell_count = 0
# with torch.no_grad():
#     for prop_name, function in Descriptors.descList:
#         prop_name = prop_name.split('_')[-1]
#         if not os.path.exists('./boundaries_sorted_'+str(args.mani_range)+'/'+prop_name):
#             os.makedirs('./boundaries_sorted_'+str(args.mani_range)+'/'+prop_name)
#         direction = np.load('./boundaries/boundary_'+prop_name+'.npy')
#         vizer = HtmlPageVisualizer(num_rows=1000,num_cols=2)
#         header_str = 'Distance -3\t -2\t -1\t center\t +1\t +2\t +3\t'
#         vizer.set_headers(['', header_str])
#         for i in tqdm(range(args.nsample // args.batch_size)):
#             sample = base_z[i]
#             dist_slide = []
#             label_slide = []
#             smile_slide = []
#             for d in distances:
#                 sample_d = sample + d * direction
#                 sample_d = sample_d.cuda()
#                 smiles_list = model.sample(args.batch_size, greedy=True, direction=sample_d)
#                 dist_slide.append(smiles_list[0])
#                 s = Chem.MolFromSmiles(smiles_list[0])
#                 label_slide.append(function(s))
#                 smile_slide.append(s)
#             print (len(dist_slide))
#             labels_for_success = [mol+'\n'+'{:.2f}\n'.format(label_score) for mol, label_score in zip(dist_slide, label_slide)]
#             img = Draw.MolsToGridImage(smile_slide, legends=labels_for_success, molsPerRow=11,
#                                 subImgSize=(200,200))
#             img.save('./boundaries_sorted_'+str(args.mani_range)+'/'+prop_name+'/'+str(i)+'_'+prop_name+'.png')
#             cell_count += 1
#             vizer.set_cell(cell_count, 1, image=np.array(img))
#             vizer.save('./boundaries_sorted_'+str(args.mani_range)+'/'+prop_name+'.html')
                # for i in range(len(smiles_list)):
                #     generated_latent.append(latent_z[i].cpu().detach().numpy())
                #     s = Chem.MolFromSmiles(smiles_list[i])
                #     props = []
                #     for descriptor, function in Descriptors.descList:
                #         props.append(function(s))
                #     generated_prop.append(props)
                # for _,smiles in enumerate(smiles_list):
                #     print(smiles)

generated_latent = []
generated_prop = []
generated_all_smiles = []
with torch.no_grad():
    for _ in tqdm(range(args.nsample // args.batch_size)):
        smiles_list, latent_z = model.sample(args.batch_size, greedy=True)
        generated_all_smiles.extend(smiles_list)
        for i in range(len(smiles_list)):
            generated_latent.append(latent_z[i].cpu().detach().numpy())
            s = Chem.MolFromSmiles(smiles_list[i])
            props = []
            for descriptor, function in Descriptors.descList:
                props.append(function(s))
            generated_prop.append(props)
        for _,smiles in enumerate(smiles_list):
            print(smiles)

generated_latent = np.array(generated_latent)
generated_prop = np.array(generated_prop)
generated_all_smiles = np.array(generated_all_smiles)
np.save('./qm9_saved_latent/smiles.npy',generated_all_smiles)
np.save('./qm9_saved_latent/z.npy',generated_latent)
np.save('./qm9_saved_latent/prop.npy',generated_prop)
validity = check_validity(generated_all_smiles)
print ('validity', validity, '/', len(generated_all_smiles))
novelty = check_novelty(generated_all_smiles,train_smiles)
print ('novelty', novelty, '/', len(generated_all_smiles))
uniqueness = check_unique(generated_all_smiles)
print ('uniqueness', uniqueness, '/', len(generated_all_smiles))