import random
import numpy as np

def inplace_shuffle(*lists):
    idx = []
    for i in range(len(lists[0])):
        idx.append(random.randint(0, i))
    for ls in lists:
        j = idx[i]
        ls[i], ls[j] = ls[j], ls[i]

def batch_by_num(n_batch, *lists, n_sample=None):
    if n_sample is None:
        n_sample = len(lists[0])

    for i in range(n_batch):
        start = int(n_sample * i / n_batch)
        end = int(n_sample * (i+1) / n_batch)
        ret = [ls[start:end] for ls in lists]
        if len(ret) > 1:
            yield ret
        else:
            yield ret[0]

def batch_by_size(batch_size, *lists, n_sample=None):
    if n_sample is None:
        n_sample = len(lists[0])

    start = 0
    while(start < n_sample):
        end = min(n_sample, start + batch_size)
        ret = [ls[start:end] for ls in lists]
        start += batch_size
        if len(ret) > 1:
            yield ret
        else:
            yield ret[0]
 

def calculate_recall_ndcg_at_ks(labels, scores, k_values):
    """
    计算多个 K 值下的 Recall@K 和 NDCG@K
    Args:
        labels: numpy.ndarray, 真实标签，形状为 (num_samples, num_labels)。
        scores: numpy.ndarray, 预测分数，形状为 (num_samples, num_labels)。
        k_values: list of int, 需要计算的 K 值列表。
    Returns:
        recall_at_ks: dict, 存储每个 k 对应的 recall@k 值，key为 k 值，value 为对应 recall@k 的值。
        ndcg_at_ks: dict, 存储每个 k 对应的 ndcg@k 值，key为 k 值，value 为对应 ndcg@k 的值。
    """
    num_samples = labels.shape[0]
    recall_at_ks = {k: 0.0 for k in k_values}
    ndcg_at_ks = {k: 0.0 for k in k_values}

    for i in range(num_samples):
        sample_labels = labels[i]
        sample_scores = scores[i]
        # 获取该样本的真实标签的索引
        relevant_indices = np.where(sample_labels == 1)[0]
        if len(relevant_indices) == 0: #如果该样本没有正标签，则跳过
           continue
        # 按预测分数降序排序标签索引
        ranked_indices = np.argsort(sample_scores)[::-1]

        for k in k_values:
            # 选择前 K 个预测标签索引
            top_k_indices = ranked_indices[:k]
            # 计算Recall@K
            hits = len(set(relevant_indices) & set(top_k_indices))
            sample_recall = hits / len(relevant_indices)
            recall_at_ks[k] += sample_recall

            # 计算 NDCG@K
            dcg = 0.0
            for j, index in enumerate(top_k_indices):
                if sample_labels[index] == 1:
                    dcg += 1 / np.log2(j + 2)
            idcg = 0.0
            for j in range(min(k,len(relevant_indices))):
                idcg += 1 / np.log2(j + 2)
            if idcg != 0:
                sample_ndcg = dcg / idcg
            else:
                sample_ndcg = 0
            ndcg_at_ks[k] += sample_ndcg
    
    for k in k_values:
        recall_at_ks[k] /= num_samples if num_samples > 0 else 1
        ndcg_at_ks[k] /= num_samples if num_samples > 0 else 1

    return recall_at_ks, ndcg_at_ks