# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from dataclasses import dataclass

import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion,LegacyFairseqCriterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
from torch.cuda import LongTensor, FloatTensor
import torch
from torch import nn
from collections import Counter
from torch.autograd import Variable
import random

@dataclass
class EISLLossCriterionConfig(FairseqDataclass):
    sentence_avg: bool = II("optimization.sentence_avg")

@register_criterion("eislloss")
class EISLCriterion(LegacyFairseqCriterion):
    def __init__(self, args, task):
        super().__init__(args,task)
        self.sentence_avg = False
        self.bestbleu = 0
        self.ngram = [int(i) for i in args.ngram.split(',')]

        
    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--ngram', default='1,2', type=str)
        # fmt: on
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample["net_input"])
        loss, loss_nll = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = (
            sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
        )
        
        logging_output = {
            "loss": loss.data,
            "loss_nll": loss_nll.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size
        }

        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        #         lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output)  # .view(-1)
        loss_nll = F.nll_loss(
            lprobs.view(-1, lprobs.size(-1)),
            target.view(-1),
            ignore_index=self.padding_idx,
            reduction="sum" if reduce else "none",
        ).detach()
        #         lprobs = lprobs.transpose(0,1)
        expected_len = expected_length(lprobs)
        loss = sample["ntokens"] * self.batch_log_eisl(lprobs, target, expected_len)
        return loss, loss_nll

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        loss_nll_sum = sum(log.get("loss_nll", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "bleu_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar(
            "loss", loss_nll_sum / sample_size / math.log(2), sample_size, round=3
        )
        
        if sample_size != ntokens:
            metrics.log_scalar(
                "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
            )
            metrics.log_derived(
                "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
            )
        else:
            metrics.log_derived(
                "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
            )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return False

    
    def batch_log_eisl(self, decoder_outputs, target_idx, trans_len, pad=1,
                              weight_list=None):
        """
        decoder_outputs: [batch_size, output_len, vocab_size]
            - matrix with probabilityes  -- log probs
        target_idx: [batch_size, target_len]
            - reference batch
        ngram_list: int or List[int]
            - n-gram to consider
        pad: int
            the idx of "pad" token
        weight_list : List
            corresponding weight of ngram

        NOTE: output_len == target_len
        """
        
        batch_size, output_len, vocab_size = decoder_outputs.size()
        _, tgt_len = target_idx.size()
        ngram_list = self.ngram
        metrics.BleuLog.tf_ratio = 0.0
        if weight_list is None:
            weight_list = [1. / len(ngram_list)] * len(ngram_list)
        decoder_outputs = torch.relu(decoder_outputs + 20) - 20 
        index = target_idx.unsqueeze(1).expand(-1, output_len, tgt_len)

        # [batch, output_len, target_len]
        cost_nll = decoder_outputs.gather(dim=2, index=index)

        # [batch, 1, output_len, target_len] -> [batch, 1, target_len, output_len]
        cost_nll = cost_nll.unsqueeze(1)#.transpose(2, 3) # P(A)log(a)

        out = cost_nll
        sum_gram = FloatTensor([0.])
        zero = torch.tensor(0.0).cuda()
        target_expand = target_idx.view(batch_size,1,1,-1).expand(-1,-1,output_len,-1)
        out = torch.where(target_expand==pad, zero, out)
        for cnt, ngram in enumerate(ngram_list):
            if ngram > output_len:
                continue
            eye_filter = torch.eye(ngram).view([1, 1, ngram, ngram]).cuda()
            term = nn.functional.conv2d(out, eye_filter)/ngram  # 除以ngram，作为normalization
            if ngram < decoder_outputs.size()[1]:
                gum_tmp = F.gumbel_softmax(term.squeeze_(1), tau=1, dim=1)
                term = term.mul(gum_tmp).sum(1).mean(1)
            sum_gram += weight_list[cnt] * term.sum()
        loss = - sum_gram / batch_size
        return loss
    