from collections.abc import Sequence

import numpy as np
from gensim.models import Word2Vec
from rdkit.Chem import AllChem, Mol
from skfp.bases import BaseFingerprintTransformer
from skfp.utils import ensure_mols


class Mol2VecFingerprint(BaseFingerprintTransformer):
    def __init__(
        self,
        model_path: str = "outputs/mol2vec.model",
        n_jobs: int | None = None,
        batch_size: int | None = None,
        verbose: int | dict = 0,
    ):
        self.model_path = model_path
        self.model = Word2Vec.load(model_path)
        self._unk_vec = self.model.wv.get_vector("UNK")

        super().__init__(
            n_features_out=self.model.vector_size,
            n_jobs=n_jobs,
            batch_size=batch_size,
            verbose=verbose,
        )

    def _calculate_fingerprint(self, X: Sequence[str | Mol]) -> np.ndarray:
        from rdkit import rdBase

        # turn off unnecessary RDKit warnings
        rdBase.DisableLog("rdApp.*")

        X = ensure_mols(X)
        embeddings = [self._get_mol_embedding(mol) for mol in X]
        return np.vstack(embeddings)

    def _get_mol_embedding(self, mol: Mol) -> np.ndarray:
        # info: subgraph identifier -> (atom_idx, radius):
        info = dict()
        AllChem.GetMorganFingerprint(mol, 1, bitInfo=info)

        sentence = np.zeros((mol.GetNumAtoms(), 2), dtype=int)

        for identifier, element in info.items():
            for atom_idx, radius_at in element:
                sentence[atom_idx, radius_at] = identifier

        sentence = sentence.ravel().tolist()

        token_embeddings = [
            (
                self.model.wv.get_vector(str(identifier))
                if str(identifier) in self.model.wv.key_to_index
                else self._unk_vec
            )
            for identifier in sentence
        ]
        mol_embedding = np.sum(token_embeddings, axis=0)

        return mol_embedding
