import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

class RetrievalEngine:
    def __init__(self, pose_library):
        self.lib = pose_library
        self.lib_embeds = self.lib.get_all_embeddings()  # (N, 512)
        self.window_size = self.lib.window_size

    def retrieve_sequence(self, input_embeddings):
        """
        基于输入音频嵌入检索姿态序列
        input_embeddings: [T, 512] - 输入音频的嵌入序列
        返回：连续的姿态序列，形状为 [T, 1, 6]
        """
        if len(input_embeddings) < 1:
            return []
    
        window_sequences = []
        poses_per_window = []  

        stride = self.window_size  
        
        num_windows = (len(input_embeddings) + stride - 1) // stride
        
        for i in range(num_windows):
            start = i * stride
            end = min(start + self.window_size, len(input_embeddings))
            
            if end - start < self.window_size and start > 0:
                start = max(0, end - self.window_size)
            
            if end - start < self.window_size // 2:
                continue

            curr_embed = np.mean(input_embeddings[start:end], axis=0)
            sims = cosine_similarity([curr_embed], self.lib_embeds)[0]  # (N,)
            idx = np.argmax(sims)
            
            retrieved_pose = self.lib.get_pose_by_index(idx)  # (window_size, 6)
            window_sequences.append(retrieved_pose)
            poses_per_window.append(end - start) 
        
            
        continuous_sequence = []
        
        for i, (poses, num_frames) in enumerate(zip(window_sequences, poses_per_window)):
            if len(poses) > num_frames:
                indices = np.linspace(0, len(poses)-1, num_frames, dtype=int)
                frames_to_use = [poses[idx] for idx in indices]
            else:
                frames_to_use = poses[:num_frames]
                
            continuous_sequence.extend(frames_to_use)

        return continuous_sequence
