from sentence_transformers import SentenceTransformer, util
import numpy as np

class Aligner:
    """
    Embedding-based retrieval for recall only.
    Encodes atoms and chunks, computes cosine similarity, and returns top-k indices per atom.
    """
    def __init__(self, model_name: str, device: str = "cpu"):
        self.model = SentenceTransformer(model_name, device=device)
        self.model.max_seq_length = 512

    def embed(self, texts):
        # Safe guard for empty inputs
        if not texts:
            dim = self.model.get_sentence_embedding_dimension()
            return np.zeros((0, dim), dtype="float32")
        return self.model.encode(
            texts,
            batch_size=32,
            show_progress_bar=False,
            normalize_embeddings=True
        )

    def topk_indices(self, atoms, chunks, k: int):
        # print('aligner.py line 26 atoms', len(atoms))
        # print('aligner.py line 27 chunks', len(chunks))
        A = self.embed(atoms)   # shape [n_atoms, dim]
        C = self.embed(chunks)  # shape [n_chunks, dim]
        # print('aligner.py line 30 A', A)
        # print('aligner.py line 31 C', C)
        if A.shape[0] == 0 or C.shape[0] == 0:
            return [[] for _ in range(A.shape[0])]
        sim = util.cos_sim(A, C)  # shape [n_atoms, n_chunks]
        # print("aligner.py line 35 sim: ", sim.shape)
        indices = []
        for i in range(sim.size(0)):
            vals, idxs = sim[i].topk(min(k, sim.size(1)))
            indices.append([int(j) for j in idxs])
        return indices
