from functools import partial
from arguments import parse_arguments
import torch.multiprocessing as mp
import torch
from tqdm import tqdm
from model.mol_graph import MolGraph
from datetime import datetime, timedelta
import rdkit.Chem as Chem
import os.path as path
import pickle
from model.mydataclass import PathTool
from typing import List, Union
import random

def process(
    raw_data: Union[List[str], List[MolGraph]],
    batch_size: int,
    save_path: str,
    num_workers: int,
    dev: bool = False,
):  

    random.shuffle(raw_data)
    if isinstance(raw_data[0], str):
        raw_data = sorted(raw_data, key=lambda smi: Chem.MolFromSmiles(smi).GetNumAtoms())
    else:
        raw_data = sorted(raw_data, key=lambda mol_graph: mol_graph.mol.GetNumAtoms())
    batches = [[] for _ in range(0, len(raw_data), batch_size)]
    batch_num = len(batches)
    for i, smi in enumerate(raw_data):
        batches[i % batch_num].append(smi)
    
    all_data = []
    func = partial(MolGraph.preprocess, dev=dev)
    with mp.Pool(num_workers) as pool:
        for data in tqdm(pool.imap(func, batches), total=len(batches), desc="Batches"):
            all_data.append(data)

    with open(save_path, "wb") as f:
        torch.save(all_data, f)

def generate_training_data(
    mols_pkl_path: str,
    valid_file: str,
    vocab_path: str,
    batch_size: int,
    train_processed_path: str,
    valid_processed_path: str,
    vocab_processed_path: str,
    num_workers: int,
):

    print(f"[{datetime.now()}] Preprocessing traing data.")
    print(f"Number of workers: {num_workers}. Total number of CPUs: {mp.cpu_count()}.\n")

    print(f"[{datetime.now()}] Preprocessing training set from {mols_pkl_path}.\n")
    process(
        raw_data = pickle.load(open(mols_pkl_path, "rb")),
        batch_size = batch_size,
        save_path = train_processed_path,
        num_workers = num_workers,
    )

    if path.exists(valid_file):
        print(f"[{datetime.now()}] Preprocessing valid set from {valid_file}.\n")
        process(
            raw_data = [smi.strip("\n") for smi in open(valid_file)],
            batch_size = batch_size,
            save_path = valid_processed_path,
            num_workers = num_workers,
            dev = True,
        )

    print(f"[{datetime.now()}] Preprocessing motif vocabulary from {vocab_path}.\n")
    vocab_data = MolGraph.preprocess_vocab()
    with open(vocab_processed_path, "wb") as f:
        torch.save(vocab_data, f)

    print(f"[{datetime.now()}] Preprocessing finished.\n\n")

if __name__ == "__main__":

    args = parse_arguments()
    pathtool  = PathTool.from_arguments(args)

    torch.multiprocessing.set_sharing_strategy("file_system")

    MolGraph.load_operations(pathtool.operation_path, args.num_operations)
    MolGraph.load_vocab(pathtool.vocab_path)

    generate_training_data(
        mols_pkl_path = pathtool.mols_pkl_path,
        valid_file = pathtool.valid_file,
        vocab_path = pathtool.vocab_path,
        batch_size = args.batch_size,
        train_processed_path = pathtool.train_processed_path,
        valid_processed_path = pathtool.valid_processed_path,
        vocab_processed_path = pathtool.vocab_processed_path,
        num_workers = args.num_workers,
    )
