import torch, random
from arguments import parse_arguments
from typing import List
import os.path as path

from model.mol_graph import MolGraph
from model.CQVAE import CQ_VAE
from model.benchmarks import QuickBenchGenerator, QuickBenchmark

from torch.utils.tensorboard import SummaryWriter

from guacamol.assess_distribution_learning import assess_distribution_learning
from guacamol.distribution_matching_generator import DistributionMatchingGenerator
from guacamol.distribution_learning_benchmark import ValidityBenchmark, UniquenessBenchmark, NoveltyBenchmark, KLDivBenchmark
from guacamol.frechet_benchmark import FrechetBenchmark
from guacamol.utils.chemistry import is_valid
import torch.multiprocessing as mp
from datetime import datetime

from model.mydataclass import ModelParams, PathTool
from typing import Optional

def benchmark(
    model_params: ModelParams,
    pathtool: PathTool,
    num_samples: int,
    num_workers: int,
    return_results: bool=False
):  
    print(f"[{datetime.now()}] Benchmarking...")

    generator = CQ_VAE.load_generator(model_params, pathtool)
    generator = QuickBenchGenerator(generator, number_samples=num_samples)
    
    with open(pathtool.generate_path, "w") as f:
        for smi in generator.molecules:
            f.write(f'{smi}\n')

    if return_results:
        benchmarks = QuickBenchmark(training_set=[smi.strip('\n') for smi in open(pathtool.train_file)], num_samples=10000)
        return benchmarks.assess_model(generator)

    else:
        assess_distribution_learning(
            generator,
            chembl_training_file=pathtool.train_file,
            json_output_file=pathtool.json_output_path,
            )
        return None

if __name__ == '__main__':
    
    args = parse_arguments()
    model_params = ModelParams.from_arguments(args)
    pathtool = PathTool.from_arguments(args)

    mp.set_start_method("spawn")
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    benchmark(
        model_params = model_params,
        pathtool = pathtool,
        num_samples = args.num_samples,
        num_workers = args.num_workers
    )