import math

import torch
from torch import nn
from torch.nn.functional import adaptive_avg_pool2d, adaptive_max_pool2d, normalize


class FSPLoss(nn.Module):
    """
    "A Gift From Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning"
    """
    def __init__(self):
        super().__init__()

    @staticmethod
    def compute_fsp_matrix(first_feature_map, second_feature_map):
        first_h, first_w = first_feature_map.shape[2:4]
        second_h, second_w = second_feature_map.shape[2:4]
        target_h, target_w = min(first_h, second_h), min(first_w, second_w)
        if first_h > target_h or first_w > target_w:
            first_feature_map = adaptive_max_pool2d(first_feature_map, (target_h, target_w))

        if second_h > target_h or second_w > target_w:
            second_feature_map = adaptive_max_pool2d(second_feature_map, (target_h, target_w))

        first_feature_map = first_feature_map.flatten(2)
        second_feature_map = second_feature_map.flatten(2)
        hw = first_feature_map.shape[2]
        return torch.matmul(first_feature_map, second_feature_map.transpose(1, 2)) / hw

    def forward(self, sub_features, victim_features, *args, **kwargs):
        fsp_loss = 0.0
        batch_size = None
        assert len(sub_features) == len(victim_features)
        for i in range(len(sub_features)):
            sub_fsp_matrices = self.compute_fsp_matrix(sub_features[i][0], sub_features[i][1])
            victim_fsp_matrices = self.compute_fsp_matrix(victim_features[i][0], victim_features[i][1])
            fsp_loss += (sub_fsp_matrices-victim_fsp_matrices).norm(dim=1).sum()
            if batch_size is None:
                batch_size = sub_fsp_matrices.shape[0]
        return fsp_loss / batch_size
