import os
import numpy as np
from copy import deepcopy
import logging
from tqdm import tqdm
from .mol_tree import MolTree


def value_fn_wrapper(known_mol_vals, value_fn):
    if value_fn is None:
        value_fn = lambda x: 0      # the lower-bound of reaction cost
    return lambda x: \
        value_fn(x) if x not in known_mol_vals.keys() else known_mol_vals[x]


def exhaustive_search(target_mol, target_mol_id, known_mol_vals, expand_fn,
                      iterations, value_fn=None, viz=False, viz_dir=None,
                      routes=None, use_gln=False):
    mol_tree = MolTree(
        target_mol=target_mol,
        known_mols=known_mol_vals.keys(),
        value_fn=value_fn_wrapper(known_mol_vals, value_fn),
        zero_known_value=False
    )

    new_mols = []
    new_expansions = {}
    new_known_mol_vals = {}

    if not mol_tree.succ:
        # pbar = tqdm(range(iterations))
        # for i in pbar:
        for i in range(iterations):
            scores = []
            for m in mol_tree.mol_nodes:
                if m.open:
                    scores.append(m.v_target())
                else:
                    scores.append(np.inf)
            scores = np.array(scores)

            if np.min(scores) == np.inf:
                logging.info('No open nodes!')
                break

            mol_tree.search_status = np.min(scores)
            # pbar.set_description('%s | %s' % (str(mol_tree.search_status), str(mol_tree.root.succ_value)))

            m_next = mol_tree.mol_nodes[np.argmin(scores)]
            assert m_next.open

            result = expand_fn(m_next.mol)
            new_expansions[m_next.mol] = deepcopy(result)

            if result is not None and (len(result['scores']) > 0):
                reactants = result['reactants']
                scores = result['scores']
                # print(np.array(scores).sum())
                costs = 0.0 - np.log(np.clip(np.array(scores), 1e-3, 1.0))
                # if use_gln:
                #     costs = 0.0 - np.array(scores)
                # else:
                #     costs = 1.0 - np.array(scores)
                if 'templates' in result.keys():
                    templates = result['templates']
                else:
                    templates = result['template']
                reactant_lists = []
                for j in range(len(scores)):
                    reactant_list = list(set(reactants[j].split('.')))
                    reactant_lists.append(reactant_list)
                    # new_mols = new_mols + deepcopy(reactant_list)

                succ = mol_tree.expand(m_next, reactant_lists, costs, templates)

                if succ:
                    break

                # found optimal route
                if mol_tree.root.succ_value <= mol_tree.search_status:
                    break

            else:
                mol_tree.expand(m_next, None, None, None)
                logging.info('Expansion fails on %s!' % m_next.mol)

        logging.info('Final search status | success value | iter: %s | %s | %d'
                     % (str(mol_tree.search_status), str(mol_tree.root.succ_value), i))

    best_route = None
    if mol_tree.succ:
        best_route = mol_tree.get_best_route()
        assert best_route is not None

        for i in range(len(best_route.mols)):
            new_known_mol_vals[best_route.mols[i]] = best_route.values[i]

    possible = True
    if viz:
        if not os.path.exists(viz_dir):
            os.makedirs(viz_dir)

        if mol_tree.succ:
            if best_route.optimal:
                f = '%s/mol_%d_route_optimal.png' % (viz_dir, target_mol_id)
            else:
                f = '%s/mol_%d_route.png' % (viz_dir, target_mol_id)
            best_route.viz_route(f)

        else:
            route = routes[target_mol_id]
            f = '%s/mol_%d_search_progress.png' % (viz_dir, target_mol_id)
            unable_to_find = mol_tree.viz_search_progress(route, f)
            if unable_to_find:
                possible = False
                logging.info('Unable to find the solution with the current one step model')


        #     f = '%s/mol_%d_search_tree.png' % (viz_dir, target_mol_id)
        #     mol_tree.viz_search_tree(f)

    return new_mols, new_expansions, new_known_mol_vals, best_route, possible
