from collections.abc import Sequence

import numpy as np
import torch
from rdkit.Chem import Mol
from skfp.bases import BaseFingerprintTransformer
from skfp.utils import ensure_smiles
from torch import nn
from transformers import AutoModelForMaskedLM, AutoTokenizer


class ChemBERTaFingerprint(BaseFingerprintTransformer):
    def __init__(
        self,
        model_path: str = "DeepChem/ChemBERTa-77M-MLM",
        n_jobs: int | None = None,
        batch_size: int | None = None,
        verbose: int | dict = 0,
    ):
        self.model_path = model_path
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.embedder = AutoModelForMaskedLM.from_pretrained(model_path)
        self.embedder._modules["lm_head"] = nn.Identity()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.embedder = self.embedder.to(self.device)

        super().__init__(
            n_features_out=384,
            n_jobs=n_jobs,
            batch_size=batch_size,
            verbose=verbose,
        )

    def _calculate_fingerprint(self, X: Sequence[str | Mol]) -> np.ndarray:
        X = ensure_smiles(X)
        embeddings = []
        with torch.inference_mode(), torch.no_grad():
            for smiles in X:
                encoded_input = self.tokenizer(
                    smiles,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=512,
                )
                encoded_input = encoded_input.to(self.device)

                # transform: [1, n_tokens, 384] -> [n_tokens, 384] -> [384]
                # we extract [CLS] token embeddings (0-th token)
                vec = self.embedder(**encoded_input).logits[0, 0, :]
                vec = vec.cpu().numpy()
                embeddings.append(vec)

        return np.array(embeddings)
