"""
Question: with confidence set to softmax values, does retro-fallback do better than retro star?

It should behave identically until a solution is found, and then should hopefully behave better??
"""

from __future__ import annotations

import argparse
import logging
import math
import pickle
import sys
import numpy as np

from tqdm.auto import tqdm

from syntheseus.search.chem import Molecule
from syntheseus.search.graph.and_or import AndNode, OrNode
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch
from syntheseus.search.algorithms.best_first.base import logger as retro_star_logger
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.search.analysis.route_extraction import min_cost_routes
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator
from syntheseus.search.graph.message_passing import run_message_passing

from retro_fallback.rfb_independent import (
    RetroFallbackTreeSearch,
    independent_success_prob_update,
    logger as rfb_logger,
)
from retro_star_task import (
    RetroStarReactionModel,
    RetroStarReactionCostFunction,
    RetroStarInventory,
    RetroStarValueMLP,
    file_names,
)


class RetroFallbackValueMLP(RetroStarValueMLP):
    """Wrapper converting costs to values."""

    def __init__(self, k: float = 1.0, **kwargs):
        super().__init__(**kwargs)
        self._k = k

    def _evaluate_nodes(
        self,
        nodes,
        graph=None,
    ) -> list[float]:
        super_vals = super()._evaluate_nodes(nodes, graph)
        return [math.exp(-v * self._k) for v in super_vals]


class RetroStarSoftmaxFeasibility(RetroStarReactionCostFunction):
    def __init__(self, k: float = 1.0, **kwargs):
        super().__init__(**kwargs)
        self._k = k

    def _evaluate_nodes(self, nodes, graph=None) -> list[float]:
        return [n.reaction.metadata["score"] ** self._k for n in nodes]


def retro_star_search(
    smiles_list: list[str],
    use_value_function: bool,
    use_tqdm: bool = False,
    limit_rxn_model_calls: int = 500,
    k: float = 1.0,
) -> list[tuple[float, float]]:
    """

    Do search on a list of SMILES strings and report the time of first solution.
    """

    # Initialize algorithm.
    rxn_model = RetroStarReactionModel(use_cache=False)  # no caching
    inventory = RetroStarInventory()
    if use_value_function:
        retro_star_value_fn = RetroStarValueMLP()
        retro_fallback_value_fn = RetroFallbackValueMLP(k=k)
    else:
        retro_star_value_fn = ConstantNodeEvaluator(0.0)
        retro_fallback_value_fn = ConstantNodeEvaluator(1.0)
    common_kwargs = dict(
        reaction_model=rxn_model,
        mol_inventory=inventory,
        limit_reaction_model_calls=limit_rxn_model_calls,
        time_limit_s=1_000,
        max_expansion_depth=20,  # prevent overly-deep solutions
        prevent_repeat_mol_in_trees=True,  # original paper did this
    )
    retro_star_alg = RetroStarSearch(
        and_node_cost_fn=RetroStarReactionCostFunction(),
        value_function=retro_star_value_fn,
        **common_kwargs,
    )
    retro_fallback_alg = RetroFallbackTreeSearch(
        and_node_marginal_feasibility_fn=RetroStarSoftmaxFeasibility(k=k),
        value_function=retro_fallback_value_fn,
        **common_kwargs,
    )

    # Do search
    logger = logging.getLogger("COMPARISON")
    min_soln_times: list[tuple[float, float]] = []
    if use_tqdm:
        smiles_iter = tqdm(smiles_list)
    else:
        smiles_iter = smiles_list
    for i, smiles in enumerate(smiles_iter):
        logger.debug(f"Start search {i}/{len(smiles_list)}. SMILES: {smiles}")
        this_soln_times = list()
        for alg in [retro_star_alg, retro_fallback_alg]:
            alg.reset()
            output_graph, _ = alg.run_from_mol(Molecule(smiles))

            # Analyze solution time
            for node in output_graph.nodes():
                node.data["analysis_time"] = node.data["num_calls_rxn_model"]
            soln_time = get_first_solution_time(output_graph)
            this_soln_times.append(soln_time)

            # Analyze number of routes
            MAX_ROUTES = 1000
            routes = list(min_cost_routes(output_graph, MAX_ROUTES))

            # Analyze success probability using models from retro-fallback
            retro_fallback_alg._set_feasibilities(
                [n for n in output_graph.nodes() if isinstance(n, AndNode)],
                output_graph,
            )
            retro_fallback_alg._set_purchasabilities(
                [n for n in output_graph.nodes() if isinstance(n, OrNode)], output_graph
            )
            run_message_passing(
                graph=output_graph,
                nodes=sorted(output_graph.nodes(), key=lambda n: n.depth, reverse=True),
                update_fns=[
                    independent_success_prob_update,
                ],
                update_predecessors=True,
                update_successors=False,
            )
            logger.debug(
                f"Done: nodes={len(output_graph)}, solution time = {soln_time}, "
                f"final success prob = {output_graph.root_node.data['rfb_ind_succ_prob']:.4e}"
                f", num routes = {len(routes)}"
            )
        min_soln_times.append(tuple(this_soln_times))

    return min_soln_times


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--limit_num_smiles",
        type=int,
        default=None,
        help="Maximum number of SMILES to run.",
    )
    parser.add_argument(
        "--rxn_model_calls",
        type=int,
        default=500,
        help="Allowed number of calls to reaction model.",
    )
    parser.add_argument(
        "--prob_k",
        type=float,
        default=1.0,
    )
    args = parser.parse_args()

    # Logging
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format="%(asctime)s %(name)s %(levelname)s %(message)s",
        filemode="w",
    )
    logging.getLogger().info(args)
    retro_star_logger.setLevel(logging.DEBUG)
    rfb_logger.setLevel(logging.DEBUG)

    # Load all SMILES to test
    with open(file_names.TEST_ROUTES, "rb") as f:
        test_routes = pickle.load(f)
    test_smiles = [r[0].split(">")[0] for r in test_routes]
    if args.limit_num_smiles is not None:
        test_smiles = test_smiles[: args.limit_num_smiles]

    # Run without value function (retro*-0)
    rxn_model_budget = args.rxn_model_calls
    logging.getLogger().info("Starting search with no value function.")
    results_no_value_fn = np.asarray(
        retro_star_search(
            smiles_list=test_smiles,
            use_tqdm=True,
            limit_rxn_model_calls=rxn_model_budget,
            use_value_function=False,
            k=args.prob_k,
        )
    )
