import torch
import numpy as np
from retrieval import RetrievalEngine

class PoseRetriever:
    def __init__(self, lib):

        self.lib = lib
        self.retriever = RetrievalEngine(self.lib)
        
    def retrieve_pose(self, embeddings):

        if isinstance(embeddings, torch.Tensor):
            embeddings = embeddings.detach().cpu().numpy()
        
        pose_segments = self.retriever.retrieve_sequence(embeddings)

        if not pose_segments:
            return np.zeros((0, 1, 6), dtype=np.float32)
            
        return pose_segments
    
    def batch_retrieve(self, batch_embeddings):

        if isinstance(batch_embeddings, torch.Tensor):
            batch_embeddings = batch_embeddings.detach().cpu().numpy()
            
        batch_poses = []
        
        for embeddings in batch_embeddings:
            pose = self.retrieve_pose(embeddings)
            batch_poses.append(pose)
            
        return batch_poses
