"""Code to analyze search success."""

from __future__ import annotations
from collections.abc import Collection, Iterator
import logging
import math
import time
from typing import Optional

import numpy as np


from syntheseus.search.graph.and_or import AndNode, OrNode, AndOrGraph, ANDOR_NODE
from syntheseus.search.graph.message_passing import run_message_passing
from syntheseus.search.analysis.route_extraction import _iter_top_routes

from retro_fallback.feasibility_model import FeasibilityModel, PurchasabilityModel
from retro_fallback.rfb_sample import rfb_success_outcomes_update

logger = logging.getLogger(__name__)


def estimate_synthesis_success_across_time(
    graph: AndOrGraph,
    feasibility_model: FeasibilityModel,
    purchasability_model: PurchasabilityModel,
    max_times: list[Optional[float]] = [],
) -> list[float]:
    """
    Estimates synthesis success over time.
    Faster than calling single version repeatedly because it only draws samples once.

    NOTE: will overwrite existing samples on the graph.

    For efficiency, an initial computation of the success outcomes of each node is done
    using an outcome of 1 for any reaction or molecule with non-zero marginal
    feasibility/purchasability. Any unsuccessful reactions in this initial computation
    are not sampled, because they will not affect the success outcomes of the root node.
    This procedure saves time in large graphs where many nodes are unsolved.
    """

    output = []

    # Get molecules in graph
    mols = list({node.mol for node in graph.nodes() if isinstance(node, OrNode)})
    rxns = list({node.reaction for node in graph.nodes() if isinstance(node, AndNode)})
    logger.debug(f"Found {len(mols)} molecules and {len(rxns)} reactions in graph.")

    # Set initial success outcomes
    t = time.monotonic()
    mol_to_marginal = {
        m: p for m, p in zip(mols, purchasability_model.marginal_probability(mols))
    }
    rxn_to_marginal = {
        r: p for r, p in zip(rxns, feasibility_model.marginal_probability(rxns))
    }
    for node in graph.nodes():
        node.data["rfb_success_outcomes"] = np.zeros(feasibility_model.num_samples)
        if isinstance(node, OrNode):
            node.data["rfb_purchase_samples"] = np.ones(1) * float(
                mol_to_marginal[node.mol] > 0
            )
        elif isinstance(node, AndNode):
            node.data["rfb_feasible_samples"] = np.ones(1) * float(
                rxn_to_marginal[node.reaction] > 0
            )
        else:
            raise TypeError(f"Unexpected node type {type(node)}")
    run_message_passing(
        graph=graph,
        nodes=sorted(graph.nodes(), key=lambda n: n.depth, reverse=True),
        update_fns=[rfb_success_outcomes_update],
        update_predecessors=True,
        update_successors=False,
    )
    mols_to_sample = {
        node.mol
        for node in graph.nodes()
        if isinstance(node, OrNode) and node.data["rfb_success_outcomes"].mean() > 0
    }
    rxns_to_sample = {
        node.reaction
        for node in graph.nodes()
        if isinstance(node, AndNode) and node.data["rfb_success_outcomes"].mean() > 0
    }
    logger.debug(
        f"Found {len(mols_to_sample)} mols and {len(rxns_to_sample)} reactions to sample "
        f"in {time.monotonic()-t:.2f} s."
    )

    # Get samples from models
    t = time.monotonic()
    rxn_samples = feasibility_model.prior_sample(rxns_to_sample)
    t2 = time.monotonic()
    logger.debug(f"Sampled {len(rxn_samples)} reactions in {t2-t:.2f} s.")
    mol_samples = purchasability_model.prior_sample(mols_to_sample)
    t3 = time.monotonic()
    logger.debug(f"Sampled {len(mols_to_sample)} molecules in {t3-t2:.2f} s.")

    # Assign samples to nodes
    for i_t, max_time in enumerate(max_times):
        logger.debug(f"Start {i_t}th estimate of synthesis success at time {max_time}.")
        t_i = time.monotonic()
        for node in graph.nodes():
            # Initialize success outcomes
            node.data["rfb_success_outcomes"] = np.zeros(feasibility_model.num_samples)

            # Set samples
            in_time = (max_time is None) or (node.data["analysis_time"] <= max_time)
            if isinstance(node, OrNode):
                if node.mol in mols_to_sample and in_time:
                    samples = mol_samples[node.mol]
                else:
                    samples = np.zeros(feasibility_model.num_samples)
                node.data["rfb_purchase_samples"] = samples
            elif isinstance(node, AndNode):
                if node.reaction in rxns_to_sample and in_time:
                    samples = rxn_samples[node.reaction]
                else:
                    samples = np.zeros(feasibility_model.num_samples)
                node.data["rfb_feasible_samples"] = samples
            else:
                raise TypeError(f"Unexpected node type {type(node)}")

        t_i2 = time.monotonic()
        logger.debug(f"Assigned samples in {t_i2-t_i:.2f} s.")

        # Run message passing from retro-fallback
        run_message_passing(
            graph=graph,
            nodes=sorted(graph.nodes(), key=lambda n: n.depth, reverse=True),
            update_fns=[rfb_success_outcomes_update],
            update_predecessors=True,
            update_successors=False,
        )
        t_i3 = time.monotonic()
        logger.debug(f"Calculated success outcomes in {t_i3-t_i2:.2f} s.")

        # Store mean success rate of root node as output
        output.append(float(graph.root_node.data["rfb_success_outcomes"].mean()))
        logger.debug(
            f"Estimated success rate: {output[-1]:.2f} in {time.monotonic() - t_i:.2f} s."
        )

    return output


def estimate_synthesis_success(
    graph: AndOrGraph,
    feasibility_model: FeasibilityModel,
    purchasability_model: PurchasabilityModel,
    max_time: Optional[float] = None,
) -> float:
    """Single estimate of synthesis success."""
    return estimate_synthesis_success_across_time(
        graph=graph,
        feasibility_model=feasibility_model,
        purchasability_model=purchasability_model,
        max_times=[max_time],
    )[0]


def _feasibility_partial_cost(nodes, graph) -> float:
    """Lower bound is -log(success probability of all reactions)."""
    rxn_samples = np.array(
        [node.data["route_feas_samples"] for node in nodes if isinstance(node, AndNode)]
    )
    rxns_all_succeed = np.all(rxn_samples, axis=0)
    # NOTE: in the future could probably also bound using success outcomes of children?
    succ_prob = float(np.mean(rxns_all_succeed))
    if succ_prob == 0:
        return math.inf
    else:
        return -math.log(succ_prob)


def _feasibility_cost(nodes, graph) -> float:
    """Cost is -log(success probability of all reactions and molecules without children)."""
    all_samples = []
    for node in nodes:
        has_children_in_route = len(set(graph.successors(node)) & set(nodes)) > 0
        if isinstance(node, AndNode) or not has_children_in_route:
            all_samples.append(node.data["route_feas_samples"])
    all_succeed = np.all(np.array(all_samples), axis=0)
    succ_prob = float(np.mean(all_succeed))
    if succ_prob == 0:
        return math.inf
    else:
        return -math.log(succ_prob)


def iter_routes_feasibility_order(
    graph: AndOrGraph,
    max_routes: int,
) -> Iterator[tuple[float, Collection[ANDOR_NODE]]]:
    """
    Iterate over routes in order of increasing feasibility.

    Assumes that node.data["route_feas_samples"] is set
    (this represents a union of purchasability and feasibility samples).
    """
    yield from _iter_top_routes(
        graph=graph,
        max_routes=max_routes,
        cost_fn=_feasibility_cost,
        cost_lower_bound=_feasibility_partial_cost,
    )
