"""Retro-fallback algorithm for graphs, not assuming independent reaction feasibility."""

from __future__ import annotations

from collections.abc import Collection, Sequence
import logging
from typing import Optional
import warnings

import numpy as np
from syntheseus.search.chem import Molecule, BackwardReaction
from syntheseus.search.graph.and_or import AndNode, OrNode, ANDOR_NODE, AndOrGraph
from syntheseus.search.algorithms.base import AndOrSearchAlgorithm
from syntheseus.search.algorithms.mixins import ValueFunctionMixin
from syntheseus.search.node_evaluation.base import BaseNodeEvaluator
from syntheseus.search.graph.message_passing import run_message_passing

from retro_fallback.feasibility_model import FeasibilityModel, PurchasabilityModel


logger = logging.getLogger(__name__)


class BinaryPurchasability(PurchasabilityModel):
    """Assigns purchasability of 1.0 if purchasable else 0.0, deterministically and independently."""

    def posterior_sample(
        self, molecules: set[Molecule], observed_samples: dict[Molecule, np.ndarray]
    ) -> dict[Molecule, np.ndarray]:
        return {
            mol: np.full(self.num_samples, float(mol.metadata.get("is_purchasable")))
            for mol in molecules
        }

    def marginal_probability(self, molecules: list[Molecule]) -> list[float]:
        return [float(mol.metadata.get("is_purchasable")) for mol in molecules]


class SampleRetroFallback(
    AndOrSearchAlgorithm[int],
    ValueFunctionMixin[OrNode],
):
    def __init__(
        self,
        *args,
        feasibility_model: FeasibilityModel,
        purchasability_model: Optional[PurchasabilityModel] = None,
        early_stopping_success_threshold: float = 1.0,
        incremental_updates: bool = True,
        **kwargs,
    ):
        # Change default to run on a graph
        kwargs.setdefault(
            "prevent_repeat_mol_in_trees", False
        )  # sensible default setting
        kwargs.setdefault("unique_nodes", True)

        # Normal init
        super().__init__(*args, **kwargs)
        self.feasibility_model = feasibility_model
        self.purchasability_model = purchasability_model or BinaryPurchasability(
            num_samples=self.feasibility_model.num_samples
        )
        self.early_stopping_success_threshold = early_stopping_success_threshold
        self.incremental_updates = incremental_updates

        # Check certain things
        if self.expand_purchasable_mols:
            raise ValueError("Currently not supported.")
        if not self.set_depth:
            raise ValueError("Currently not supported.")

    def reset(self):
        super().reset()
        self.feasibility_model.reset()
        self.purchasability_model.reset()

    @property
    def requires_tree(self) -> bool:
        return False

    @property
    def mol_success_estimator(self) -> BaseNodeEvaluator[OrNode]:
        """Alias for value function."""
        return self.value_function

    def setup(self, graph: AndOrGraph) -> None:
        # If there is only one node, set its success estimate arbitrarily
        # This saves a call to the value function
        if len(graph) == 1:
            # NOTE: could fail if input node is purchasble.
            # Deal with this in a nice way later
            graph.root_node.data.setdefault("rfb_success_estimate", 1.0)

        # Clear samples so that the graph is initialized with a consistent sample
        for node in graph.nodes():
            node.data.pop("rfb_purchase_samples", None)
            node.data.pop("rfb_feasible_samples", None)

        return super().setup(graph)

    def _run_from_graph_after_setup(self, graph: AndOrGraph) -> int:
        # Logging setup
        log_level = logging.DEBUG - 1
        logger_active = logger.isEnabledFor(log_level)

        # Run search until time limit or no nodes to expand
        step = 0
        terminate = False
        for step in range(self.limit_iterations):
            # If success is >= early stopping threshold (default 1.0), terminate
            root_mean_success = np.mean(graph.root_node.data["rfb_success_outcomes"])
            if root_mean_success >= self.early_stopping_success_threshold:
                logger.log(
                    log_level,
                    f"{100*root_mean_success:.2f}% success reached, will terminate.",
                )
                terminate = True

            eligible_nodes = [
                n for n in graph.nodes() if self.can_expand_node(n, graph)
            ]
            if len(eligible_nodes) == 0:
                logger.log(log_level, "No nodes to expand, will terminate.")
                terminate = True

            if self.should_stop_search(graph) or terminate:
                logger.log(log_level, f"Terminate condition reached at step {step}")
                break

            # Choose a leaf node to expand
            # NOTE: current implementation is not very efficient, could replace later
            root_solved_mask = np.isclose(
                graph.root_node.data["rfb_success_outcomes"], 1.0
            )
            chosen_leaf = max(
                eligible_nodes,
                key=lambda n: (
                    np.ma.array(
                        n.data["rfb_max_success_tree"], mask=root_solved_mask
                    ).mean(),  # mean success improvement
                    -n.creation_time.timestamp(),  # break ties with older nodes
                ),
            )

            # Do the expansion
            new_nodes = self.expand_node(chosen_leaf, graph)  # expand the node

            # Do updates
            if self.incremental_updates:
                nodes_to_update = new_nodes + [chosen_leaf]
            else:
                nodes_to_update = list(graph.nodes())

                # Initialize all arrays to zero
                for n in nodes_to_update:
                    for k in [
                        "rfb_success_outcomes",
                        "rfb_expand_success_estimate",
                        "rfb_max_success_tree",
                    ]:
                        n.data[k] = np.zeros(self.feasibility_model.num_samples)

            # Run updates so that next node can be chosen
            updated_nodes = self.set_node_values(nodes_to_update, graph)

            if logger_active:
                with np.printoptions(threshold=10, precision=3):
                    logger.log(
                        level=log_level,
                        msg=f"Step {step}:\tnode={chosen_leaf}, expanded: {len(new_nodes)} new nodes created, "
                        f"{len(updated_nodes)} nodes updated. "
                        f"Graph size = {len(graph)}, root success = {graph.root_node.data['rfb_success_outcomes'].mean():.3f}",
                    )

    def set_node_values(  # type: ignore[override]
        self, nodes: Collection[ANDOR_NODE], graph: AndOrGraph
    ) -> Collection[ANDOR_NODE]:
        # Call superclass (to update depth)
        output_nodes = super().set_node_values(nodes, graph)
        del nodes  # unused

        # Fill in feasibilities, purchasabilities, and success estimates
        self._set_purchasabilities_feasibilities(graph)
        self._set_success_estimate(  # only for unexpanded leaf nodes
            or_nodes=[
                node
                for node in output_nodes
                if isinstance(node, OrNode)
                and "rfb_success_estimate" not in node.data
                and self.can_expand_node(node, graph)
            ],
            graph=graph,
        )

        # Update success probabilities efficiently
        return self._run_retro_fallback_updates(output_nodes, graph)

    def _set_purchasabilities_feasibilities(self, graph: AndOrGraph) -> None:
        """
        Sets purchasability and feasibility samples for retro-fallback.

        Because they are samples from the joint distribution,
        they are in general dependent, and therefore it must
        be done for the whole graph at once.
        """

        # Get a list of molecules and their samples
        or_nodes_without_samples: set[OrNode] = set()
        and_nodes_without_samples: set[AndNode] = set()
        mol_to_samples: dict[Molecule, np.ndarray] = dict()
        rxn_to_samples: dict[BackwardReaction, np.ndarray] = dict()
        for node in graph.nodes():
            # NOTE: this update does not account for the possibility
            # that nodes could occur multiple times in the graph but have different samples.
            # In theory this shouldn't happen because samples are reset at initialization,
            # but it is always possible...
            if isinstance(node, OrNode):
                if "rfb_purchase_samples" not in node.data:
                    or_nodes_without_samples.add(node)
                else:
                    mol_to_samples[node.mol] = node.data["rfb_purchase_samples"]
            elif isinstance(node, AndNode):
                if "rfb_feasible_samples" not in node.data:
                    and_nodes_without_samples.add(node)
                else:
                    rxn_to_samples[node.reaction] = node.data["rfb_feasible_samples"]

        # Update 1: samples for new molecules
        # ==================================================

        # Now, calculate which molecules don't have samples.
        # This is not the same as nodes without samples, since molecules could occur multiple times in a tree.
        mols_without_samples = set(
            [node.mol for node in or_nodes_without_samples]
        ) - set(mol_to_samples.keys())

        # If there are any molecules without samples, sample them
        if len(mols_without_samples) > 0:
            mol_to_samples.update(
                self.purchasability_model.posterior_sample(
                    molecules=mols_without_samples, observed_samples=mol_to_samples
                )
            )

        # Update nodes without samples
        for node in or_nodes_without_samples:
            node.data["rfb_purchase_samples"] = mol_to_samples[node.mol]

        # Update 2: reaction success
        # ==================================================
        # (largely follows above)

        rxns_without_samples = set(
            [node.reaction for node in and_nodes_without_samples]
        ) - set(rxn_to_samples.keys())
        if len(rxns_without_samples) > 0:
            rxn_to_samples.update(
                self.feasibility_model.posterior_sample(
                    reactions=rxns_without_samples, observed_samples=rxn_to_samples
                )
            )
        for node in and_nodes_without_samples:
            node.data["rfb_feasible_samples"] = rxn_to_samples[node.reaction]

    def _set_success_estimate(
        self, or_nodes: Sequence[OrNode], graph: AndOrGraph
    ) -> None:
        values = self.value_function(or_nodes, graph=graph)
        assert len(values) == len(or_nodes)
        for node, v in zip(or_nodes, values):
            node.data["rfb_success_estimate"] = v

    def _run_retro_fallback_updates(
        self, nodes: Collection[ANDOR_NODE], graph: AndOrGraph
    ) -> Collection[ANDOR_NODE]:
        nodes_to_update = set(nodes)

        # Initialize values
        for node in nodes_to_update:
            for k in [
                "rfb_success_outcomes",
                "rfb_expand_success_estimate",
                "rfb_max_success_tree",
            ]:
                if k not in node.data:
                    node.data[k] = np.zeros(self.feasibility_model.num_samples)

        # First, run success outcome updates
        nodes_to_update.update(
            run_message_passing(
                graph=graph,
                nodes=sorted(nodes_to_update, key=lambda n: n.depth, reverse=True),
                update_fns=[rfb_success_outcomes_update],
                update_predecessors=True,
                update_successors=False,
            )
        )

        MAX_ITER = 10 * len(graph)  # Later we could remove this hard-coded value
        try:
            # Second, run success estimate updates
            nodes_to_update.update(
                run_message_passing(
                    graph=graph,
                    nodes=sorted(nodes_to_update, key=lambda n: n.depth, reverse=True),
                    update_fns=[rfb_expand_success_estimate_update],
                    update_predecessors=True,
                    update_successors=False,
                    max_iterations=MAX_ITER,
                )
            )

            # Third, run success tree updates (in reverse order)
            nodes_to_update.update(
                run_message_passing(
                    graph=graph,
                    nodes=sorted(nodes_to_update, key=lambda n: n.depth, reverse=False),
                    update_fns=[rfb_max_success_tree_update],
                    update_predecessors=False,
                    update_successors=True,
                    max_iterations=MAX_ITER,
                )
            )
        except RuntimeError:
            if not self.incremental_updates:
                warnings.warn(
                    "Incremental updates are off, but it still took a large number of iterations to converge."
                )
            logger.warning(
                f"Message passing did not converge after {MAX_ITER} iterations. "
                "Will perform updates on all nodes."
            )

            nodes_to_update = set(graph.nodes())
            for node in nodes_to_update:
                for k in [
                    "rfb_expand_success_estimate",
                    "rfb_max_success_tree",
                ]:
                    node.data[k] = np.zeros(self.feasibility_model.num_samples)

            run_message_passing(
                graph=graph,
                nodes=sorted(nodes_to_update, key=lambda n: n.depth, reverse=True),
                update_fns=[rfb_expand_success_estimate_update],
                update_predecessors=True,
                update_successors=False,
            )
            run_message_passing(
                graph=graph,
                nodes=sorted(nodes_to_update, key=lambda n: n.depth, reverse=False),
                update_fns=[rfb_max_success_tree_update],
                update_predecessors=False,
                update_successors=True,
            )

        return nodes_to_update


def rfb_success_outcomes_update(node: ANDOR_NODE, graph: AndOrGraph) -> bool:
    """
    Updates "rfb_success_outcomes", indicators of whether a node was successfully synthesized.

    For OR nodes, this is just max(success of children, purchase success).

    For AND nodes, this is just the product of the success of children and reaction feasibility.

    NOTE: updates depend only on children, so can avoid updating children in update prop.
    """

    # Step 1: calculate new success prob
    if isinstance(node, OrNode):
        success_outcomes = node.data["rfb_purchase_samples"]
        for or_child in graph.successors(node):
            success_outcomes = np.maximum(
                success_outcomes, or_child.data["rfb_success_outcomes"]
            )

    elif isinstance(node, AndNode):
        # Success prob is product of reaction feasibility and children success probs
        success_outcomes = node.data["rfb_feasible_samples"]
        for and_child in graph.successors(node):
            success_outcomes = success_outcomes * and_child.data["rfb_success_outcomes"]
    else:
        raise TypeError("Only AND/OR nodes supported.")

    # Step 2: set the new value and return
    assert set(np.unique(success_outcomes)) <= {0, 1}
    old_value: Optional[np.ndarray] = node.data.get("rfb_success_outcomes", None)
    node.data["rfb_success_outcomes"] = success_outcomes
    value_changed = old_value is None or not np.allclose(old_value, success_outcomes)
    return value_changed


def rfb_expand_success_estimate_update(node: ANDOR_NODE, graph: AndOrGraph) -> bool:
    """
    Updates "rfb_expand_success_estimate": an estimate of the success probability
    of this node if a single route under it is expanded.

    For leaf OR nodes, this is max(purchase outcomes, success estimate (if defined)).

    For AND nodes, it is feasibility * product of children success estimates.

    For OR nodes, it is max of children success estimates and purchasing.

    NOTE: updates depend only on children, so can avoid updating children in update prop.
    """

    # Step 1: calculate new success prob
    if isinstance(node, OrNode):
        # Option 1: success from purchasing
        new_value = node.data["rfb_purchase_samples"]

        # Option 2: success comes from children
        for or_child in graph.successors(node):
            new_value = np.maximum(
                new_value, or_child.data["rfb_expand_success_estimate"]
            )

        # Option 3: success comes from estimate (only if not expanded)
        # NOTE: not expanded is slightly different from "cannot be expanded",
        # so this update is implicitly relying on nodes which cannot be expanded
        # not being given success estimates
        if not node.is_expanded and "rfb_success_estimate" in node.data:
            new_value = np.maximum(new_value, node.data["rfb_success_estimate"])

    elif isinstance(node, AndNode):
        # Success prob is product of reaction feasibility and children success probs
        new_value = node.data["rfb_feasible_samples"]
        for and_child in graph.successors(node):
            new_value = new_value * and_child.data["rfb_expand_success_estimate"]
    else:
        raise TypeError("Only AND/OR nodes supported.")

    # Step 2: set the new value and return
    assert np.all(new_value <= 1) and np.all(
        new_value >= 0
    )  # should be float in [0, 1]
    old_value: Optional[np.ndarray] = node.data.get("rfb_expand_success_estimate", None)
    node.data["rfb_expand_success_estimate"] = new_value
    value_changed = old_value is None or not np.allclose(old_value, new_value)
    return value_changed


def rfb_max_success_tree_update(node: ANDOR_NODE, graph: AndOrGraph) -> bool:
    """
    Updates "rfb_max_success_tree": the maximum success probability of any tree containing
    this node and the root node, assuming other nodes success values attain their heuristic estimates.

    For the root node, this is the same as its "rfb_expand_success_estimate"

    For AND nodes, it is "rfb_max_success_tree" of parent / "rfb_expand_success_estimate" of parent,
    * "rfb_expand_success_estimate" of this node.
    If viewed in log space, this is the same update as retro*.

    For other OR nodes, it is just the max of the parents's "rfb_max_success_tree" values.

    NOTE: updates depend only on parents.
    """
    parents = list(graph.predecessors(node))

    # Step 1: calculate new success prob
    if isinstance(node, OrNode):
        if len(parents) == 0:
            # Root node
            new_value = node.data["rfb_expand_success_estimate"]
        else:
            new_value = parents[0].data["rfb_max_success_tree"]
            for parent in parents[1:]:
                new_value = np.maximum(new_value, parent.data["rfb_max_success_tree"])

    elif isinstance(node, AndNode):
        assert len(parents) == 1
        parent = parents[0]

        # Do the calculation ignoring 0/0 = NaN, then correct for it later
        # Note that 1/0 can never occur
        with np.errstate(divide="ignore", invalid="ignore"):
            new_value = (
                parent.data["rfb_max_success_tree"]
                / parent.data["rfb_expand_success_estimate"]
                * node.data["rfb_expand_success_estimate"]
            )
        new_value = np.nan_to_num(new_value, nan=0.0, posinf=np.nan, neginf=np.nan)

    else:
        raise TypeError("Only AND/OR nodes supported.")

    # Step 2: set the new value and return
    assert np.all(new_value <= 1) and np.all(
        new_value >= 0
    )  # should be float in [0, 1]
    old_value: Optional[np.ndarray] = node.data.get("rfb_max_success_tree", None)
    node.data["rfb_max_success_tree"] = new_value
    value_changed = old_value is None or not np.allclose(old_value, new_value)
    return value_changed
