#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Description
Viewing the prediction of relevance as a conventional regression problem.
"""

import torch
import torch.nn.functional as F

from ptranking.base.ranker import NeuralRanker



def rankBCE_loss_function(batch_pred=None, batch_label=None, mask=None, TL_AF=None, learn_from_data=True, qg_mask_enable=True):
	'''
	Ranking loss based on mean square error
	:param batch_pred:
	:param batch_stds:
	:return:
	'''
	if ('S' == TL_AF or 'ST' == TL_AF) and learn_from_data:  # map to the same relevance level
		max_rele_level = torch.max(batch_label)
		batch_pred = batch_pred * max_rele_level

	if learn_from_data:
		# mask all negative groups
		if qg_mask_enable:
			qg_mask = torch.max(batch_label, dim=1)[0].unsqueeze(-1)
		else:
			qg_mask = 1

		batch_loss = F.binary_cross_entropy_with_logits(batch_pred, batch_label, weight=qg_mask * mask, reduction='sum')
		
	else:
		# print(batch_label.shape)
		# assert False
		batch_label.detach()
		batch_loss = F.binary_cross_entropy_with_logits(batch_pred, torch.sigmoid(batch_label), weight=mask, reduction='sum') - F.binary_cross_entropy_with_logits(batch_label, torch.sigmoid(batch_label), weight=mask, reduction='sum')
		# if batch_loss > 0:
		# 	batch_loss /= torch.sum(mask).detach()
	return batch_loss


class RankBCE(NeuralRanker):
	def __init__(self, sf_para_dict=None, gpu=False, device=None, lr=None):
		super(RankBCE, self).__init__(id='RankMSE', sf_para_dict=sf_para_dict, gpu=gpu, device=device, lr=lr)
		self.TL_AF = self.get_tl_af()

	def inner_train(self, batch_pred, batch_label, mask, teacher_pred=None, **kwargs):
		'''
		:param batch_preds: [batch, ranking_size] each row represents the relevance predictions for documents within a ltr_adhoc
		:param batch_stds: [batch, ranking_size] each row represents the standard relevance grades for documents within a ltr_adhoc
		:return:
		'''
		# print(batch_pred[0], mask[0])
		# assert False
		mix_alpha = kwargs["pri_dict"].json_dict["mix_alpha"][0]
		qg_mask_enable = kwargs["pri_dict"].json_dict["qg_mask"]
		data_loss = rankBCE_loss_function(batch_pred, batch_label, mask=mask, TL_AF=self.TL_AF, qg_mask_enable = qg_mask_enable)

		_rec_data_loss = data_loss.item()

		teacher_batch_loss = data_loss
		if teacher_pred is not None:
			teacher_batch_loss = rankBCE_loss_function(batch_pred,
												teacher_pred,
												mask=mask,
												TL_AF=self.TL_AF,
												learn_from_data=False,
												qg_mask_enable=qg_mask_enable)
			_rec_teacher_loss = teacher_batch_loss.item()
		else:
			_rec_teacher_loss = 0

		batch_loss = mix_alpha * data_loss + \
		             (1 - mix_alpha) * teacher_batch_loss
		_rec_batch_loss = batch_loss.item()

		assert batch_loss.item() >= 0, batch_loss
		if batch_loss > 0 or not qg_mask_enable:
			self.optimizer.zero_grad()
			batch_loss.backward()
			self.optimizer.step()

		return (_rec_batch_loss, _rec_data_loss, _rec_teacher_loss)
