import torch
from torch import nn

"""
Feature Diversity Loss:
Usage to replicate paper:
Call 
loss_function = FeatureDiversityLoss(0.196, linear) 
to inititalize loss with linear layer of model.
At each mini batch get feature maps (Output of final convolutional layer) and add to Loss:
loss += loss_function(feature_maps, outputs)
"""


class FeatureDiversityLoss(nn.Module):
    def __init__(self, scaling_factor, linear):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.linearLayer = linear

    def initialize(self, linearLayer):
        self.linearLayer = linearLayer

    def get_weights(self, outputs):
        weight_matrix = self.linearLayer.weight
        weight_matrix = torch.abs(weight_matrix)
        top_classes = torch.argmax(outputs, dim=1)
        relevant_weights = weight_matrix[top_classes]
        return relevant_weights

    def forward(self, feature_maps, outputs):
        relevant_weights = self.get_weights(outputs)
        relevant_weights = norm_vector(relevant_weights)
        feature_maps = preserve_avg_func(feature_maps)
        flattened_feature_maps = feature_maps.flatten(2)
        batch, features, map_size = flattened_feature_maps.size()
        relevant_feature_maps = flattened_feature_maps * relevant_weights[..., None]
        diversity_loss = torch.sum(
            torch.amax(relevant_feature_maps, dim=1))
        return -diversity_loss / batch * self.scaling_factor


def norm_vector(x):
    return x / (torch.norm(x, dim=1) + 1e-5)[:, None]


def preserve_avg_func(x):
    avgs = torch.mean(x, dim=[2, 3])
    max_avgs = torch.max(avgs, dim=1)[0]
    scaling_factor = avgs / max_avgs[..., None]
    softmaxed_maps = softmax_feature_maps(x)
    scaled_maps = softmaxed_maps * scaling_factor[..., None, None]
    return scaled_maps


def softmax_feature_maps(x):
    return torch.softmax(x.reshape(x.size(0), x.size(1), -1), 2).view_as(x)


if __name__ == '__main__':
    # example
    n_classes = 100
    n_batches = 3
    n_features = 2048
    torch.manual_seed(0)
    test_maps = torch.randn(3, n_features, 4, 4)
    linear = type('', (), {})()
    linear.weight = torch.randn(n_classes, n_features)
    outputs = torch.randn(3, n_classes)
    submitted_loss = FeatureDiversityLoss(0.196, linear)
    submitted_output = submitted_loss(test_maps, outputs)
