import pickle
from arguments import parse_arguments
from model.mol_graph import MolGraph
import multiprocessing as mp
from tqdm import tqdm
from collections import Counter
from datetime import datetime
from typing import List, Tuple
from model.mydataclass import PathTool


def apply_operations(batch: List[str]) -> Tuple[List[MolGraph], Counter]:
    mols, vocab = [], Counter()
    pos = mp.current_process()._identity[0]
    with tqdm(total = len(batch), desc=f"Processing {pos}", position=pos-1, ncols=80, leave=False) as pbar:
        for smi in batch:
            mol = MolGraph(smi, process_level="motif")
            mols.append(mol)
            vocab = vocab + Counter(mol.motifs)
            pbar.update()
    return mols, vocab

def motif_vocab_construction(
    train_path: str,
    vocab_path: str,
    operation_path: str,
    num_operations: int,
    num_workers: int,
    mols_pkl_path: str,
    log_path: str,
):

    print(f"[{datetime.now()}] Construcing motif vocabulary from {train_path}.")
    print(f"Number of workers: {num_workers}. Total number of CPUs: {mp.cpu_count()}.")

    data_set = [smi.strip("\n") for smi in open(train_path)]
    batch_size = (len(data_set) - 1) // num_workers + 1
    batches = [data_set[i : i + batch_size] for i in range(0, len(data_set), batch_size)]
    print(f"Total: {len(data_set)} molecules.\n")

    print(f"Processing...")
    mols, vocab = [], Counter()
    MolGraph.load_operations(operation_path, num_operations)
    with mp.Pool(num_workers, initializer=tqdm.set_lock, initargs=(mp.RLock(),)) as pool:
        for batch_mols, batch_vocab in pool.imap(apply_operations, batches):
            mols.extend(batch_mols)
            vocab = vocab + batch_vocab

    atom_list = [x for (x, _) in vocab.keys() if x not in MolGraph.OPERATIONS]
    atom_list.sort()
    new_vocab = []
    whole_list = atom_list + MolGraph.OPERATIONS
    for (x, y), value in vocab.items():
        assert x in whole_list, print(f"Error: ({x}, {y}).")
        new_vocab.append((x, y, value))
        
    index_dict = dict(zip( whole_list, range(len(whole_list))))
    sorted_vocab = sorted(new_vocab, key=lambda x: index_dict[x[0]])
    with open(vocab_path, "w") as f:
        for (x, y, _) in sorted_vocab:
            f.write(f"{x} {y}\n")
    
    print(f"\r[{datetime.now()}] Motif vocabulary construction finished.")
    print(f"The motif vocabulary is in {vocab_path}.\n\n")

    with open(mols_pkl_path, "wb") as f:
        pickle.dump(mols, f)

    with open(log_path, "w") as f:
        for (x, y, z) in sorted_vocab:
            f.write(f"{x} {y} {z}\n")

    return mols, sorted_vocab

if __name__ == "__main__":

    args = parse_arguments()
    pathtool = PathTool.from_arguments(args)

    mols, vocab = motif_vocab_construction(
        train_path = pathtool.train_file,
        vocab_path = pathtool.vocab_path,
        operation_path = pathtool.operation_path,
        num_operations = args.num_operations,
        num_workers = args.num_workers,
        mols_pkl_path = pathtool.mols_pkl_path,
        log_path = pathtool.vocab_construct_log_path
    )

    
    
    