#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Description
Tao Qin, Tie-Yan Liu, and Hang Li. 2010.
A general approximation framework for direct optimization of information retrieval measures.
Journal of Information Retrieval 13, 4 (2010), 375–397.
"""

from logging.handlers import DEFAULT_SOAP_LOGGING_PORT
import torch

from ptranking.data.data_utils import LABEL_TYPE
from ptranking.base.ranker import NeuralRanker
from ptranking.ltr_adhoc.eval.parameter import ModelParameter
from ptranking.metric.adhoc_metric import torch_dcg_at_k
from ptranking.base.neural_utils import robust_sigmoid


# def get_approx_ranks(input, alpha=10, gpu=False):
#     ''' get approximated rank positions: Equation-11 in the paper'''
#     batch_pred_diffs = torch.unsqueeze(input, dim=2) - torch.unsqueeze(input, dim=1)  # computing pairwise differences, i.e., Sij or Sxy

#     batch_indicators = robust_sigmoid(torch.transpose(batch_pred_diffs, dim0=1, dim1=2), alpha, gpu) # using {-1.0*} may lead to a poor performance when compared with the above way;

#     batch_hat_pis = torch.sum(batch_indicators, dim=2) + 0.5  # get approximated rank positions, i.e., hat_pi(x)

#     return batch_hat_pis

def get_approx_ranks(input, alpha, mask):
    square_mask = torch.unsqueeze(mask, dim=1) * torch.unsqueeze(mask, dim=2)

    batch_pred_diffs = torch.unsqueeze(input, dim=2) - torch.unsqueeze(input, dim=1)
    batch_indicators = torch.sigmoid(alpha * batch_pred_diffs) 

    batch_indicators = batch_indicators.masked_fill(square_mask < 0.5, 0)
    
    batch_hat_pis = torch.sum(batch_indicators, dim=1) + 0.5  # get approximated rank positions, i.e., hat_pi(x)
    # batch_hat_pis = batch_hat_pis.masked_fill(mask < 0.5, float('inf'))

    return batch_hat_pis

def approxNDCG_loss(batch_preds=None, batch_stds=None, mask=None, alpha=10, label_type=None, gpu=False):
    batch_hat_pis = get_approx_ranks(batch_preds, alpha=alpha, mask=mask)
    # new_hat_pis = new_get_approx_ranks(batch_preds, alpha=alpha)
    # print(batch_hat_pis[0])
    assert LABEL_TYPE.MultiLabel == label_type
    batch_stds = batch_stds.masked_fill(mask < 0.5, 0)
    batch_gains = torch.pow(2.0, batch_stds) - 1.0
    batch_dcg = torch.sum(torch.div(batch_gains, torch.log2(batch_hat_pis + 1)), dim=1)

    # print(batch_dcg)
    batch_ideal_gains = torch.sort(batch_gains, descending=True)[0]
    batch_idcgs = torch.sum(torch.div(batch_ideal_gains, torch.log2(torch.arange(batch_ideal_gains.size(1)).to(batch_dcg.device) + 2.0)), dim=1)

    # mask negative query group
    positive_groups = batch_idcgs > 0.0
    if torch.sum(positive_groups) > 0:
        batch_approx_nDCG = torch.div(batch_dcg[positive_groups], batch_idcgs[positive_groups])
        batch_loss = -torch.sum(batch_approx_nDCG)
    else:
        batch_loss = torch.Tensor([0.0]).to(batch_idcgs.device)

    

    # print(batch_idcgs)

    # print(batch_dcg[0], batch_idcgs[0], batch_loss)
    # ideal dcg given optimally ordered labels
    # assert False
    # batch_idcgs = torch_dcg_at_k(batch_sorted_labels=batch_stds, cutoff=None, label_type=label_type, gpu=gpu)

    

    # batch_loss = -torch.mean(batch_approx_nDCG)
    return batch_loss




class ApproxNDCG(NeuralRanker):
    '''
    Tao Qin, Tie-Yan Liu, and Hang Li. 2010.
    A general approximation framework for direct optimization of information retrieval measures.
    Journal of Information Retrieval 13, 4 (2010), 375–397.
    '''

    def __init__(self, sf_para_dict=None, model_para_dict=None, gpu=False, device=None, lr=None):
        super(ApproxNDCG, self).__init__(id='ApproxNDCG', sf_para_dict=sf_para_dict, gpu=gpu, device=device, lr=lr)
        self.alpha = model_para_dict['alpha']

    def inner_train(self, batch_preds, batch_stds, mask, teacher_pred=None, **kwargs):
        '''
        :param batch_preds: [batch, ranking_size] each row represents the relevance predictions for documents within a ltr_adhoc
        :param batch_stds: [batch, ranking_size] each row represents the standard relevance grades for documents within a ltr_adhoc
        :return:
        '''
        # assert False
        mix_alpha = kwargs["pri_dict"].json_dict["mix_alpha"][0]

        label_type = kwargs['label_type']
        assert label_type == LABEL_TYPE.MultiLabel

        # # Mask padded docs
        # batch_preds = batch_preds.masked_fill(mask < 0.5, -float('inf'))
        # print(batch_preds[0])
        # assert False

        # target_batch_stds, batch_sorted_inds = torch.sort(batch_stds, dim=1, descending=True)
        # target_batch_preds = torch.gather(batch_preds, dim=1, index=batch_sorted_inds)
        data_loss = approxNDCG_loss(batch_preds, batch_stds, mask, self.alpha, label_type=label_type, gpu=self.gpu)
        _rec_data_loss = data_loss.item()
        # print(_rec_data_loss)

        teacher_loss = data_loss
        if teacher_pred is not None:
            teacher_pred_sigmoid = torch.sigmoid(teacher_pred)
            teacher_loss = approxNDCG_loss(batch_preds, teacher_pred_sigmoid, mask, self.alpha, label_type=label_type, gpu=self.gpu)
            _rec_teacher_loss = teacher_loss.item()

        else:
            _rec_teacher_loss = 0

        # print(data_loss.device, teacher_loss.device)
            
        batch_loss = mix_alpha * data_loss + \
                         (1 - mix_alpha) * teacher_loss
        _rec_batch_loss = batch_loss.item()
        # assert False

        if batch_loss != 0:
            self.optimizer.zero_grad()
            batch_loss.backward()
            self.optimizer.step()

        return (_rec_batch_loss, _rec_data_loss, _rec_teacher_loss)

#-------
def get_apxndcg_paras_str(model_para_dict, log=False):
    s1 = ':' if log else '_'
    apxNDCG_paras_str = s1.join(['Alpha', str(model_para_dict['alpha'])])

    return apxNDCG_paras_str

###### Parameter of ApproxNDCG ######

class ApproxNDCGParameter(ModelParameter):
    ''' Parameter class for ApproxNDCG '''
    def __init__(self, debug=False, para_json=None):
        super(ApproxNDCGParameter, self).__init__(model_id='ApproxNDCG', para_json=para_json)
        self.debug = debug

    def default_para_dict(self):
        """
        Default parameter setting for ApproxNDCG
        :return:
        """
        self.apxNDCG_para_dict = dict(model_id=self.model_id, alpha=10.)
        return self.apxNDCG_para_dict

    def to_para_string(self, log=False, given_para_dict=None):
        """
        String identifier of parameters
        :param log:
        :param given_para_dict: a given dict, which is used for maximum setting w.r.t. grid-search
        :return:
        """
        # using specified para-dict or inner para-dict
        apxNDCG_para_dict = given_para_dict if given_para_dict is not None else self.apxNDCG_para_dict

        s1 = ':' if log else '_'
        apxNDCG_paras_str = s1.join(['Alpha', str(apxNDCG_para_dict['alpha'])])
        return apxNDCG_paras_str

    def grid_search(self):
        """
        Iterator of parameter settings for ApproxNDCG
        """
        if self.use_json:
            choice_alpha = self.json_dict['alpha']
        else:
            choice_alpha = [10.0] if self.debug else [10.0]  # 1.0, 10.0, 50.0, 100.0

        for alpha in choice_alpha:
            self.apxNDCG_para_dict = dict(model_id=self.model_id, alpha=alpha)
            yield self.apxNDCG_para_dict
