from tqdm import tqdm
import pickle
import concurrent.futures
import sys
from model.utils import ShinglingEncoder
from rdkit import RDLogger

RDLogger.DisableLog("rdApp.*")


def process_one(one, omit):
    mappings = ShinglingEncoder.encode([one[0]], show_progress_bar=False, atom_index_mapping=True, root_central_atom=False, omit=omit)[0]
    return one[0], mappings

if __name__ == '__main__':
    with open("/amax/data/yield_data/pretraining_data/reactions_with_multiple.pkl", "rb") as f:
        data_type = pickle.load(f)
    print(f"reactions nums: {len(data_type)}")

    if len(sys.argv[1:]) == 2:
        start_id, end_id = int(sys.argv[1]), int(sys.argv[2])
    else:
        start_id, end_id = 0, 4000000

    data_type = sorted(data_type, key=lambda x: len(x[0]))[start_id:end_id]
    omit = False
    rxn2shingling = {}
    with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
        future_to_data = {executor.submit(process_one, one, omit): one for one in data_type}
        
        for future in tqdm(concurrent.futures.as_completed(future_to_data), total=len(data_type)):
            key, value = future.result()
            rxn2shingling[key] = value

    with open(f"/amax/data/reaction/shinglings/rxn2shingling_{omit}_{start_id}_{end_id}.pkl", "wb") as f:
        pickle.dump(rxn2shingling, f)
