import logging
from tqdm import tqdm
from rdkit import Chem
from .rdkit_ext import polymerization_is_valid, form_double_units, \
    produce_candidate_monomer, produce_double_candidate_monomer
from utils import blockPrint, enablePrint
import time
import sys


rp_model = None

def init_rp_model():
    global rp_model
    blockPrint()
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
    if tf.__version__ >= '2.0':
        import tensorflow.compat.v1 as tf
    tf.logging.set_verbosity(tf.logging.ERROR)
    from rexgen_direct.api import ReactionPrediction
    rp_model = ReactionPrediction()
    enablePrint()

    global inited
    inited = True

# workaround for python 3.6
inited = False

def compute_prob_wrapper(input):
    # workaround for python 3.6
    global inited
    if not inited:
        init_rp_model()

    # compute_prob wrapper
    data = input
    (reactants, m, double_m) = data
    try:
        forward_prob = rp_model.compute_prob('%s.%s' % (m, m), double_m)

        if forward_prob > 0.:
            return reactants, m, forward_prob
    except Exception as e:
        print(e)

    return None

def worker_process(q_in, q_out):
    init_rp_model()

    x = q_in.get()
    while x is not None:
        q_out.put(compute_prob_wrapper(x))
        x = q_in.get()

    sys.exit(0)     # normal exit

def one_step_poly_retro(repeat_unit, r0_expand_fn, topk=200, n_cores=1, debug=False, mp_only=False):
    """
    One step polymerization retro-analysis by searching using r0 templates.

    :param repeat_unit: repeat unit of the target chain polymer. For example,
            c1([At])ccc(cc1)C=C([At]), with the two [At] representing the bonds
            connecting neighboring repeat units.
    :param r0_expand_fn: one step retrosyn by matching r0 templates.
    :return: candidate_monomers: list of candidate monomers, together with the
            forward probability.
            For example, [(candidate1, prob1), (candidate2, prob2), ...]
    """
    t0 = time.time()
    try:
        double_unit, double_idxs = form_double_units(repeat_unit)
    except Exception as e:
        logging.ERROR('Failing to form double unit: %s' % e)
        return None

    try:
        out = r0_expand_fn(double_unit)
    except Exception as e:
        logging.ERROR('Failing to search for a matching template: %s' % e)
        return None

    t05 = time.time()

    num_atoms = len(Chem.MolFromSmiles(repeat_unit).GetAtoms()) - 2
    reactants_candidates = []
    template_dict = {}
    if out is not None and (len(out['reactants']) > 0):
        for j in range(len(out['reactants'])):
            if polymerization_is_valid(out['reactants'][j], num_atoms):
                reactants_candidates.append(out['reactants'][j])
                template_dict[out['reactants'][j]] = out['template'][j]

    result = {
        'repeat_unit': repeat_unit,
        'double_unit': double_unit,
        'double_idxs': '%d,%d,%d,%d' % (double_idxs[0], double_idxs[1], double_idxs[2], double_idxs[3]),
        'candidates': []
    }

    t1 = time.time()
    forward_data = []
    print('producing candidate monomers for %d reactions' % len(reactants_candidates))
    for reactants in tqdm(reactants_candidates):
        m = produce_candidate_monomer(reactants, double_idxs)
        double_m = produce_double_candidate_monomer(reactants, double_idxs)

        # if m in ['CC(CC(=O)Cl)OC(=O)O',
        #          'CC(=O)OC1C(=O)OC(=O)C1OC(=O)Oc1c(Cl)cc(C(C)(C)c2cc(Cl)c(O)c(Cl)c2)cc1Cl',
        #          'CC(C)(c1cc(Cl)c(OC(=O)Cl)c(Cl)c1)c1cc(Cl)c(OC(=O)C(=O)NN)c(Cl)c1',
        #          'CC(C)(c1ccc(Oc2ccc(C(=O)O)cc2)cc1)c1ccc(Oc2ccc(S(=O)(=O)c3ccc(F)cc3)cc2)cc1',
        #          'CCCCOC(=O)C(C#N)(CCO)COCCO',
        #          'CC(C)(C)[Si](Oc1c(-c2ccccc2)cc(Cl)cc1-c1ccccc1)(c1ccccc1)c1ccccc1',
        #          'CC(CC(=O)Cl)C(=O)OC1CCCCC1',
        #          'CC1(C)OB(c2cccc(B3OC(C)(C)C(C)(C)O3)c2)OC1(C)C']:
        #     # segmentation fault
        #     continue

        if m is not None and double_m is not None:
            forward_data.append((reactants, m, double_m))

    t2 = time.time()

    candidate_monomers = []

    print('computing forward probabilities')

    candidate_list = []
    if mp_only:
        from multiprocessing import Process, Queue

        forward_data = forward_data + [None] * n_cores

        q_in = Queue(n_cores * 2)
        q_out = Queue(len(forward_data))
        worker_procs = [Process(target=worker_process, args=[q_in, q_out]) for i in range(n_cores)]
        for p in worker_procs:
            p.start()

        data_idx = 0
        while True:
            for i in range(n_cores):
                if not worker_procs[i].is_alive() and worker_procs[i].exitcode != 0:
                    print('worker %d is dead' % i)
                    worker_procs[i] = Process(target=worker_process, args=[q_in, q_out])
                    print('worker %d restarts' % i)
                    worker_procs[i].start()

            while (not q_in.full()) and (data_idx < len(forward_data)):
                q_in.put(forward_data[data_idx])
                data_idx += 1

            if q_in.empty() and data_idx == len(forward_data):
                break

        for p in worker_procs:
            p.join()

        receive_cnt = 0
        while not q_out.empty():
            try:
                out = q_out.get_nowait()
                receive_cnt += 1
                if out is not None:
                    reactants, m, forward_prob = out
                    candidate_list.append((reactants, m, forward_prob))
            except Exception as e:
                print(e)

        print('receive %d/%d' % (receive_cnt, len(forward_data)-n_cores))

    elif not debug:
        # use concurrent library
        import concurrent.futures

        # concurrent.futures
        done_dict = set()
        failed = False
        # with concurrent.futures.ProcessPoolExecutor(max_workers=n_cores, initializer=init_rp_model) as executor:
        with concurrent.futures.ProcessPoolExecutor(max_workers=n_cores) as executor:
            future_out = {executor.submit(compute_prob_wrapper, data): data for data in
                             forward_data}
            with tqdm(concurrent.futures.as_completed(future_out),
                               total=len(forward_data)) as pbar:
                for future in pbar:
                    data = future_out[future]
                    try:
                        out = future.result()
                        if out is not None:
                            reactants, m, forward_prob = out
                            candidate_list.append((reactants, m, forward_prob))
                        done_dict.add(data)
                    except Exception as exc:
                        pbar.set_description('%s' % exc)
                        failed = True
                        break

        if failed:
            print('fall back to sequential execution')
            # executor = concurrent.futures.ProcessPoolExecutor(max_workers=1, initializer=init_rp_model)
            executor = concurrent.futures.ProcessPoolExecutor(max_workers=1)
            for data in tqdm(forward_data):
                if data in done_dict:
                    continue
                future = executor.submit(compute_prob_wrapper, data)
                try:
                    out = future.result()
                    if out is not None:
                        reactants, m, forward_prob = out
                        candidate_list.append((reactants, m, forward_prob))
                except Exception as exc:
                    print('%s' % exc)
                    executor.shutdown()
                    # executor = concurrent.futures.ProcessPoolExecutor(max_workers=1,
                    #                                                   initializer=init_rp_model)
                    executor = concurrent.futures.ProcessPoolExecutor(max_workers=1)

            executor.shutdown()

    else:
        # sequential execution
        from rexgen_direct.api import ReactionPrediction
        rp_model = ReactionPrediction()
        for reactants, m, double_m in tqdm(forward_data):
            try:
                forward_prob = rp_model.compute_prob('%s.%s' % (m, m), double_m)

                if forward_prob > 0.:
                    candidate_list.append((reactants, m, forward_prob))
            except Exception:
                pass

    candidate_list = sorted(candidate_list, key=lambda item: -item[2])[:topk]
    for reactants, m, forward_prob in candidate_list:
        candidate_monomers.append((m, forward_prob))
        c = {
            'reactants': reactants,
            'monomer': m,
            'forward_prob': str(forward_prob),
            'template_id': template_dict[reactants]
        }
        result['candidates'].append(c)

    t3 = time.time()

    print('inside first step: %f, %f, %f, %f' % (t05-t0, t1-t05, t2-t1, t3-t2))

    return candidate_monomers, result
