from __future__ import print_function

import torch.nn as nn


class HintLoss(nn.Module):
    """Fitnets: hints for thin deep nets, ICLR 2015"""
    def __init__(self, weight=1):
        super(HintLoss, self).__init__()
        self.crit = nn.MSELoss()
        self.weight = weight

    def forward(self, f_s, f_t):
        loss = self.crit(f_s.float(), f_t.float()) * self.weight
        return loss
