# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

from fairseq import utils

from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from . import register_criterion


@register_criterion('label_smoothed_cross_entropy_with_alignment')
class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion):

    def __init__(self, args, task):
        super().__init__(args, task)
        self.alignment_lambda = args.alignment_lambda

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        super(LabelSmoothedCrossEntropyCriterionWithAlignment,
              LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser)
        parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D',
                            help='weight for the alignment loss')

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }

        alignment_loss = None

        # Compute alignment loss only for training set and non dummy batches.
        if 'alignments' in sample and sample['alignments'] is not None:
            alignment_loss = self.compute_alignment_loss(sample, net_output)

        if alignment_loss is not None:
            logging_output['alignment_loss'] = utils.item(alignment_loss.data)
            loss += self.alignment_lambda * alignment_loss

        return loss, sample_size, logging_output

    def compute_alignment_loss(self, sample, net_output):
        attn_prob = net_output[1]['attn']
        bsz, tgt_sz, src_sz = attn_prob.shape
        attn = attn_prob.view(bsz * tgt_sz, src_sz)

        align = sample['alignments']
        align_weights = sample['align_weights'].float()

        if len(align) > 0:
            # Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
            # the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
            loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum()
        else:
            return None

        return loss

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
        return {
            'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
            'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0.,
            'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
            'ntokens': ntokens,
            'nsentences': nsentences,
            'sample_size': sample_size,
        }
