#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Description
Given the neural ranker, compute nDCG values.
"""

import torch

from tqdm import tqdm

from ptranking.data.data_utils import LABEL_TYPE
from ptranking.metric.adhoc_metric import torch_nDCG_at_ks


def ndcg_at_ks(ranker=None, test_loader=None, ks=[1, 5, 10], label_type=LABEL_TYPE.MultiLabel, gpu=False, device=None, eval_dict=None):
    sum_ndcg_at_ks = torch.zeros(len(ks))
    sum_count = torch.zeros(1)

    for batch_features, batch_labels, batch_mask in tqdm(test_loader, desc="Testing", leave=False):
        if gpu:
            batch_features, batch_labels, batch_mask = batch_features.to(
                device), batch_labels.to(device), batch_mask.to(device)

        if eval_dict['thresholding']:
            batch_labels = (
                batch_labels > eval_dict['thresholding_val'][0]).float()
        
        # pre-filter the negative q.g.
        sample_with_positive = torch.max(batch_labels, dim=1)[0] > 0.5
        batch_labels = batch_labels[sample_with_positive]
        batch_mask = batch_mask[sample_with_positive]
        batch_features = batch_features[sample_with_positive]

        if batch_features.shape[-1] > ranker.num_features:
            if eval_dict['mode'] == 'LUPI_teacher':
                batch_features = batch_features[:, :, -ranker.num_features:]
            else:
                batch_features = batch_features[:, :, :ranker.num_features]
        
        batch_rele_preds = ranker.predict(batch_features, train=False)
        # Very important. Set the padded seq to 0.
        batch_rele_preds = batch_rele_preds.masked_fill(batch_mask < 0.5, -float('inf'))

        _, batch_sorted_inds = torch.sort(
            batch_rele_preds, dim=1, descending=True)

        batch_sys_sorted_labels = torch.gather(
            batch_labels, dim=1, index=batch_sorted_inds)
        
        batch_ideal_sorted_labels, _ = torch.sort(
                batch_labels, dim=1, descending=True)

        batch_ndcg_at_ks, count = torch_nDCG_at_ks(batch_sys_sorted_labels=batch_sys_sorted_labels,
                                                     batch_ideal_sorted_labels=batch_ideal_sorted_labels,
                                                     ks=ks, label_type=label_type)
        sum_ndcg_at_ks += batch_ndcg_at_ks.cpu()
        sum_count += count
    return sum_ndcg_at_ks / sum_count
