import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader

from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import Draw, AllChem

from html_utils import HtmlPageVisualizer

import math, random, sys, os
import numpy as np
import argparse
from tqdm import tqdm

import hgraph
from hgraph import *
import rdkit

def check_validity(generated_all_smiles):
    count = 0
    for sm in generated_all_smiles:
        mol = Chem.MolFromSmiles(sm)
        if mol is not None:
            count += 1
    return count

def check_unique(generated_all_smiles):
    return len(set(generated_all_smiles))

def check_novelty(generated_all_smiles, train_smiles):
    new_molecules = 0
    for sm in generated_all_smiles:
        if sm not in train_smiles:
            new_molecules += 1
    return new_molecules

lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

torch.manual_seed(42)

parser = argparse.ArgumentParser()
parser.add_argument('--vocab', required=True)
parser.add_argument('--atom_vocab', default=common_atom_vocab)
parser.add_argument('--model', required=True)

parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--nsample', type=int, default=10000)

parser.add_argument('--rnn_type', type=str, default='LSTM')
parser.add_argument('--hidden_size', type=int, default=250)
parser.add_argument('--embed_size', type=int, default=250)
parser.add_argument('--batch_size', type=int, default=7)
parser.add_argument('--latent_size', type=int, default=32)
parser.add_argument('--depthT', type=int, default=15)
parser.add_argument('--depthG', type=int, default=15)
parser.add_argument('--diterT', type=int, default=1)
parser.add_argument('--diterG', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--mani_range', type=int, default=1)
parser.add_argument('--gpu', type=int, default=0)

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu)

vocab = [x.strip("\r\n ").split() for x in open(args.vocab)] 
args.vocab = PairVocab(vocab)

with open('./data/zinc/all.txt') as f:
    train_smiles = [line.strip("\r\n ") for line in f] 
print (len(train_smiles))

model = HierVAE(args).cuda()

model.load_state_dict(torch.load(args.model)[0])
model.eval()

torch.manual_seed(args.seed)
random.seed(args.seed)
base_z = torch.randn(1000, args.latent_size)

distances = np.linspace(-args.mani_range,args.mani_range,11)[2:9]
prop_names = np.load('./prop_name.npy')
with torch.no_grad():
    for prop_name in tqdm(prop_names):
        if not os.path.exists('./hvae_zinc250k_manipulation_'+str(args.mani_range)+'/'+prop_name):
            os.makedirs('./hvae_zinc250k_manipulation_'+str(args.mani_range)+'/'+prop_name)
        print (prop_name,'/',len(Descriptors.descList))
        prop_name = prop_name.split('_')[-1]
        
        direction = np.load('./zinc_boundaries/boundary_'+prop_name+'.npy')

        generated_prop_smiles = []
        for i in tqdm(range(1000)):
            sample = base_z[i]
            dist_sub_slide = []
            label_sub_slide = []
            samples = []
            for d in distances:
                sample_d = sample + d * direction
                samples.append(sample_d)
            samples = torch.cat(samples).cuda()
            try:
                smiles_list = model.sample(7, greedy=True, direction=samples)
            except:
                continue
            if smiles_list != []:
                np.save(open('./hvae_zinc250k_manipulation_'+str(args.mani_range)+'/'+prop_name+'/smiles_'+str(i)+'.npy','wb'),np.array(smiles_list))

