from arguments import parse_arguments
import torch, random
from model.CQVAE import CQ_VAE
from model.mol_graph import MolGraph
import sys
import os.path as path
import rdkit.Chem as Chem
from guacamol.utils.chemistry import is_valid
import torchvision
from rdkit.Chem import Draw
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

from model.mydataclass import ModelParams, PathTool
import torch.multiprocessing as mp
import os

def draw_mols(mols, smis, job_name, tb:SummaryWriter):
    molsPerRow = 6
    subImgSize = (200, 200)
    nRows = len(mols) // molsPerRow
    if len(mols) % molsPerRow:
        nRows += 1

    res = Image.new("RGBA", (molsPerRow * subImgSize[0], nRows * subImgSize[1]), (255, 255, 255, 0))
    for i, mol in enumerate(mols):
        row = i // molsPerRow
        col = i % molsPerRow
        if mol != None:
            img = Draw.MolToImage(mol, size=subImgSize, legend=smis[i])
            res.paste(img, (col * subImgSize[0], row * subImgSize[1]))

    res = torchvision.transforms.ToTensor()(res)
    res = torchvision.utils.make_grid(res)
    assert res is not None
    tb.add_image(job_name, res)

if __name__ == '__main__':
    
    args = parse_arguments()

    model_params = ModelParams.from_arguments(args)
    pathtool = PathTool.from_arguments(args)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    
    mp.set_start_method("spawn")
    generator = CQ_VAE.load_generator(model_params, pathtool)
    samples = generator.generate(args.num_sample, args.num_workers)
