import argparse
from arguments import parse_arguments
import torch
from model.mol_graph import MolGraph
import os.path as path
import multiprocessing as mp
from model.mydataclass import PathTool
from merging_operation_learning import merging_operation_learning
from motif_vocab_construction import motif_vocab_construction
from generate_training_data import generate_training_data

def preprocess(
    path_tool: PathTool,
    num_operations: int,
    batch_size: int,
    merging_operation_learning_num_iters: int = 3000,
    merging_operation_learning_min_frequecy: int = 10,
    merging_operation_learning_mp_threashold: int = 1e5,
    num_workers: int = mp.cpu_count(),
):
    if not path.exists(path_tool.operation_path):
        learning_trace = merging_operation_learning(
            train_path = path_tool.train_file,
            operation_path = path_tool.operation_path,
            num_iters = merging_operation_learning_num_iters,
            min_frequency = merging_operation_learning_min_frequecy,
            num_workers = num_workers,
            mp_threshold = merging_operation_learning_mp_threashold,
            log_path = path_tool.operation_learning_log_path
        )

    MolGraph.load_operations(path_tool.operation_path, num_operations)

    if not path.exists(path_tool.vocab_path):
        mols, vocab = motif_vocab_construction(
            train_path = path_tool.train_file,
            vocab_path = path_tool.vocab_path,
            operation_path = path_tool.operation_path,
            num_operations = num_operations,
            num_workers = num_workers,
            mols_pkl_path = path_tool.mols_pkl_path,
            log_path = path_tool.vocab_construct_log_path,
        )
    
    MolGraph.load_vocab(path_tool.vocab_path)
    
    torch.multiprocessing.set_sharing_strategy("file_system")
    generate_training_data(
        mols_pkl_path = path_tool.mols_pkl_path,
        valid_file = path_tool.valid_file,
        vocab_path = path_tool.vocab_path,
        batch_size = batch_size,
        train_processed_path = path_tool.train_processed_path,
        valid_processed_path = path_tool.valid_processed_path,
        vocab_processed_path = path_tool.vocab_processed_path,
        num_workers = num_workers,
    )

if __name__ == "__main__":

    args = parse_arguments()
    path_tool = PathTool.from_arguments(args)

    preprocess(
        path_tool = path_tool,
        num_operations = args.num_operations,
        batch_size = args.batch_size,
        merging_operation_learning_num_iters = args.num_iters,
        merging_operation_learning_min_frequecy = args.min_frequency,
        merging_operation_learning_mp_threashold = args.mp_thd,
        num_workers = args.num_workers
    )