from __future__ import annotations

import abc
import numpy as np

from syntheseus.search.chem import Molecule, BackwardReaction


class FeasibilityModel(abc.ABC):
    """Model assigning feasibility probabilities and outcomes to reactions."""

    def __init__(self, num_samples: int, **kwargs):
        super().__init__(**kwargs)
        self.num_samples = num_samples

    @abc.abstractmethod
    def posterior_sample(
        self,
        reactions: set[BackwardReaction],
        observed_samples: dict[BackwardReaction, np.ndarray],
    ) -> dict[BackwardReaction, np.ndarray]:
        """Sample feasibility outcomes for a set of reactions, given a set of observed outcomes for other reactions."""

    def prior_sample(
        self, reactions: set[BackwardReaction]
    ) -> dict[BackwardReaction, np.ndarray]:
        """Sample feasibility outcomes for a set of reactions."""
        return self.posterior_sample(reactions, {})

    @abc.abstractmethod
    def marginal_probability(self, reactions: list[BackwardReaction]) -> list[float]:
        """Return marginal probability for a set of reactions."""

    def reset(self):
        """Resets this model (deleting any caches/stored information)."""
        pass


class IndependentFeasibilityModel(FeasibilityModel):
    """
    Class whose feasibility model assumes independence between reactions.
    Sampling is therefore entirely determined by marginal probabilities.
    """

    def posterior_sample(
        self,
        reactions: set[BackwardReaction],
        observed_samples: dict[BackwardReaction, np.ndarray],
    ) -> dict[BackwardReaction, np.ndarray]:
        """Sample feasibility outcomes for a set of reactions, given a set of observed outcomes for other reactions."""
        rxn_list = list(reactions)
        marginal_probabilities = self.marginal_probability(rxn_list)
        return {
            rxn: np.random.binomial(1, p, self.num_samples)
            for rxn, p in zip(rxn_list, marginal_probabilities)
        }


class PurchasabilityModel(abc.ABC):
    """Model assigning purchasability probabilities and outcomes to molecules."""

    def __init__(self, num_samples: int, **kwargs):
        super().__init__(**kwargs)
        self.num_samples = num_samples

    @abc.abstractmethod
    def posterior_sample(
        self, molecules: set[Molecule], observed_samples: dict[Molecule, np.ndarray]
    ) -> dict[Molecule, np.ndarray]:
        """Sample purchasability outcomes for a set of Molecules, given a set of observed outcomes for other Molecules."""

    def prior_sample(self, molecules: set[Molecule]) -> dict[Molecule, np.ndarray]:
        """Sample purchasability outcomes for a set of molecules."""
        return self.posterior_sample(molecules, {})

    @abc.abstractmethod
    def marginal_probability(self, molecules: list[Molecule]) -> list[float]:
        """Return marginal purchasability for a set of molecules."""

    def reset(self):
        """Resets this model (deleting any caches/stored information)."""
        pass


class IndependentPurchasabilityModel(PurchasabilityModel):
    """
    Class whose purchasability model assumes independence between molecules.
    Sampling is therefore entirely determined by marginal probabilities.
    """

    def posterior_sample(
        self,
        molecules: set[Molecule],
        observed_samples: dict[Molecule, np.ndarray],
    ) -> dict[Molecule, np.ndarray]:
        """Sample purchasability outcomes for a set of molecules, given a set of observed outcomes for other molecules."""
        mol_list = list(molecules)
        marginal_probabilities = self.marginal_probability(mol_list)
        return {
            mol: np.random.binomial(1, p, self.num_samples)
            for mol, p in zip(mol_list, marginal_probabilities)
        }
