# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
from torch.nn.modules.loss import _Loss


class LatentLayersKLLoss(_Loss):
    def __init__(self, args):
        super().__init__()
        self.args = args

    def forward(self, layer_samples, lang_idx, update_num, sample_size):
        prior = self.args.prior
        samples = layer_samples[lang_idx]
        eps = 1e-7
        if prior == "uniform":
            # uniform prior
            kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
        elif prior == "agged_posterior":
            # aggregated posterior
            y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
            agged_q = torch.sum(y_t, dim=0)
            row_norm = agged_q.sum(-1)
            normed_agg_q = agged_q / row_norm
            kl_loss = (
                samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
            ).sum(-1)
        else:
            raise NotImplementedError("The specified prior is not implemented.")

        # normalized by number of layers
        kl_loss /= layer_samples[0].size()[0]
        kl_weight = min(
            self.args.sparsity_weight,
            (update_num - self.args.soft_update)
            * self.args.sparsity_weight
            / self.args.anneal_updates,
        )
        kl_loss *= kl_weight * sample_size
        return kl_loss


class LatentLayersSparsityLoss(_Loss):
    def __init__(self, args):
        super().__init__()
        self.args = args

    def is_valid(self, update_num):
        if self.args.target_layers <= 0:
            return False
        return update_num > (self.args.soft_update + self.args.anneal_updates)

    def forward(self, layer_samples_list, update_num, sample_size):
        batch_loss = 0
        share_loss = 0
        global_sparsity_loss = 0
        layer_samples = torch.stack(layer_samples_list, dim=0)
        if (
            self.args.target_layers > 0 or self.args.share_weight > 0
        ) and update_num > (self.args.soft_update + self.args.anneal_updates):
            # anneal sparsity weight
            if update_num < (self.args.anneal_updates + self.args.soft_update):
                weight_anneal = 0
            elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
                weight_anneal = (
                    (update_num - self.args.soft_update - self.args.anneal_updates)
                    * self.args.share_weight
                    / self.args.anneal_updates
                )
            else:
                weight_anneal = 1
            # compute ratio among languages
            layer_utilization = torch.sum(layer_samples, dim=0)
            layer_utilization /= layer_samples.size()[0]
            if self.args.share_weight > 0:
                # encouraging sharing across languages
                share_loss = sum(
                    -1.0 * v * math.log(v) for v in layer_utilization if v > 0
                )
                batch_loss += (
                    weight_anneal * self.args.share_weight * sample_size * share_loss
                )
            if self.args.target_layers > 0:
                # computed expected number of layers selected
                expeted_layers = sum(layer_utilization)
                # compute l2 loss wrt target number of layers
                global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
                batch_loss += (
                    weight_anneal
                    * self.args.share_weight
                    * sample_size
                    * global_sparsity_loss
                )
        return batch_loss
