from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from ._base import Distiller


class WTTM(Distiller):
    def __init__(self, student, teacher, cfg, wrap_student_in_ddp=False, local_rank=None):
        super(WTTM, self).__init__(student, teacher, wrap_student_in_ddp=wrap_student_in_ddp, local_rank=local_rank)
        self.exponent = cfg.WTTM.EXPONENT
        self.ce_loss_weight = cfg.WTTM.CE_WEIGHT
        self.kd_loss_weight = cfg.WTTM.KD_WEIGHT
        
    def forward_train(self, image, perturbedInput, target, **kwargs):
        logits_student, _ = self.student(image)
        with torch.no_grad():
            logits_teacher, _ = self.teacher(perturbedInput)
            
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        
        p_s = F.log_softmax(logits_student, dim=1)
        p_t = torch.pow(torch.softmax(logits_teacher, dim=1), self.exponent)
        norm = torch.sum(p_t, dim=1)
        p_t = p_t / norm.unsqueeze(1)
        KL = torch.sum(F.kl_div(p_s, p_t, reduction='none'), dim=1)
        loss_kd = self.kd_loss_weight * torch.mean(norm*KL)
        
        losses_dict = {"loss_ce": loss_ce, "loss_kd": loss_kd,}
        
        return logits_student, logits_teacher, losses_dict