import numpy as np
import random
import torch
import logging
import json
import time
import multiprocessing
from tqdm import tqdm
from common import args, prepare_mlp, prepare_test_polymers, \
    prepare_starting_molecules, prepare_molstar_planner
from utils import setup_logger
from polymer_lib.polymerization_search import one_step_poly_retro
from functools import partial


def planwrapper(inp):
    idx, candidate = inp
    m = candidate
    succ, msg = args.plan_handle(m, idx)

    return (succ, msg, idx)

def run_cand():
    t0 = time.time()
    # ===================== produce candidate monomer =========================
    # interval = [(0,100), (100,200), (200,300), (300,400), (400,500), (500,600), (600,700)]
    interval = [(0,60)]
    monomer_candidate_dict = {}
    for start, end in interval:
        monomer_candidate_dict.update(json.load(open('%s/monomer_dict_%d_%d.json' % (args.dataset_dir, start, end), 'r')))
    print('total', len(monomer_candidate_dict))
    monomer_candidate_list = []
    for k, v in monomer_candidate_dict.items():
        for t in v:
            monomer_candidate_list.append(t['reactants'].split('.')[0])
            monomer_candidate_list.append(t['reactants'].split('.')[1])

    monomer_candidate_list = list(set(monomer_candidate_list))

    # ========================= retrosynthesis ================================
    starting_mols = prepare_starting_molecules(args.starting_molecules)
    device_id = 0 if args.gpu >= 0 else -1
    one_step = prepare_mlp(args.mlp_templates, args.mlp_model_dump, device_id)
    value_fn = lambda x: 0.
    args.plan_handle = prepare_molstar_planner(
        one_step=one_step,
        value_fn=value_fn,
        use_gln=False,
        starting_mols=starting_mols,
        expansion_beam=args.expansion_beam,
        expansion_topk=args.expansion_topk,
        iterations=500,
        viz=False,
        viz_dir=None
    )
    t1 = time.time()
    print('preparation: %d' % (t1-t0))


    pool = multiprocessing.Pool(args.n_cores)
    pbar = tqdm(pool.imap_unordered(planwrapper,
                                    enumerate(monomer_candidate_list)),
                total=len(monomer_candidate_list))

    for out in pbar:
        succ, msg, idx = out

        if succ:
            result = {
                'cand': monomer_candidate_list[idx],
                'route': msg[0].serialize(),
                'route_prob': str(np.exp(-msg[0].succ_value))
            }
            with open('%s/result_%d.json' % (args.monomer_result_dir, idx+10000), 'w') as f:
                json.dump(result, f)

    pool.close()
    t2 = time.time()
    print('retro step x %d time: %d' % (len(monomer_candidate_list), t2-t1))


if __name__ == '__main__':
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    setup_logger('run_cand.log')

    multiprocessing.set_start_method('fork')

    run_cand()
