import math
from pathlib import Path

import pandas as pd
from syntheseus.search.chem import Molecule
from syntheseus.search.mol_inventory import BaseMolInventory

from retro_fallback.feasibility_model import IndependentPurchasabilityModel

INVENTORY_CSV = Path(__file__).parent / "eMolecules" / "emolecules_inventory.csv"


class eMoleculesInventory(BaseMolInventory):
    def __init__(self, max_tier: int, eMolecules_file: str = INVENTORY_CSV, **kwargs):
        super().__init__(**kwargs)
        self.max_tier = max_tier

        # Read data frame
        df = pd.read_csv(eMolecules_file)
        smiles_list = df.smiles.to_list()
        tier_list = df.tier.to_list()

        # Make SMILES to tier dictionary
        self._smiles_to_tier = {s: int(tier) for s, tier in zip(smiles_list, tier_list)}

    def is_purchasable(self, mol: Molecule) -> bool:
        return self._smiles_to_tier.get(mol.smiles, math.inf) <= self.max_tier

    def fill_metadata(self, mol: Molecule) -> None:
        super().fill_metadata(mol)  # will fill is purchasable
        tier = self._smiles_to_tier.get(mol.smiles, None)
        if tier is not None:
            mol.metadata["emols_tier"] = tier


class BinaryEMoleculesPurchasability(IndependentPurchasabilityModel):
    """Model where all molecules <= a certain tier are purchasable."""

    def marginal_probability(self, molecules: list[Molecule]) -> list[float]:
        """Return marginal purchasability for a set of molecules."""
        return [1.0 if mol.metadata["is_purchasable"] else 0.0 for mol in molecules]
