"""
Compare the performance of the retro-fallback and retro-star algorithms
using the retro* model and various versions of independent feasibility models
and constant, optimistic value functions.

Essentially, it is pure search with no heuristic.
"""


from __future__ import annotations

import argparse
import logging
import json
import math
import sys
from typing import Any
import numpy as np
from tqdm import tqdm

from syntheseus.search.chem import Molecule, BackwardReaction
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator
from syntheseus.search.mol_inventory import BaseMolInventory
from syntheseus.search.reaction_models.base import BackwardReactionModel
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch

from retro_fallback.feasibility_model import (
    FeasibilityModel,
    IndependentFeasibilityModel,
)
from retro_fallback.rfb_sample import SampleRetroFallback, BinaryPurchasability
from retro_fallback.success_analysis import estimate_synthesis_success

from retro_star_task import (
    RetroStarReactionModel,
    RetroStarInventory,
)


class FeasibilityModelCost(NoCacheNodeEvaluator):
    """Uses marginal cost from a feasibility model to set node cost."""

    def __init__(
        self,
        feasibility_model: FeasibilityModel,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.feasibility_model = feasibility_model

    def _evaluate_nodes(self, nodes, graph=None) -> list[float]:
        feasibilities = list(
            self.feasibility_model.marginal_probability([n.reaction for n in nodes])
        )
        return [-math.log(f) for f in feasibilities]


def compare_algorithms(
    smiles_list: list[str],
    rxn_model: BackwardReactionModel,
    inventory: BaseMolInventory,
    feasibility_model: FeasibilityModel,
    limit_rxn_model_calls: int,
    analysis_times: list[float],
    use_tqdm: bool = False,
) -> dict[str, dict[str, dict[str, Any]]]:
    logger = logging.getLogger("alg_comparison")

    # Initialize graph versions of algorithms
    high_integer = int(1e10)
    common_alg_kwargs = dict(
        reaction_model=rxn_model,
        mol_inventory=inventory,
        limit_reaction_model_calls=limit_rxn_model_calls,
        limit_iterations=high_integer,
        max_expansion_depth=high_integer,
        prevent_repeat_mol_in_trees=False,
        unique_nodes=True,
        stop_on_first_solution=False,
    )
    retro_fallback_alg = SampleRetroFallback(
        feasibility_model=feasibility_model,
        value_function=ConstantNodeEvaluator(1.0),
        **common_alg_kwargs,
    )
    retro_star_alg = RetroStarSearch(
        and_node_cost_fn=FeasibilityModelCost(feasibility_model=feasibility_model),
        value_function=ConstantNodeEvaluator(0.0),
        **common_alg_kwargs,
    )
    algs = [("retro_fallback", retro_fallback_alg), ("retro_star", retro_star_alg)]

    # Set up progress bar
    if use_tqdm:
        smiles_iter = tqdm(
            smiles_list,
            dynamic_ncols=True,  # avoid issues open tmux on different screens
            smoothing=0.0,  # average speed, needed because searches vary a lot in length
        )
    else:
        smiles_iter = smiles_list

    # Run searches
    output: dict[str, dict[str, dict[str, Any]]] = {name: dict() for name, _ in algs}

    for i, smiles in enumerate(smiles_iter):
        logger.debug(f"Start search {i}/{len(smiles_list)}. SMILES: {smiles}")
        for name, alg in algs:
            # Run search (fully resetting)
            alg.reset()
            output_graph, _ = alg.run_from_mol(Molecule(smiles))

            # Get solution time
            for node in output_graph.nodes():
                node.data["analysis_time"] = node.data["num_calls_rxn_model"]
                del node  # to not interfere with garbage collection below
            soln_time = get_first_solution_time(output_graph)

            # Get success probability over time
            purchasability_model_for_analysis = BinaryPurchasability(
                num_samples=feasibility_model.num_samples
            )
            success_probabilities_over_time = [
                estimate_synthesis_success(
                    output_graph,
                    feasibility_model,
                    purchasability_model_for_analysis,
                    max_time=t,
                )
                for t in analysis_times
            ]

            # Log results
            output[name][smiles] = {
                "solution_time": soln_time,
                "num_nodes": len(output_graph),
                "success_probabilities": success_probabilities_over_time,
            }

            logger.debug(
                f"Done {name+':':<30s} nodes={len(output_graph):>8d}, solution time = {soln_time:>8.3g}, final success prob = {success_probabilities_over_time[-1]:>8.3f}."
            )

    return output


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smiles_file",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--output_json",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--limit_rxn_model_calls",
        type=int,
        default=25,
        help="Allowed number of calls to reaction model.",
    )
    parser.add_argument(
        "--feasibility_model",
        type=str,
        default="constant",
        help="Type of feasibility model to use.",
    )
    parser.add_argument(
        "--feasibility_a",
        type=float,
        default=1.0,
        help="Parameter 'a' for feasibility model (meaning varies by model).",
    )
    parser.add_argument(
        "--feasibility_k",
        type=float,
        default=1.0,
        help="Parameter 'k' for feasibility model (meaning varies by model).",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=10_000,
        help="Number of samples to use for feasibility/purchasability model.",
    )
    return parser


class RetroStarModelConstantFeasibilityModel(IndependentFeasibilityModel):
    """Feasibility model that uses the constant value function from retro-star."""

    def __init__(
        self,
        constant: float,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.constant = constant

    def marginal_probability(self, reactions: list[BackwardReaction]) -> list[float]:
        """Return marginal probability for a set of reactions."""
        return [self.constant for _ in reactions]


class RetroStarModelSoftmaxFeasibilityModel(IndependentFeasibilityModel):
    """
    Feasibility of a reaction is:
    a * (p ** k), where p is its softmax probability from the reaction model.
    """

    def __init__(
        self,
        a: float,
        k: float,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self._a = a
        self._k = k

    def marginal_probability(self, reactions: list[BackwardReaction]) -> list[float]:
        """Return marginal probability for a set of reactions."""
        return [self._a * (rxn.metadata["score"] ** self._k) for rxn in reactions]


class RetroStarModelReciprocalRankFeasibilityModel(IndependentFeasibilityModel):
    """
    Feasibility of a reaction is:
    a / (k + rank), where rank is the reaction's rank (lowest rank is 1.)
    """

    def __init__(
        self,
        a: float,
        k: float,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self._a = a
        self._k = k

    def marginal_probability(self, reactions: list[BackwardReaction]) -> list[float]:
        """Return marginal probability for a set of reactions."""
        return [self._a / (rxn.metadata["rank"] + self._k) for rxn in reactions]


if __name__ == "__main__":
    args = get_parser().parse_args()

    # Logging
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format="%(asctime)s %(name)s %(levelname)s %(message)s",
        filemode="w",
    )
    logging.getLogger().info(args)

    # Load all SMILES to test
    with open(args.smiles_file, "r") as f:
        test_smiles = [line.strip() for line in f.readlines()]

    # Make reaction model, inventory, feasibility model
    rxn_model = RetroStarReactionModel()
    inventory = RetroStarInventory()
    if args.feasibility_model == "constant":
        feasibility_model = RetroStarModelConstantFeasibilityModel(
            constant=args.feasibility_a, num_samples=args.num_samples
        )
    elif args.feasibility_model == "softmax":
        feasibility_model = RetroStarModelSoftmaxFeasibilityModel(
            a=args.feasibility_a, k=args.feasibility_k, num_samples=args.num_samples
        )
    elif args.feasibility_model == "reciprocal_rank":
        feasibility_model = RetroStarModelReciprocalRankFeasibilityModel(
            a=args.feasibility_a, k=args.feasibility_k, num_samples=args.num_samples
        )
    else:
        raise NotImplementedError(args.feasibility_model)

    # Run each value function
    analysis_times = np.linspace(0, args.limit_rxn_model_calls, 10)
    overall_results = dict(
        args=args.__dict__,
        analysis_times=analysis_times.tolist(),
    )
    overall_results["results"] = compare_algorithms(
        smiles_list=test_smiles,
        rxn_model=rxn_model,
        inventory=inventory,
        use_tqdm=True,
        feasibility_model=feasibility_model,
        limit_rxn_model_calls=args.limit_rxn_model_calls,
        analysis_times=analysis_times,
    )

    # Save results
    with open(args.output_json, "w") as f:
        json.dump(overall_results, f, indent=2)
