# Reward metric calculation
import numpy as np

def get_rank(retrieved, target, k):
    """
    Calculate the ranking of each document in target within the top k positions of retrieved
    Return a list with the same length as target, where the i-th element represents the rank of target[i]
    If not in the top k positions, rank is -1
    """
    # Ensure k does not exceed the length of retrieved
    k = min(k, len(retrieved))
    
    # Get the top k positions of retrieved
    retrieved_top_k = retrieved[:k]
    
    # Store ranking results for each document
    ranks = []
    
    # Calculate ranking for each document in target
    for doc in target:
        # Check if document is in the top k positions
        for rank, item in enumerate(retrieved_top_k, 1):
            if doc == item:
                ranks.append(rank)
                break
        else:
            # If loop ends without finding, rank is -1
            ranks.append(-1)
    if len(ranks) == 1:
        return ranks[0]
    return ranks


# def get_rank(retrieved, target, k):
#     """
#     Check if documents in target are in the top k positions of retrieved, return rank if found, otherwise return -1.
#     """
#     # Ensure k does not exceed the length of target
#     k = min(k, len(retrieved))
    
#     # Get the top k positions of target
#     retrieved_top_k = retrieved[:k]
    
#     # Get documents in retrieved (assuming length is 1)
#     doc = target[0]
    
#     # Check if document is in the top k positions of target
#     for i, item in enumerate(retrieved_top_k, 1):
#         if doc == item:
#             return i
#     return -1

# Calculate various metrics directly based on rank
def hit_at_k_rank(rank, k):
    if rank == -1:
        return 0
    if rank <= k:
        return 1
    return 0

def mrr_at_k_rank(rank, k):
    if rank == -1:
        return 0
    if rank <= k:
        return 1.0 / rank
    return 0

def ndcg_at_k_rank(rank, k):
    if rank == -1:
        return 0
    if rank <= k:
        # Relevant document is 1, other positions are 0
        dcg = 1.0 / np.log2(rank + 1)
        # Ideal dcg
        # idcg = 1.0 / np.log2(1 + 1)
        # ndcg = dcg / idcg
        return dcg
    return 0





# def get_predicted_ranking(predicted_list, target_list, true_scores):
#     target_score_dict = {item: score for item, score in zip(target_list, true_scores)}
#     return [target_score_dict.get(item, 0) for item in predicted_list]
def hit_at_k(retrieved, target, k, rel_scores=None):
    '''
    Currently assume target is 1 document
    '''
    retrieved = retrieved[:k]
    return 1.0 if any(item in target for item in retrieved) else 0.0

def dcg_at_k(retrieved, target, k, rel_scores=None):
    """
    Compute DCG@k (Discounted Cumulative Gain).
    Default target is an ordered list of relevant documents, from highest to lowest relevance.
    """
    retrieved = retrieved[:k]
    if rel_scores is None:
        gains = np.array(retrieved) == target
    else:
        assert len(target) == len(rel_scores)
        rel_scores_dict = {item: rel_scores[i] for i, item in enumerate(target)}
        gains = np.array([rel_scores_dict.get(doc, 0) for doc in retrieved])
    discounts = np.log2(np.arange(2, len(gains) + 2))
    return np.sum(gains / discounts)

def ndcg_at_k(retrieved, target, k, rel_scores=None):
    """
    Compute NDCG@k.
    """
    if rel_scores == None:
        rel_scores = [1 for _ in range(len(target))] 
    dcg = dcg_at_k(retrieved, target, k, rel_scores)
    if isinstance(target, list):
        ideal_dcg = dcg_at_k(target, target, k, rel_scores)
    else:
        ideal_dcg = dcg_at_k([target], target, k, rel_scores)  # Ideal DCG: only the target at top
    return dcg / ideal_dcg if ideal_dcg > 0 else 0.0


def dcg_at_k_rank(rank, k):
    """
    Compute DCG@k (Discounted Cumulative Gain) for a single relevant document at a specific rank.
    
    Args:
        rank (int): The rank of the first occurrence of the answer (1-indexed)
        k (int): Number of documents to consider
        
    Returns:
        float: DCG score
    """
    if rank > k:
        return 0.0
    
    # Create a relevance array with a single 1 at the position of the first answer
    relevance = np.zeros(k)
    if rank <= k:
        relevance[rank-1] = 1.0
    
    # Calculate discounted gains using log2(i+1) for positions i (0-indexed)
    discounts = np.log2(np.arange(2, k + 2))
    return np.sum(relevance / discounts)


def ideal_dcg_at_k(k):
    """
    Compute ideal DCG@k where the answer is at rank 1.
    
    Args:
        k (int): Number of documents to consider
        
    Returns:
        float: Ideal DCG score
    """
    # Ideal case: answer is at rank 1
    return dcg_at_k_rank(1, k)


def ndcg_for_rank(rank, k):
    """
    Compute NDCG@k for a document at a specific rank.
    
    Args:
        rank (int): The rank of the first occurrence of the answer (1-indexed)
        k (int): Number of documents to consider
        
    Returns:
        float: NDCG score between 0 and 1
    """
    dcg = dcg_at_k_rank(rank, k)
    ideal_dcg = ideal_dcg_at_k(k)
    return dcg / ideal_dcg if ideal_dcg > 0 else 0.0



if __name__ == '__main__':
    retrieved = [1, 2, 3, 4, 5, 6, 7]
    target = [1,2]
    rel_scores = [1 for _ in range(len(target))] 
    k = 10
    score = ndcg_at_k(retrieved, target, k, rel_scores)
    print(score)