# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from fairseq.data import encoders


@register_criterion("wsc")
class WSCCriterion(LegacyFairseqCriterion):
    def __init__(self, args, task):
        super().__init__(args, task)
        if self.args.save_predictions is not None:
            self.prediction_h = open(self.args.save_predictions, "w")
        else:
            self.prediction_h = None
        self.bpe = encoders.build_bpe(args.bpe)
        self.tokenizer = encoders.build_tokenizer(args.tokenizer)

    def __del__(self):
        if self.prediction_h is not None:
            self.prediction_h.close()

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
        parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
        parser.add_argument(
            "--wsc-cross-entropy",
            action="store_true",
            help="use cross entropy formulation instead of margin loss",
        )
        parser.add_argument(
            "--save-predictions", metavar="FILE", help="file to save predictions to"
        )

    def get_masked_input(self, tokens, mask):
        masked_tokens = tokens.clone()
        masked_tokens[mask] = self.task.mask
        return masked_tokens

    def get_lprobs(self, model, tokens, mask):
        logits, _ = model(src_tokens=self.get_masked_input(tokens, mask))
        lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
        scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
        mask = mask.type_as(scores)
        scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
        return scores

    def get_loss(self, query_lprobs, cand_lprobs):
        if self.args.wsc_cross_entropy:
            return F.cross_entropy(
                torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
                query_lprobs.new([0]).long(),
            )
        else:
            return (
                -query_lprobs
                + self.args.wsc_margin_alpha
                * (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
            ).sum()

    def forward(self, model, sample, reduce=True):
        # compute loss and accuracy
        loss, nloss = 0.0, 0
        ncorrect, nqueries = 0, 0

        for i, label in enumerate(sample["labels"]):
            query_lprobs = self.get_lprobs(
                model,
                sample["query_tokens"][i].unsqueeze(0),
                sample["query_masks"][i].unsqueeze(0),
            )
            cand_lprobs = self.get_lprobs(
                model,
                sample["candidate_tokens"][i],
                sample["candidate_masks"][i],
            )

            pred = (query_lprobs >= cand_lprobs).all().item()

            if label is not None:
                label = 1 if label else 0
                ncorrect += 1 if pred == label else 0
                nqueries += 1

            if label:
                # only compute a loss for positive instances
                nloss += 1
                loss += self.get_loss(query_lprobs, cand_lprobs)

            id = sample["id"][i].item()
            if self.prediction_h is not None:
                print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)

        if nloss == 0:
            loss = torch.tensor(0.0, requires_grad=True)

        sample_size = nqueries if nqueries > 0 else 1
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["nsentences"],
            "sample_size": sample_size,
            "ncorrect": ncorrect,
            "nqueries": nqueries,
        }
        return loss, sample_size, logging_output

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)

        agg_output = {
            "loss": loss_sum / sample_size / math.log(2),
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }

        ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
        nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
        if nqueries > 0:
            agg_output["accuracy"] = ncorrect / float(nqueries)

        return agg_output


@register_criterion("winogrande")
class WinograndeCriterion(WSCCriterion):
    def forward(self, model, sample, reduce=True):
        # compute loss and accuracy
        query_lprobs = self.get_lprobs(
            model,
            sample["query_tokens"],
            sample["query_masks"],
        )
        cand_lprobs = self.get_lprobs(
            model,
            sample["candidate_tokens"],
            sample["candidate_masks"],
        )
        pred = query_lprobs >= cand_lprobs
        loss = self.get_loss(query_lprobs, cand_lprobs)

        sample_size = sample["query_tokens"].size(0)
        ncorrect = pred.sum().item()
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["nsentences"],
            "sample_size": sample_size,
            "ncorrect": ncorrect,
            "nqueries": sample_size,
        }
        return loss, sample_size, logging_output
