import torch
import torch.nn as nn
import torch.nn.functional as F
from ._base import Distiller

def cosine_similarity(a, b, eps=1e-8):
    return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)


def pearson_correlation(a, b, eps=1e-8):
    return cosine_similarity(a - a.mean(1).unsqueeze(1), b - b.mean(1).unsqueeze(1), eps)


def inter_class_relation(y_s, y_t):
    return 1 - pearson_correlation(y_s, y_t).mean()


def intra_class_relation(y_s, y_t):
    return inter_class_relation(y_s.transpose(0, 1), y_t.transpose(0, 1))


class DIST(Distiller):
    def __init__(self, student, teacher, cfg, wrap_student_in_ddp=False, local_rank=None):
        super(DIST, self).__init__(student, teacher, wrap_student_in_ddp=wrap_student_in_ddp, local_rank=local_rank)
        self.beta = cfg.DIST.beta
        self.gamma = cfg.DIST.gamma
        self.tau = cfg.DIST.tau
        self.ce_loss_weight = cfg.DIST.CE_WEIGHT
        self.kd_loss_weight = cfg.DIST.KD_WEIGHT

    def forward_train(self, image, perturbedInput, target, **kwargs):
        logits_student, _ = self.student(image)
        with torch.no_grad():
            logits_teacher, _ = self.teacher(perturbedInput)
            
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        
        y_s = (logits_student / self.tau).softmax(dim=1)
        y_t = (logits_teacher / self.tau).softmax(dim=1)
        inter_loss = self.tau**2 * inter_class_relation(y_s, y_t)
        intra_loss = self.tau**2 * intra_class_relation(y_s, y_t)
        loss_kd = self.beta * inter_loss + self.gamma * intra_loss
        
        losses_dict = {"loss_ce": loss_ce, "loss_kd": self.kd_loss_weight * loss_kd}
        
        return  logits_student, logits_teacher, losses_dict