import numpy as np
import random
import torch
import logging
import json
import time
import multiprocessing
from tqdm import tqdm
from common import args, prepare_mlp, prepare_test_polymers, \
    prepare_starting_molecules, prepare_molstar_planner
from utils import setup_logger
from polymer_lib.polymerization_search import one_step_poly_retro
from functools import partial


def planwrapper(inp):
    idx, candidate = inp
    m = candidate['monomer']
    succ, msg = args.plan_handle(m, idx)

    return (succ, msg, idx)

def retro_polymer():
    # ===================== produce candidate monomer =========================
    repeat_units, monomers, _, _ = \
        prepare_test_polymers(args.test_real_polymers)
    device_id = 0 if args.gpu >= 0 else -1
    r0_one_step = prepare_mlp(args.mlp_templates_r0, args.mlp_model_dump_r0,
                           device_id, polymer=True)
    r0_expand_fn = partial(r0_one_step.run,
                           topk=args.expansion_topk, n_cores=args.n_cores)
    t0 = time.time()

    # ========================= retrosynthesis ================================
    starting_mols = prepare_starting_molecules(args.starting_molecules)
    one_step = prepare_mlp(args.mlp_templates, args.mlp_model_dump, device_id)
    value_fn = lambda x: 0.
    args.plan_handle = prepare_molstar_planner(
        one_step=one_step,
        value_fn=value_fn,
        use_gln=False,
        starting_mols=starting_mols,
        expansion_beam=args.expansion_beam,
        expansion_topk=args.expansion_topk,
        iterations=args.iterations,
        viz=False,
        viz_dir=None
    )
    t1 = time.time()
    print('preparation: %d' % (t1-t0))


    for (polyidx, repeat_unit) in enumerate(repeat_units):
        # polymer id = start, start+1, ..., end-1
        if polyidx < args.start or polyidx >= args.end:
            continue

        t0 = time.time()
        candidate_monomers, result = one_step_poly_retro(repeat_unit,
                                                         r0_expand_fn, n_cores=args.n_cores)
        t1 = time.time()
        print('first step time: %d' % (t1-t0))
        print(candidate_monomers, result)
        result['gt_monomer'] = monomers[polyidx]

        pool = multiprocessing.Pool(args.n_cores)
        pbar = tqdm(pool.imap_unordered(planwrapper,
                                        enumerate(result['candidates'])),
                    total=len(result['candidates']))

        for out in pbar:
            succ, msg, idx = out

        # debugging
        # for (idx, candidate) in enumerate(result['candidates']):
        #     m = candidate['monomer']
        #     succ, msg = args.plan_handle(m, idx)

            if succ:
                result['candidates'][idx]['recovered'] = 0
                if monomers[polyidx] != 'N/A':
                    if msg[0].contains_gt(monomers[polyidx].split('.')):
                        logging.info('Successfully recovered monomers: %s' %
                                     monomers[polyidx])
                        result['candidates'][idx]['recovered'] = 1

                result['candidates'][idx]['route'] = msg[0].serialize()
                result['candidates'][idx]['route_prob'] = \
                    str(np.exp(-msg[0].succ_value))

        pool.close()
        t2 = time.time()
        print('retro step x %d time: %d' % (len(result['candidates']), t2-t1))

        print(result)
        with open('%s/result_%d.json' % (args.test_output_dir, polyidx), 'w') as f:
            json.dump(result, f)


if __name__ == '__main__':
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    setup_logger('retro_polymer.log')

    multiprocessing.set_start_method('fork')

    retro_polymer()
