# Reward metric calculation
import numpy as np

def hit_at_k_rank_list(rank_list, k):
    """
    Compute Hits@k
    Args:
        rank_list: List of ranks for relevant documents, e.g., [1, 3] or [1, -1]
        k: Threshold k
    Returns:
        float: Hits@k value
    """
    if not rank_list:
        return 0.0
    
    hits = 0
    for rank in rank_list:
        if rank != -1 and rank <= k:
            hits = 1
            break  # Count as hit if any document is within k
    return hits

def map_at_k_rank_list(rank_list, k):
    """
    Compute MAP@k
    Args:
        rank_list: List of ranks for relevant documents, e.g., [1, 3] or [1, -1]
        k: Threshold k
    Returns:
        float: MAP@k value
    """
    if not rank_list:
        return 0.0
    
    # Filter out -1 ranks (not found)
    valid_ranks = [rank for rank in rank_list if rank != -1 and rank <= k]
    
    if not valid_ranks:
        return 0.0
    
    # Sort by rank
    valid_ranks.sort()
    
    # Compute AP
    ap_sum = 0.0
    for i, rank in enumerate(valid_ranks):
        # Number of relevant documents found at this rank position
        relevant_at_rank = i + 1
        # precision at rank
        precision = relevant_at_rank / rank
        ap_sum += precision
    
    # Divide by total number of relevant documents
    return ap_sum / len(rank_list)

def mrr_at_k_rank_list(rank_list, k):
    """
    Compute MRR@k
    Args:
        rank_list: List of ranks for relevant documents, e.g., [1, 3] or [1, -1]
        k: Threshold k
    Returns:
        float: MRR@k value
    """
    if not rank_list:
        return 0.0
    
    # Find the smallest rank within k
    valid_ranks = [rank for rank in rank_list if rank != -1 and rank <= k]
    
    if not valid_ranks:
        return 0.0
    
    # Return reciprocal of the first relevant document's rank
    first_rank = min(valid_ranks)
    return 1.0 / first_rank

def ndcg_at_k_rank_list(rank_list, k):
    """
    Compute NDCG@k
    Args:
        rank_list: List of ranks for relevant documents, e.g., [1, 3] or [1, -1]
        k: Threshold k
    Returns:
        float: NDCG@k value
    """
    if not rank_list:
        return 0.0
    
    # Filter out -1 ranks
    valid_ranks = [rank for rank in rank_list if rank != -1 and rank <= k]
    
    if not valid_ranks:
        return 0.0
    
    # Compute DCG
    dcg = 0.0
    for rank in valid_ranks:
        dcg += 1.0 / np.log2(rank + 1)
    
    # Compute ideal DCG (assume all relevant documents are at the top)
    total_relevant = len([r for r in rank_list if r != -1])  # Total number of relevant documents
    idcg = 0.0
    # for i in range(len(valid_ranks)):
    for i in range(min(total_relevant, k)):
        idcg += 1.0 / np.log2(i + 2)  # i+2 because log2(1)=0, so start from 2
    
    return dcg / idcg if idcg > 0 else 0.0

# Calculate metrics directly based on rank
def hit_at_k_rank(rank, k):
    if rank == -1:
        return 0
    if rank <= k:
        return 1
    return 0

def mrr_at_k_rank(rank, k):
    if rank == -1:
        return 0
    if rank <= k:
        return 1.0 / rank
    return 0

def ndcg_at_k_rank(rank, k):
    if rank == -1:
        return 0
    if rank <= k:
        # Relevant document has gain 1, others 0
        dcg = 1.0 / np.log2(rank + 1)
        # Ideal dcg
        # idcg = 1.0 / np.log2(1 + 1)
        # ndcg = dcg / idcg
        return dcg
    return 0
