from typing import (
    List,
    Tuple,
)

import numpy as np


class MemoryRerankingService:
    def __init__(self, top_k: int) -> None:
        self.top_k = top_k

    def _deduplicate(
        self,
        ret_memory_contents: List[str],
        ret_memory_ids: List[str],
        ret_memory_scores: List[float]
    ) -> Tuple[List[str], List[str], List[float]]:
        id_to_entry = {}
        for content, mem_id, score in zip(ret_memory_contents, ret_memory_ids, ret_memory_scores):
            if mem_id not in id_to_entry or id_to_entry[mem_id]["score"] < score:
                id_to_entry[mem_id] = {"content": content, "id": mem_id, "score": score}
        entries = list(id_to_entry.values())
        dedup_contents = [entry["content"] for entry in entries]
        dedup_ids = [entry["id"] for entry in entries]
        dedup_scores = [entry["score"] for entry in entries]
        return dedup_contents, dedup_ids, dedup_scores

    def rerank(
        self,
        ret_memory_contents: List[str],
        ret_memory_ids: List[str],
        ret_memory_scores: List[float],
    ) -> Tuple[List[str], List[str], List[float]]:
        ret_memory_contents, ret_memory_ids, ret_memory_scores = self._deduplicate(
            ret_memory_contents, ret_memory_ids, ret_memory_scores
        )
        sorted_idx = np.argsort(ret_memory_scores)[-1 : -self.top_k - 1 : -1]
        reranked_memory_contents = [ret_memory_contents[i] for i in sorted_idx]
        reranked_memory_ids = [ret_memory_ids[i] for i in sorted_idx]
        reranked_memory_scores = [ret_memory_scores[i] for i in sorted_idx]
        return reranked_memory_contents, reranked_memory_ids, reranked_memory_scores
