import math

import torch
import torch.nn.functional as F

from fairseq import utils

from . import FairseqCriterion, register_criterion


@register_criterion('tnf_masked_lm')
class TnfMaskedLmLoss(FairseqCriterion):
    """
    Implementation for the loss used in masked language model (MLM) training.
    """

    def __init__(self, args, task):
        super().__init__(args, task)

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        # compute MLM loss
        masked_tokens = sample['target'].ne(self.padding_idx)
        sample_size = masked_tokens.int().sum().item()

        # (Rare case) When all tokens are masked, the model results in empty
        # tensor and gives CUDA error.
        if sample_size == 0:
            masked_tokens = None
        
        logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
        targets = model.get_targets(sample, [logits])

        if sample_size != 0:
            targets = targets[masked_tokens]

        loss = F.nll_loss(
            F.log_softmax(
                logits.view(-1, logits.size(-1)),
                dim=-1,
                dtype=torch.float32,
            ),
            targets.view(-1),
            reduction='sum',
            ignore_index=self.padding_idx,
        )
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample_size,
            'wembed_norm': (model.decoder.
                            sentence_encoder.
                            embed_tokens.
                            weight.data.norm(2).item()
                            ),
            'bembed_norm': model.decoder.tnf_embeddings.weight.data.norm(2).item(),
            'tnf_update_cnts': model.decoder.tnf_words_updates_cnts.mean().item(),
        }
        return loss, sample_size, logging_output

    def track_accuracy(self, model, sample):
        """Compute the accuracy for the given sample.
        """
        # compute MLM loss
        masked_tokens = sample['target'].ne(self.padding_idx)
        sample_size = masked_tokens.int().sum().item()

        # (Rare case) When all tokens are masked, the model results in empty
        # tensor and gives CUDA error.
        if sample_size == 0:
            masked_tokens = None

        logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
        targets = model.get_targets(sample, [logits])


        if sample_size != 0:
            targets = targets[masked_tokens]

        sorted_occur, _ = torch.sort(targets.view(-1))
        occur, occur_counts = torch.unique_consecutive(sorted_occur, return_counts=True)

        preds = torch.argmax(logits, dim=1)
        correct_positions = (targets == preds)
        sorted_corrects, _ = torch.sort(preds[correct_positions].view(-1))
        corrects, correct_counts = torch.unique_consecutive(sorted_corrects, return_counts=True)
        return corrects, correct_counts, occur, occur_counts

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        loss = sum(log.get('loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        wembed_norm = logging_outputs[0].get('wembed_norm', 0)
        bembed_norm = logging_outputs[0].get('bembed_norm', 0)
        bwords_update_cnts = logging_outputs[0].get('tnf_update_cnts', 0)

        agg_output = {
            'loss': loss / sample_size / math.log(2),
            'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if ntokens > 0 else 0.,
            'ntokens': ntokens,
            'nsentences': nsentences,
            'sample_size': sample_size,
            'wembed_norm': wembed_norm,
            'bembed_norm': bembed_norm,
            'bwords_update_cnts': bwords_update_cnts,
        }
        return agg_output
