"""Main algorithm comparison script for ICLR 2024 submission."""

from __future__ import annotations

import argparse
import collections
import logging
import json
import math
from pprint import pformat
import sys
import time
from typing import Any
from tqdm import tqdm
import pickle
import numpy as np

from syntheseus.search.chem import Molecule
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.search.analysis.route_extraction import iter_routes_time_order
from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator
from syntheseus.search.graph.and_or import AndNode, OrNode
from syntheseus.search.graph.standardization import get_unique_node_andor_graph
from syntheseus.search.algorithms.best_first.retro_star import (
    RetroStarSearch,
    reaction_number_update,
)
from syntheseus.search.algorithms.mcts.base import pucb_bound
from syntheseus.search.algorithms.mcts.molset import MolSetMCTS, MolSetGraph
from syntheseus.search.algorithms.breadth_first import (
    AndOr_BreadthFirstSearch,
    MolSet_BreadthFirstSearch,
)
from syntheseus.search.graph.message_passing import run_message_passing

from retro_fallback.feasibility_model import (
    FeasibilityModel,
    PurchasabilityModel,
)
from iclr24_expts.feasibility_models import (
    ConstantIndependentFeasibility,
    Constant_GPLVM_Feasibility,
    RankIndependentFeasibility,
    Rank_GPLVM_Feasibility,
)
from iclr24_expts.purchasability_models import (
    eMoleculesInventory,
    BinaryEMoleculesPurchasability,
)
from iclr24_expts.heuristics import (
    SAScoreProbabilityHeuristic,
    SAScoreCostHeuristic,
    mol_to_sa_succ_prob,
)
from retro_fallback.rfb_sample import SampleRetroFallback
from retro_fallback.success_analysis import (
    estimate_synthesis_success_across_time,
    iter_routes_feasibility_order,
)

from retro_star_task import RetroStarReactionModel


main_logger = logging.getLogger("ICLR24-main")


class RetroStarMarginalFeasibilityCost(NoCacheNodeEvaluator):
    """Sets the cost of a reaction to be -log(marginal feasibility)."""

    def __init__(
        self,
        feasibility_model: FeasibilityModel,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.feasibility_model = feasibility_model

    def _evaluate_nodes(self, nodes, graph=None) -> list[float]:
        feasibilities = list(
            self.feasibility_model.marginal_probability([n.reaction for n in nodes])
        )
        return [-math.log(f) if f > 0 else math.inf for f in feasibilities]


class RetroStarMarginalPurchasabilityCost(NoCacheNodeEvaluator):
    """Sets the cost of a molecule to be -log(marginal purchasability)."""

    def __init__(
        self,
        purchasability_model: PurchasabilityModel,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.purchasability_model = purchasability_model

    def _evaluate_nodes(self, nodes, graph=None) -> list[float]:
        purchasabilities = list(
            self.purchasability_model.marginal_probability([n.mol for n in nodes])
        )

        # NOTE: use a large number instead of inf due to convergence issues
        big_number = 1e4
        return [-math.log(f) if f > 0 else big_number for f in purchasabilities]


class _MCTS_SuccessEvaluator(NoCacheNodeEvaluator):
    """Base class with common methods to evaluate success probability."""

    def __init__(
        self,
        feasibility_model: FeasibilityModel,
        purchasability_model: PurchasabilityModel,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.feasibility_model = feasibility_model
        self.purchasability_model = purchasability_model

    def get_reactions_and_mols(self, node, graph):
        """Get all reactions and molecules in the route to this node."""
        mols_need_to_purchase = set(node.mols)
        reactions = set()
        curr_node = node
        while curr_node != graph.root_node:
            parents = list(graph.predecessors(curr_node))
            assert len(parents) == 1, "Not a tree?"
            parent = parents[0]
            reactions.add(graph._graph.edges[parent, curr_node]["reaction"])
            curr_node = parent
        return mols_need_to_purchase, reactions

    def get_feas_and_pur_samples(self, node, graph):
        mols, rxns = self.get_reactions_and_mols(node, graph)
        purchase_samples = self.purchasability_model.prior_sample(mols)
        feas_samples = self.feasibility_model.prior_sample(rxns)
        return purchase_samples, feas_samples


class MCTS_SuccessReward(_MCTS_SuccessEvaluator):
    """Reward is the success probability of a given route, estimated with samples."""

    def __init__(self, *args, max_visit: int = 100, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_visit = max_visit

    def _evaluate_nodes(self, nodes, graph=None) -> list[float]:
        assert graph is not None
        output = []
        for node in nodes:
            if node.num_visit < self.max_visit:
                purchase_samples, feas_samples = self.get_feas_and_pur_samples(
                    node, graph
                )
                all_outcomes = np.stack(
                    list(purchase_samples.values()) + list(feas_samples.values())
                )
                output.append(float(np.average(np.min(all_outcomes, axis=0))))
            else:
                output.append(0.0)
                if node.num_visit == self.max_visit:
                    main_logger.debug("Hit max visit count, not evaluating further.")

        return output


class MCTS_Optimistic_ValueFunction(_MCTS_SuccessEvaluator):
    """Value function is just feasibility of reactions in route. Mols assumed to be purchasable."""

    def _evaluate_nodes(self, nodes, graph=None) -> list[float]:
        assert graph is not None
        output = []
        for node in nodes:
            _, feas_samples = self.get_feas_and_pur_samples(node, graph)
            if len(feas_samples) == 0:
                output.append(1.0)
            else:
                all_outcomes = np.stack(list(feas_samples.values()))
                output.append(float(np.average(np.min(all_outcomes, axis=0))))

        return output


class MCTS_SAscore_ValueFunction(_MCTS_SuccessEvaluator):
    """Like optimistic value function, but failing molecule samples are assigned SA score."""

    def _evaluate_nodes(self, nodes, graph=None) -> list[float]:
        assert graph is not None
        output = []
        for node in nodes:
            purchase_samples, feas_samples = self.get_feas_and_pur_samples(node, graph)
            heuristic_samples = {
                mol: np.maximum(samp, mol_to_sa_succ_prob(mol))
                for mol, samp in purchase_samples.items()
            }
            all_outcomes = np.stack(
                list(feas_samples.values()) + list(heuristic_samples.values())
            )
            output.append(float(np.average(np.prod(all_outcomes, axis=0))))

        return output


class MCTS_Marginal_Feasibility_Policy(NoCacheNodeEvaluator):
    """Policy outputs are just marginal feasibilities."""

    def __init__(
        self,
        feasibility_model: FeasibilityModel,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.feasibility_model = feasibility_model

    def _evaluate_nodes(self, nodes, graph=None) -> list[float]:
        rxns = []
        for node in nodes:
            parent = list(graph.predecessors(node))[0]
            rxns.append(graph._graph.edges[parent, node]["reaction"])
        output = self.feasibility_model.marginal_probability(rxns)
        return output


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smiles_file",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--smiles_start_idx",
        type=int,
        default=None,
        help="Index of first SMILES to use.",
    )
    parser.add_argument(
        "--smiles_end_idx",
        type=int,
        default=None,
        help="Index of last SMILES to use (exclusive).",
    )
    parser.add_argument(
        "--output_json",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--limit_rxn_model_calls",
        type=int,
        default=500,
        help="Allowed number of calls to reaction model.",
    )
    parser.add_argument(
        "--marginal_feasibility",
        type=str,
        default="constant",
        help="Type of feasibility model to use.",
    )
    parser.add_argument(
        "--corr_feasibility",
        type=str,
        default="independent",
        help="Type of feasibility model to use.",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=10_000,
        help="Number of samples to use for feasibility/purchasability model (and success calculation).",
    )
    parser.add_argument(
        "--num_samples_alg",
        type=int,
        default=None,
        help="Number of samples to use during search (if different).",
    )
    parser.add_argument(
        "--algorithm",
        type=str,
        required=True,
        help="Algorithm to use.",
    )
    parser.add_argument(
        "--heuristic",
        type=str,
        default="optimistic",
    )
    parser.add_argument(
        "--shuffle_reactions",
        action="store_true",
        help="Flag to shuffle reactions from reaction model.",
    )
    parser.add_argument(
        "--andor_graph_analysis",
        action="store_true",
        help="For OR graphs, convert to an AND/OR graph when performing analysis.",
    )
    parser.add_argument(
        "--skip_most_feasible_route",
        action="store_true",
        help="Flag to skip calculating the most feasible route (sometimes expensive).",
    )
    parser.add_argument(
        "--save_output_graphs",
        action="store_true",
        help="Flag to save output graphs.",
    )
    return parser


def get_nodes_in_route(graph) -> set:
    """
    Extract all nodes which are part of a solved synthesis route.
    In all cases this can be done by descending from the root node.
    """

    queue = collections.deque(
        [graph.root_node]
    )  # only nodes whose parents are in a route
    output = set()
    while len(queue) > 0:
        node = queue.popleft()
        if node in output:
            continue
        if node.has_solution:
            output.add(node)
            for child in graph.successors(node):
                queue.append(child)
    return output


def analyze_output_graph(
    output_graph,
    feasibility_model,
    purchasability_model,
    analysis_times,
    calculate_most_feasible_route=True,
    max_num_routes: int = 100,
) -> dict[str, Any]:
    """Get various analysis results for the output graph."""
    main_logger.debug("Starting analysis.")
    t0 = time.monotonic()

    # Insert analysis times (here always equal to num calls to rxn model)
    for node in output_graph.nodes():
        node.data["analysis_time"] = node.data["num_calls_rxn_model"]

    # Metric 1: solution time
    soln_time = get_first_solution_time(output_graph)
    t1 = time.monotonic()
    main_logger.debug(f"Solution time calculated in {t1 - t0:.2f} s.")

    # Metric 2: success probability over time
    success_probabilities_over_time = estimate_synthesis_success_across_time(
        output_graph,
        feasibility_model,
        purchasability_model,
        max_times=analysis_times,
    )
    t2 = time.monotonic()
    main_logger.debug(f"Success probabilities calculated in {t2 - t1:.2f} s.")

    # Metric 3: number of routes over time (capped at M).
    # Finding the first M routes is enough to calculate this.
    # routes_time_order = list(
    #     iter_routes_time_order(output_graph, max_routes=max_num_routes)
    # )
    # routes_time_order = []
    # route_times = [max(n.data["analysis_time"] for n in r) for r in routes_time_order]
    # num_routes_over_time = [sum(t <= at for t in route_times) for at in analysis_times]
    # del routes_time_order, route_times
    # t3 = time.monotonic()
    # main_logger.debug(f"Number of routes calculated in {t3 - t2:.2f} s.")

    # A bunch of time-dependent metrics which cannot be computed simultaneously for different times
    num_nodes_over_time = []
    search_duration_over_time = []
    shortest_route_over_time = []
    most_feasible_route_over_time = []
    for at in analysis_times:
        t0 = time.monotonic()

        # Two collections of nodes important for analysis
        nodes_before_at = [
            n for n in output_graph.nodes() if n.data["analysis_time"] <= at
        ]
        nodes_at_or_after_at = [
            n for n in output_graph.nodes() if n.data["analysis_time"] >= at
        ]

        # Metric 4: graph size over time
        num_nodes_over_time.append(len(nodes_before_at))

        # Metric 5: wallclock time taken to reach this number of nodes.
        # Because search can terminate early, it is not clear if this number of nodes would even be reached.
        # Therefore, we take the minimum time of any node at or after this time.
        if len(nodes_at_or_after_at) > 0:
            search_duration_over_time.append(
                (
                    min(n.creation_time for n in nodes_at_or_after_at)
                    - output_graph.root_node.creation_time
                ).total_seconds()
            )
        else:
            search_duration_over_time.append(math.nan)
        t5 = time.monotonic()
        main_logger.debug(f"\tt={at}: Graph size/time calculated in {t5 - t0:.2f} s.")

        # Metric 6: shortest route over time.
        # Re-use retro* updates to calculate this.
        if soln_time <= at:
            # Set reaction/mol costs appropriately
            for node in output_graph.nodes():
                node.data["reaction_number"] = math.inf  # initialize to overestimate
                node.data[
                    "reaction_number_estimate"
                ] = math.inf  # Effectively no estimate
                if node.data["analysis_time"] > at:
                    node.data["retro_star_rxn_cost"] = math.inf
                    node.data["retro_star_mol_cost"] = math.inf
                else:
                    if isinstance(node, OrNode):
                        node.data["retro_star_mol_cost"] = (
                            0.0 if node.mol.metadata["is_purchasable"] else math.inf
                        )
                    elif isinstance(node, AndNode):
                        node.data["retro_star_rxn_cost"] = 1.0
                    else:
                        raise ValueError(f"Invalid node type: {type(node)}")

            # Run updates
            run_message_passing(
                output_graph,
                nodes=output_graph.nodes(),
                update_fns=[reaction_number_update],
                update_predecessors=True,
                update_successors=False,
            )
            shortest_route_over_time.append(
                output_graph.root_node.data["reaction_number"]
            )

        else:
            shortest_route_over_time.append(math.inf)
        t6 = time.monotonic()
        main_logger.debug(f"\tt={at}: Shortest route calculated in {t6 - t5:.2f} s.")

        # Metric 7: most feasible route over time
        if not calculate_most_feasible_route:
            most_feasible_route_over_time.append(math.nan)
        elif soln_time <= at:
            # Set feasibilities for this time
            for node in output_graph.nodes():
                # NOTE: this is re-using node attributes from earlier success probability calculation
                if isinstance(node, OrNode):
                    node.data["route_feas_samples"] = node.data["rfb_purchase_samples"]
                else:
                    node.data["route_feas_samples"] = node.data["rfb_feasible_samples"]

                # Zero samples if analysis time exceeded
                if node.data["analysis_time"] > at:
                    node.data["route_feas_samples"] = np.zeros_like(
                        node.data["route_feas_samples"]
                    )

            # Extract best solution
            most_feasible_solutions = list(
                iter_routes_feasibility_order(output_graph, max_routes=1)
            )
            assert len(most_feasible_solutions) == 1
            most_feasible_route_over_time.append(
                math.exp(-most_feasible_solutions[0][0])
            )
            del most_feasible_solutions
        else:
            # There are no solutions, just directly output 0
            most_feasible_route_over_time.append(0.0)
        t7 = time.monotonic()
        main_logger.debug(
            f"\tt={at}: Most feasible route calculated in {t7 - t6:.2f} s."
        )

        del at

    return dict(
        solution_time=soln_time,
        success_probabilities=success_probabilities_over_time,
        # num_routes_over_time=num_routes_over_time,
        # cap_num_routes=max_num_routes,  # so it is clear what the max number of routes is
        num_nodes_over_time=num_nodes_over_time,
        search_duration_over_time=search_duration_over_time,
        shortest_route_over_time=shortest_route_over_time,
        most_feasible_route_over_time=most_feasible_route_over_time,
        num_nodes_part_of_a_route=len(get_nodes_in_route(output_graph)),
    )


def analyze_output_graph_molset(
    output_graph,
    feasibility_model,
    purchasability_model,
    analysis_times,
    max_num_routes=1_000_000,
) -> dict[str, Any]:
    """Alternative analysis procedure for MolSetGraphs. Works by extracting all routes."""

    main_logger.debug("Starting analysis.")
    t0 = time.monotonic()

    # Insert analysis times (here always equal to num calls to rxn model)
    for node in output_graph.nodes():
        node.data["analysis_time"] = node.data["num_calls_rxn_model"]

    # Metric 1: solution time
    soln_time = get_first_solution_time(output_graph)
    t1 = time.monotonic()
    main_logger.debug(f"Solution time calculated in {t1 - t0:.2f} s.")

    # Remaining metrics require routes: extract them all
    routes = list(iter_routes_time_order(output_graph, max_routes=max_num_routes))
    main_logger.debug(f"Extracted {len(routes)} routes.")
    assert len(routes) < max_num_routes, "Some routes likely not extracted."

    # Figure out which reactions and molecules are in each route: this will be used to draw samples
    reactions_in_routes = [
        set(output_graph.to_synthesis_graph(r).nodes()) for r in routes
    ]
    mols_in_routes = [set(max(r, key=lambda n: n.depth).mols) for r in routes]

    # Draw the samples
    all_mols = set(m for li in mols_in_routes for m in li)
    all_rxns = set(r for li in reactions_in_routes for r in li)
    main_logger.debug(
        f"Drawing samples for {len(all_mols)} mols and {len(all_rxns)} rxns."
    )
    rxn_samples = feasibility_model.prior_sample(all_rxns)
    mol_samples = purchasability_model.prior_sample(all_mols)
    main_logger.debug("Done drawing samples.")

    # Assign a time to each route
    route_times = [max(n.data["analysis_time"] for n in r) for r in routes]

    # Compute success outcomes for each route
    success_outcomes = []
    for i, _ in enumerate(routes):
        arr = np.stack(
            [rxn_samples[r] for r in reactions_in_routes[i]]
            + [mol_samples[m] for m in mols_in_routes[i]]
        )
        success_outcomes.append(np.min(arr, axis=0))
    success_outcomes = np.asarray(success_outcomes)

    # Compute various metrics over time
    success_probabilities_over_time = []
    num_nodes_over_time = []
    search_duration_over_time = []
    shortest_route_over_time = []
    most_feasible_route_over_time = []
    for t in analysis_times:
        # Two collections of nodes important for analysis
        nodes_before_at = [
            n for n in output_graph.nodes() if n.data["analysis_time"] <= t
        ]
        nodes_at_or_after_at = [
            n for n in output_graph.nodes() if n.data["analysis_time"] >= t
        ]

        # Nodes over time
        num_nodes_over_time.append(len(nodes_before_at))
        if len(nodes_at_or_after_at) > 0:
            search_duration_over_time.append(
                (
                    min(n.creation_time for n in nodes_at_or_after_at)
                    - output_graph.root_node.creation_time
                ).total_seconds()
            )
        else:
            search_duration_over_time.append(math.nan)

        # Various metrics
        if any(rt <= t for rt in route_times):
            success_probabilities_over_time.append(
                float(success_outcomes[np.asarray(route_times) <= t].max(axis=0).mean())
            )
            shortest_route_over_time.append(
                min([len(r) - 1 for r, rt in zip(routes, route_times) if rt <= t])
            )
            most_feasible_route_over_time.append(
                float(success_outcomes[np.asarray(route_times) <= t].mean(axis=1).max())
            )
        else:
            success_probabilities_over_time.append(0.0)
            shortest_route_over_time.append(math.inf)
            most_feasible_route_over_time.append(0.0)

    return dict(
        solution_time=soln_time,
        success_probabilities=success_probabilities_over_time,
        final_num_routes=len(routes),
        num_nodes_over_time=num_nodes_over_time,
        search_duration_over_time=search_duration_over_time,
        shortest_route_over_time=shortest_route_over_time,
        most_feasible_route_over_time=most_feasible_route_over_time,
        num_nodes_part_of_a_route=len(get_nodes_in_route(output_graph)),
    )


def save_output_graph(output_graph, filename) -> None:
    """Save an output graph, stripping out numpy arrays."""

    for node in output_graph.nodes():
        node.data = {
            k: v for k, v in node.data.items() if not isinstance(v, np.ndarray)
        }

    with open(filename, "wb") as f:
        pickle.dump(output_graph, f)


def main():
    main_logger.setLevel(logging.DEBUG)
    main_logger.info("START OF SCRIPT")

    # Get arguments
    args = get_parser().parse_args()
    main_logger.info(args)

    # Load SMILES to test
    with open(args.smiles_file, "r") as f:
        test_smiles = [line.strip() for line in f.readlines()]
    if args.smiles_start_idx is not None:
        test_smiles = test_smiles[args.smiles_start_idx : args.smiles_end_idx]
    main_logger.info(f"Loaded {len(test_smiles)} SMILES to test.")

    # Make reaction model, and feasibility model
    rxn_model = RetroStarReactionModel(use_cache=True, shuffle=args.shuffle_reactions)
    _use_cache_covar = True  # common arg to all GP-LVM models
    if args.marginal_feasibility == "constant":
        _kwargs = dict(constant=0.5, num_samples=args.num_samples)
        if args.corr_feasibility == "independent":
            feasibility_model = ConstantIndependentFeasibility(**_kwargs)
        elif args.corr_feasibility == "gplvm":
            feasibility_model = Constant_GPLVM_Feasibility(
                use_cached_covar_matrix=_use_cache_covar, **_kwargs
            )
        else:
            raise ValueError(f"Invalid corr_feasibility: {args.corr_feasibility}")
        del _kwargs
    elif args.marginal_feasibility == "rank":
        _kwargs = dict(constant=0.75, rank_div=10, num_samples=args.num_samples)
        if args.corr_feasibility == "independent":
            feasibility_model = RankIndependentFeasibility(**_kwargs)
        elif args.corr_feasibility == "gplvm":
            feasibility_model = Rank_GPLVM_Feasibility(
                use_cached_covar_matrix=_use_cache_covar, **_kwargs
            )
        else:
            raise ValueError(f"Invalid corr_feasibility: {args.corr_feasibility}")
        del _kwargs
    else:
        raise NotImplementedError(
            f"Invalid marginal_feasibility: {args.marginal_feasibility}"
        )

    # Make inventory and purchasability model
    inventory = eMoleculesInventory(max_tier=2)  # for now
    purchasability_model = BinaryEMoleculesPurchasability(num_samples=args.num_samples)

    # Make the algorithm
    high_integer = int(1e10)
    common_alg_kwargs = dict(
        reaction_model=rxn_model,
        mol_inventory=inventory,
        limit_reaction_model_calls=args.limit_rxn_model_calls,
        stop_on_first_solution=False,
    )
    retro_star_like_common_alg_kwargs = dict(
        limit_iterations=high_integer,
        max_expansion_depth=30,
        prevent_repeat_mol_in_trees=False,
        unique_nodes=True,
    )
    if args.algorithm == "retro-fallback":
        if args.heuristic == "optimistic":
            value_function = ConstantNodeEvaluator(1.0)
        elif args.heuristic == "sascore":
            value_function = SAScoreProbabilityHeuristic()
        else:
            raise ValueError(f"Invalid heuristic: {args.heuristic}")
        alg = SampleRetroFallback(
            feasibility_model=feasibility_model,
            purchasability_model=purchasability_model,
            value_function=value_function,
            early_stopping_success_threshold=0.99,
            incremental_updates=True,
            **common_alg_kwargs,
            **retro_star_like_common_alg_kwargs,
        )
    elif args.algorithm == "retro-star":
        if args.heuristic == "optimistic":
            value_function = ConstantNodeEvaluator(0.0)
        elif args.heuristic == "sascore":
            value_function = SAScoreCostHeuristic()
        else:
            raise ValueError(f"Invalid heuristic: {args.heuristic}")
        alg = RetroStarSearch(
            and_node_cost_fn=RetroStarMarginalFeasibilityCost(
                feasibility_model=feasibility_model
            ),
            or_node_cost_fn=RetroStarMarginalPurchasabilityCost(
                purchasability_model=purchasability_model
            ),
            value_function=value_function,
            **common_alg_kwargs,
            **retro_star_like_common_alg_kwargs,
        )
    elif args.algorithm == "mcts":
        if args.heuristic == "optimistic":
            value_function = MCTS_Optimistic_ValueFunction(
                feasibility_model, purchasability_model
            )
        elif args.heuristic == "sascore":
            value_function = MCTS_SAscore_ValueFunction(
                feasibility_model, purchasability_model
            )
        else:
            raise ValueError(f"Invalid heuristic: {args.heuristic}")
        alg = MolSetMCTS(
            value_function=value_function,
            reward_function=MCTS_SuccessReward(feasibility_model, purchasability_model),
            policy=MCTS_Marginal_Feasibility_Policy(feasibility_model),
            limit_iterations=high_integer,
            time_limit_s=10_000,  # super high limit, should not be reached
            max_expansion_depth=30,
            prevent_repeat_mol_in_trees=True,
            unique_nodes=False,
            bound_constant=0.01,
            bound_function=pucb_bound,
            min_num_visit_to_expand=10,  # gather a lot of information before expanding
            **common_alg_kwargs,
        )
        alg.init_mcts_value = 1.0  # optimistic start so that all nodes are visited
    elif args.algorithm == "breadth-first":
        alg = AndOr_BreadthFirstSearch(
            unique_nodes=True,
            limit_iterations=high_integer,
            **common_alg_kwargs,
        )
    elif args.algorithm == "ORbreadth-first":
        alg = MolSet_BreadthFirstSearch(
            unique_nodes=False,  # OR graphs not fully supported
            limit_iterations=high_integer,
            **common_alg_kwargs,
        )
    else:
        raise ValueError(f"Invalid algorithm: {args.algorithm}")

    # Run the search
    analysis_times = [
        t
        for t in [1, 5, 10, 25, 50, 100, 150, 200, 250, 500, 1000]
        if t <= args.limit_rxn_model_calls
    ]
    search_results = {}
    for i, smiles in enumerate(tqdm(test_smiles, dynamic_ncols=True, smoothing=0.0)):
        main_logger.info(f"Start search {i}/{len(test_smiles)}. SMILES: {smiles}")

        # Reset sample count
        num_samples_for_seach = args.num_samples_alg or args.num_samples
        feasibility_model.num_samples = num_samples_for_seach
        purchasability_model.num_samples = num_samples_for_seach

        # Reset and run search
        t0 = time.monotonic()
        alg.reset()
        output_graph, _ = alg.run_from_mol(Molecule(smiles))
        t1 = time.monotonic()
        main_logger.debug(f"Search done in {t1 - t0:.2f} s. Starting analysis.")

        # Reset feasibility models for eval
        feasibility_model.num_samples = args.num_samples
        purchasability_model.num_samples = args.num_samples

        # Potential convert to AND/OR graph for analysis
        if isinstance(output_graph, MolSetGraph) and args.andor_graph_analysis:
            output_graph = get_unique_node_andor_graph(output_graph)
            main_logger.debug(
                f"Converted OR graph to AND/OR graph for analysis. Resulting graph has {len(output_graph)} nodes."
            )

        # Analyze output graph
        if isinstance(output_graph, MolSetGraph):
            analysis_results = analyze_output_graph_molset(
                output_graph, feasibility_model, purchasability_model, analysis_times
            )
        else:
            analysis_results = analyze_output_graph(
                output_graph,
                feasibility_model,
                purchasability_model,
                analysis_times,
                calculate_most_feasible_route=not args.skip_most_feasible_route,
            )
        t2 = time.monotonic()

        # Store search results
        search_results[smiles] = dict(
            final_num_nodes=len(output_graph),
            total_search_time=t1 - t0,
            total_analysis_time=t2 - t1,
            **analysis_results,
        )
        main_logger.info(
            f"Done search + analysis in {t2 - t0:.2f} s. Results:\n{pformat(search_results[smiles])}"
        )

        # Potentially save graph, then delete to free up memory
        if args.save_output_graphs:
            save_output_graph(output_graph, f"output_graph{i}.pkl")
        del output_graph

    # Save json output
    main_logger.info("Saving json...")
    with open(args.output_json, "w") as f:
        json.dump(
            dict(
                args=args.__dict__,
                analysis_times=list(analysis_times),
                search_results=search_results,
            ),
            f,
            indent=2,
        )

    main_logger.info("END OF SCRIPT.")


if __name__ == "__main__":
    # Logging
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.INFO,
        format="%(asctime)s %(name)s %(levelname)s %(message)s",
        filemode="w",
    )
    main()
