
from sentence_transformers import util
import torch
import torch.nn.functional as F

def mmr(query_emb, embeddings, k=5, lambda_param=0.5):
    """
    Perform Maximal Marginal Relevance (MMR) to select top-k diverse documents.

    Args:
        query_emb (torch.Tensor): Embedding of the query.
        embeddings (list of torch.Tensor): List of document embeddings.
        k (int): Number of top documents to select.
        lambda_param (float): Trade-off parameter between relevance and diversity.

    Returns:
        list: Indices of the selected top-k documents.
    """
    # Ensure embeddings are tensors
    query_emb = query_emb.unsqueeze(0) if query_emb.ndimension() == 1 else query_emb
    embeddings = [e.unsqueeze(0) if e.ndimension() == 1 else e for e in embeddings]

    # Compute relevance scores (query vs. documents)
    rel_scores = [F.cosine_similarity(query_emb, e) for e in embeddings]

    selected = []
    remaining = list(range(len(embeddings)))

    while len(selected) < k and remaining:
        mmr_scores = []
        for i in remaining:
            # Compute similarity with already selected documents
            sim_with_selected = [F.cosine_similarity(embeddings[i], embeddings[j]) for j in selected]
            diversity_penalty = max(sim_with_selected) if sim_with_selected else 0.0

            # Compute MMR score
            mmr_score = lambda_param * rel_scores[i] - (1 - lambda_param) * diversity_penalty
            mmr_scores.append((mmr_score.item(), i))

        # Select the document with the highest MMR score
        _, best_idx = max(mmr_scores, key=lambda x: x[0])
        selected.append(best_idx)
        remaining.remove(best_idx)

    return selected

class MMR:

	def __init__(self):
		pass

	def attribute_PC(self, query_emb, embeddings):
		query_emb = torch.tensor(query_emb)
		embeddings = [torch.tensor(e) for e in embeddings]
		results = util.semantic_search(query_emb, embeddings, top_k=5)
		mmr_results1 = mmr(query_emb, embeddings, k=1, lambda_param=0.5)
		mmr_results2 = mmr(query_emb, embeddings, k=2, lambda_param=0.5)
		return [], mmr_results1, mmr_results2
