from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

'''
AT with sum of absolute values with power p
code from: https://github.com/AberHu/Knowledge-Distillation-Zoo
'''
class AT(nn.Layer):
	'''
	Paying More Attention to Attention: Improving the Performance of Convolutional
	Neural Netkworks wia Attention Transfer
	https://arxiv.org/pdf/1612.03928.pdf
	'''
	def __init__(self, p):
		super(AT, self).__init__()
		self.p = p

	def forward(self, fm_s, fm_t):
		loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t))

		return loss

	def attention_map(self, fm, eps=1e-6):
		am = paddle.pow(paddle.abs(fm), self.p)
		am = paddle.sum(am, axis=1, keepdim=True)
		norm = paddle.norm(am, axis=(2,3), keepdim=True)
		am = paddle.divide(am, norm+eps)

		return am