from collections import Counter
import math

import numpy as np
import pytest

from syntheseus.search.chem import Molecule, BackwardReaction
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.tests.search.algorithms.test_base import BaseAlgorithmTest
from syntheseus.tests.search.conftest import *  # noqa[F403]
from syntheseus.tests.search.algorithms.test_best_first import DictRxnCost

from retro_fallback.rfb_sample import SampleRetroFallback, BinaryPurchasability
from retro_fallback.feasibility_model import IndependentFeasibilityModel
from retro_fallback.success_analysis import estimate_synthesis_success
from test_rfb_independent import mol_value_fn, BY_HAND_STEP1_RXNS  # noqa[F401]


class DefaultFeasibilityModel(IndependentFeasibilityModel):
    """Feasibility model assigning a constant independent probability to each reaction."""

    def __init__(self, constant_prob: float, num_samples: int, **kwargs):
        super().__init__(num_samples=num_samples, **kwargs)
        self.constant_prob = constant_prob
        self.rng = np.random.default_rng()

    def marginal_probability(self, reactions: list[BackwardReaction]) -> list[float]:
        return [self.constant_prob for _ in reactions]


@pytest.fixture
def rxn_feas_fn() -> IndependentFeasibilityModel:
    """Return feasibility function for hand-worked example."""
    return DefaultFeasibilityModel(constant_prob=0.5, num_samples=10_000)


class TestSampleRetroFallback(BaseAlgorithmTest):
    time_limit_multiplier = 3.0  # generally slower than most algorithms

    def setup_algorithm(self, **kwargs):
        kwargs.setdefault(
            "feasibility_model",
            DefaultFeasibilityModel(constant_prob=0.5, num_samples=10_000),
        )
        kwargs.setdefault("value_function", ConstantNodeEvaluator(1.0))
        return SampleRetroFallback(**kwargs)

    @pytest.mark.parametrize("prevent", [False, True])
    def test_prevent_repeat_mol_in_trees(
        self, prevent: bool, retrosynthesis_task5: RetrosynthesisTask  # noqa[F405]
    ) -> None:
        if prevent:
            with pytest.warns(UserWarning):
                with pytest.raises(AssertionError):
                    super().test_prevent_repeat_mol_in_trees(
                        prevent, retrosynthesis_task5
                    )
        else:
            pass  # the test here only makes sense for trees, so we skip it

    @pytest.mark.parametrize("expand_purchasable_mols", [False, True])
    def test_expand_purchasable_mols(
        self,
        retrosynthesis_task1: RetrosynthesisTask,  # noqa[F405]
        expand_purchasable_mols: bool,
    ) -> None:
        if expand_purchasable_mols:
            with pytest.raises(ValueError):
                super().test_expand_purchasable_mols(
                    retrosynthesis_task1, expand_purchasable_mols
                )
        else:
            super().test_expand_purchasable_mols(
                retrosynthesis_task1, expand_purchasable_mols
            )

    @pytest.mark.parametrize("set_depth", [True, False])
    def test_set_depth(
        self, set_depth: bool, retrosynthesis_task4: RetrosynthesisTask  # noqa[F405]
    ) -> None:
        """
        Test the 'set_depth' argument, which toggles whether the 'depth'
        attribute is set during node updates.

        The test is run on a small finite tree for simplicity.
        """
        if set_depth:
            super().test_set_depth(set_depth, retrosynthesis_task4)
        else:
            with pytest.raises(ValueError):
                super().test_set_depth(set_depth, retrosynthesis_task4)

    def test_by_hand_step1(
        self,
        retrosynthesis_task2: RetrosynthesisTask,  # noqa[F405]
        rxn_feas_fn: DictRxnCost,
        mol_value_fn,  # noqa[F811]
    ) -> None:
        r"""
        Test sample retro-fallback on similar example to the one in test_rfb_independent.py.

        In the first step, the algorithm should expand the root node,
        and the tree should have the following structure:

               ------------ COCS -----------------
              /      /      |     \    \     \     \
            C+OCS  CO+CS  COC+S  OOCS  SOCS  COCO  COCC
        """
        output_graph = self.run_alg_for_n_iterations(
            retrosynthesis_task2,
            1,
            feasibility_model=rxn_feas_fn,
            value_function=mol_value_fn,
        )
        assert output_graph.reaction_smiles_counter() == Counter(BY_HAND_STEP1_RXNS)  # type: ignore  # unsure about rxn_counter
        assert len(output_graph) == 18
        assert not output_graph.root_node.has_solution
        assert get_first_solution_time(output_graph) == math.inf
        assert np.allclose(output_graph.root_node.data["rfb_success_outcomes"], 0.0)

        # Check some priority values
        for smiles_str, expected_avg_success_tree in [
            (
                "CO",
                0.95
                * 0.5,  # this magic number is 0.95 estimated success from CS and 0.5 that CO.CS>>COCS succeeds
            ),
            ("CS", 0.95 * 0.5),
            ("C", 0.1 * 0.1 * 0.5),
            (
                "COCO",
                0.8 * 0.5,
            ),
            (
                "COCC",
                0.1 * 0.5,
            ),
        ]:
            assert math.isclose(
                output_graph._mol_to_node[Molecule(smiles_str)]
                .data["rfb_max_success_tree"]
                .mean(),
                expected_avg_success_tree,
                abs_tol=0.01,
            )

        # Check final success probability
        assert math.isclose(
            estimate_synthesis_success(
                output_graph, rxn_feas_fn, BinaryPurchasability(rxn_feas_fn.num_samples)
            ),
            0.0,
        )

    def test_by_hand_step2(
        self,
        retrosynthesis_task2: RetrosynthesisTask,  # noqa[F405]
        rxn_feas_fn: DictRxnCost,
        mol_value_fn,  # noqa[F811]
    ) -> None:
        r"""
        Continuation of test above.

        Should expand "CS", yielding 2 routes (from purchasable CC and CO)

               ------------ COCS -------------------------
              /           /      |         \    \     \     \
            C[1]+OCS  CO[2]+CS  COC+S[3]  OOCS  SOCS  COCO  COCC
                            |
                   -----------------
                  /         / \  \  \
                C[1]+S[3]  CC OS SS CO[2]

        Note: [X] indicate pairs of nodes which are the same. This is a graph, not a tree!
        """
        output_graph = self.run_alg_for_n_iterations(
            retrosynthesis_task2,
            2,
            feasibility_model=rxn_feas_fn,
            value_function=mol_value_fn,
        )
        assert (
            len(output_graph) == 26
        )  # 3 less than in test_rfb_independent.py because it runs on a graph
        assert output_graph.root_node.has_solution
        assert get_first_solution_time(output_graph) == 2

        # Average success outcome should be the same as original algorithm
        expected_synthesis_probability = 0.375
        assert math.isclose(
            output_graph.root_node.data["rfb_success_outcomes"].mean(),
            expected_synthesis_probability,
            abs_tol=0.01,
        )

        # Check some priority values
        success_mask = np.isclose(
            output_graph.root_node.data["rfb_success_outcomes"], 1.0
        )
        for smiles_str, expected_avg_success_tree in [
            (
                "COCO",
                0.5 * 0.8,  # success of COCO is independent of existing routes
            ),
            (
                "C",
                0.5 * 0.1 * 0.1,
            ),
        ]:
            success_array = np.ma.masked_array(
                output_graph._mol_to_node[Molecule(smiles_str)].data[
                    "rfb_max_success_tree"
                ],
                mask=success_mask,
            )
            assert math.isclose(
                success_array.mean(),
                expected_avg_success_tree,
                abs_tol=0.01,
            )

        # Check final success probability
        assert math.isclose(
            estimate_synthesis_success(
                output_graph, rxn_feas_fn, BinaryPurchasability(rxn_feas_fn.num_samples)
            ),
            expected_synthesis_probability,
            abs_tol=0.01,
        )

    def test_by_hand_step3(
        self,
        retrosynthesis_task2: RetrosynthesisTask,  # noqa[F405]
        rxn_feas_fn: DictRxnCost,
        mol_value_fn,  # noqa[F811]
    ) -> None:
        r"""
        Continuation of test above.

        Should expand "COCO", yielding 1 additional route.

               ------------ COCS ------------------------------
              /           /      |            \    \     \     \
            C[1]+OCS  CO[2]+CS  COC[4]+S[3]  OOCS  SOCS  COCO  COCC[5]
                            |                             |
                   -----------------                      |
                  /         / \  \  \                     |
                C[1]+S[3]  CC OS SS CO[2]                 |
                                                          |
                            -----------------------------------
                            /       /     /          \     \   \
                        C[1]+OCO  CO[2]  COC[4]+O  OOCO  SOCO  COCC[5]

        Note: [X] indicate pairs of nodes which are the same. This is a graph, not a tree!
        """
        output_graph = self.run_alg_for_n_iterations(
            retrosynthesis_task2,
            3,
            feasibility_model=rxn_feas_fn,
            value_function=mol_value_fn,
        )
        assert len(output_graph) == 36
        assert output_graph.root_node.has_solution
        assert get_first_solution_time(output_graph) == 2

        # Average success outcome should be the same as original algorithm
        expected_synthesis_probability = 1 - (1 - 0.375) * (1 - 0.25)
        assert math.isclose(
            output_graph.root_node.data["rfb_success_outcomes"].mean(),
            expected_synthesis_probability,
            abs_tol=0.01,
        )

        # Check some priority values
        success_mask = np.isclose(
            output_graph.root_node.data["rfb_success_outcomes"], 1.0
        )
        for smiles_str, expected_avg_success_tree in [
            (
                "COCC",
                0.5 * 0.1,  # NOTE: this number ignores contribution from below COCO
            ),
            (
                "SCOS",
                0.5 * 0.7,
            ),
        ]:
            success_array = np.ma.masked_array(
                output_graph._mol_to_node[Molecule(smiles_str)].data[
                    "rfb_max_success_tree"
                ],
                mask=success_mask,
            )
            assert math.isclose(
                success_array.mean(),
                expected_avg_success_tree,
                abs_tol=0.01,
            )

        # Check final success probability
        assert math.isclose(
            estimate_synthesis_success(
                output_graph, rxn_feas_fn, BinaryPurchasability(rxn_feas_fn.num_samples)
            ),
            expected_synthesis_probability,
            abs_tol=0.01,
        )
