import os
import numpy as np
import logging
import time
from .mol_tree import MolTree


def molstar(target_mol, target_mol_id, starting_mols, expand_fn, value_fn,
            iterations, viz=False, viz_dir=None):
    mol_tree = MolTree(
        target_mol=target_mol,
        known_mols=starting_mols,
        value_fn=value_fn
    )

    i = -1

    if not mol_tree.succ:
        for i in range(iterations):
            scores = []
            pure_scores = []
            for m in mol_tree.mol_nodes:
                if m.open:
                    scores.append(m.v_target())
                    pure_scores.append(m.v_pure_target())
                else:
                    scores.append(np.inf)
                    pure_scores.append(np.inf)
            scores = np.array(scores)
            pure_scores = np.array(pure_scores)

            if np.min(scores) == np.inf:
                logging.info('No open nodes!')
                break

            # threshold = np.min(pure_scores) * 1.5
            # scores[pure_scores > threshold] = np.inf

            metric = scores

            mol_tree.search_status = np.min(metric)
            m_next = mol_tree.mol_nodes[np.argmin(metric)]
            # print(m_next.open, m_next.v_target(), np.min(metric), metric.shape, np.argmin(metric))
            assert m_next.open

            t = time.time()
            result = expand_fn(m_next.mol)
            mol_tree.expand_fn_time += (time.time() - t)

            if result is not None and (len(result['scores']) > 0):
                reactants = result['reactants']
                scores = result['scores']
                costs = 0.0 - np.log(np.clip(np.array(scores), 1e-3, 1.0))
                # costs = 1.0 - np.array(scores)
                if 'templates' in result.keys():
                    templates = result['templates']
                else:
                    templates = result['template']

                reactant_lists = []
                for j in range(len(scores)):
                    reactant_list = list(set(reactants[j].split('.')))
                    reactant_lists.append(reactant_list)

                assert m_next.open
                succ = mol_tree.expand(m_next, reactant_lists, costs, templates)

                # if succ:
                #     break

                # found optimal route
                if mol_tree.root.succ_value <= mol_tree.search_status:
                    break

            else:
                mol_tree.expand(m_next, None, None, None)
                # logging.info('Expansion fails on %s!' % m_next.mol)

        # logging.info('Final search status | success value | iter: %s | %s | %d'
        #              % (str(mol_tree.search_status), str(mol_tree.root.succ_value), i+1))
        # print(mol_tree.value_fn_time, mol_tree.expdand_fn_time)

    best_route = None
    routes = []
    if mol_tree.succ:
        best_route = mol_tree.get_best_route()
        # routes = mol_tree.get_routes()
        assert best_route is not None

    possible = True
    if viz:
        if not os.path.exists(viz_dir):
            os.makedirs(viz_dir)

        if mol_tree.succ:
            if best_route.optimal:
                f = '%s/mol_%d_route_optimal' % (viz_dir, target_mol_id)
            else:
                f = '%s/mol_%d_route' % (viz_dir, target_mol_id)
            best_route.viz_route(f)

        else:
            route = routes[target_mol_id]
            f = '%s/mol_%d_search_progress.pdf' % (viz_dir, target_mol_id)
            unable_to_find = mol_tree.viz_search_progress(route, f)
            if unable_to_find:
                possible = False
                logging.info('Unable to find the solution with the current one step model')


        f = '%s/mol_%d_search_tree' % (viz_dir, target_mol_id)
        mol_tree.viz_search_tree(f)

    return mol_tree.succ, (best_route, possible, i+1, routes)
