# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch


class LatencyMetric(object):
    @staticmethod
    def length_from_padding_mask(padding_mask, batch_first: bool = False):
        dim = 1 if batch_first else 0
        return padding_mask.size(dim) - padding_mask.sum(dim=dim, keepdim=True)

    def prepare_latency_metric(
        self,
        delays,
        src_lens,
        target_padding_mask=None,
        batch_first: bool = False,
        start_from_zero: bool = True
    ):
        assert len(delays.size()) == 2
        assert len(src_lens.size()) == 2

        if start_from_zero:
            delays = delays + 1

        if batch_first:
            # convert to batch_last
            delays = delays.t()
            src_lens = src_lens.t()
            tgt_len, bsz = delays.size()
            _, bsz_1 = src_lens.size()

            if target_padding_mask is not None:
                target_padding_mask = target_padding_mask.t()
                tgt_len_1, bsz_2 = target_padding_mask.size()
                assert tgt_len == tgt_len_1
                assert bsz == bsz_2

        assert bsz == bsz_1

        if target_padding_mask is None:
            tgt_lens = tgt_len * delays.new_ones([1, bsz]).float()
        else:
            # 1, batch_size
            tgt_lens = self.length_from_padding_mask(target_padding_mask, False).float()
            delays = delays.masked_fill(target_padding_mask, 0)

        return delays, src_lens, tgt_lens, target_padding_mask

    def __call__(
        self,
        delays,
        src_lens,
        target_padding_mask=None,
        batch_first: bool = False,
        start_from_zero: bool = True,
    ):
        delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
            delays,
            src_lens,
            target_padding_mask,
            batch_first,
            start_from_zero
        )
        return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)

    @staticmethod
    def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
        """
        Expected sizes:
        delays: tgt_len, batch_size
        src_lens: 1, batch_size
        target_padding_mask: tgt_len, batch_size
        """
        raise NotImplementedError


class AverageProportion(LatencyMetric):
    """
    Function to calculate Average Proportion from
    Can neural machine translation do simultaneous translation?
    (https://arxiv.org/abs/1606.02012)

    Delays are monotonic steps, range from 1 to src_len.
    Give src x tgt y, AP is calculated as:

    AP = 1 / (|x||y]) sum_i^|Y| deleys_i
    """
    @staticmethod
    def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
        if target_padding_mask is not None:
            AP = torch.sum(delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True)
        else:
            AP = torch.sum(delays, dim=0, keepdim=True)

        AP = AP / (src_lens * tgt_lens)
        return AP


class AverageLagging(LatencyMetric):
    """
    Function to calculate Average Lagging from
    STACL: Simultaneous Translation with Implicit Anticipation
    and Controllable Latency using Prefix-to-Prefix Framework
    (https://arxiv.org/abs/1810.08398)

    Delays are monotonic steps, range from 1 to src_len.
    Give src x tgt y, AP is calculated as:

    AL = 1 / tau sum_i^tau delays_i - (i - 1) / gamma

    Where
    gamma = |y| / |x|
    tau = argmin_i(delays_i = |x|)
    """
    @staticmethod
    def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
        # tau = argmin_i(delays_i = |x|)
        tgt_len, bsz = delays.size()
        lagging_padding_mask = delays >= src_lens
        lagging_padding_mask = torch.nn.functional.pad(lagging_padding_mask.t(), (1, 0)).t()[:-1, :]
        gamma = tgt_lens / src_lens
        lagging = delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
        lagging.masked_fill_(lagging_padding_mask, 0)
        tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
        AL = lagging.sum(dim=0, keepdim=True) / tau

        return AL


class DifferentiableAverageLagging(LatencyMetric):
    """
    Function to calculate Differentiable Average Lagging from
    Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
    (https://arxiv.org/abs/1906.05218)

    Delays are monotonic steps, range from 0 to src_len-1.
    (In the original paper thery are from 1 to src_len)
    Give src x tgt y, AP is calculated as:

    DAL = 1 / |Y| sum_i^|Y| delays'_i - (i - 1) / gamma

    Where
    delays'_i =
        1. delays_i if i == 1
        2. max(delays_i, delays'_{i-1} + 1 / gamma)

    """
    @staticmethod
    def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
        tgt_len, bsz = delays.size()

        gamma = tgt_lens / src_lens
        new_delays = torch.zeros_like(delays)

        for i in range(delays.size(0)):
            if i == 0:
                new_delays[i] = delays[i]
            else:
                new_delays[i] = torch.cat(
                    [
                        new_delays[i - 1].unsqueeze(0) + 1 / gamma,
                        delays[i].unsqueeze(0)
                    ],
                    dim=0
                ).max(dim=0)[0]

        DAL = (
            new_delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
        )
        if target_padding_mask is not None:
            DAL = DAL.masked_fill(target_padding_mask, 0)

        DAL = DAL.sum(dim=0, keepdim=True) / tgt_lens

        return DAL


class LatencyMetricVariance(LatencyMetric):
    def prepare_latency_metric(
        self,
        delays,
        src_lens,
        target_padding_mask=None,
        batch_first: bool = True,
        start_from_zero: bool = True
    ):
        assert batch_first
        assert len(delays.size()) == 3
        assert len(src_lens.size()) == 2

        if start_from_zero:
            delays = delays + 1

        # convert to batch_last
        bsz, num_heads_x_layers, tgt_len = delays.size()
        bsz_1, _ = src_lens.size()
        assert bsz == bsz_1

        if target_padding_mask is not None:
            bsz_2, tgt_len_1 = target_padding_mask.size()
            assert tgt_len == tgt_len_1
            assert bsz == bsz_2

        if target_padding_mask is None:
            tgt_lens = tgt_len * delays.new_ones([bsz, tgt_len]).float()
        else:
            # batch_size, 1
            tgt_lens = self.length_from_padding_mask(target_padding_mask, True).float()
            delays = delays.masked_fill(target_padding_mask.unsqueeze(1), 0)

        return delays, src_lens, tgt_lens, target_padding_mask


class VarianceDelay(LatencyMetricVariance):
    @staticmethod
    def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
        """
        delays : bsz, num_heads_x_layers, tgt_len
        src_lens : bsz, 1
        target_lens : bsz, 1
        target_padding_mask: bsz, tgt_len or None
        """
        if delays.size(1) == 1:
            return delays.new_zeros([1])

        variance_delays = delays.var(dim=1)

        if target_padding_mask is not None:
            variance_delays.masked_fill_(target_padding_mask, 0)

        return variance_delays.sum(dim=1, keepdim=True) / tgt_lens


class LatencyInference(object):
    def __init__(self, start_from_zero=True):
        self.metric_calculator = {
            "differentiable_average_lagging": DifferentiableAverageLagging(),
            "average_lagging": AverageLagging(),
            "average_proportion": AverageProportion(),
        }

        self.start_from_zero = start_from_zero

    def __call__(self, monotonic_step, src_lens):
        """
        monotonic_step range from 0 to src_len. src_len means eos
        delays: bsz, tgt_len
        src_lens: bsz, 1
        """
        if not self.start_from_zero:
            monotonic_step -= 1

        src_lens = src_lens

        delays = (
            monotonic_step
            .view(monotonic_step.size(0), -1, monotonic_step.size(-1))
            .max(dim=1)[0]
        )

        delays = (
            delays.masked_fill(delays >= src_lens, 0)
            + (src_lens - 1)
            .expand_as(delays)
            .masked_fill(delays < src_lens, 0)
        )
        return_dict = {}
        for key, func in self.metric_calculator.items():
            return_dict[key] = func(
                delays.float(), src_lens.float(),
                target_padding_mask=None,
                batch_first=True,
                start_from_zero=True
            ).t()

        return return_dict


class LatencyTraining(object):
    def __init__(
        self, avg_weight, var_weight, avg_type, var_type,
        stay_on_last_token, average_method,
    ):
        self.avg_weight = avg_weight
        self.var_weight = var_weight
        self.avg_type = avg_type
        self.var_type = var_type
        self.stay_on_last_token = stay_on_last_token
        self.average_method = average_method

        self.metric_calculator = {
            "differentiable_average_lagging": DifferentiableAverageLagging(),
            "average_lagging": AverageLagging(),
            "average_proportion": AverageProportion(),
        }

        self.variance_calculator = {
            "variance_delay": VarianceDelay(),
        }

    def expected_delays_from_attention(
        self, attention, source_padding_mask=None, target_padding_mask=None
    ):
        if type(attention) == list:
            # bsz, num_heads, tgt_len, src_len
            bsz, num_heads, tgt_len, src_len = attention[0].size()
            attention = torch.cat(attention, dim=1)
            bsz, num_heads_x_layers, tgt_len, src_len = attention.size()
            # bsz * num_heads * num_layers, tgt_len, src_len
            attention = attention.view(-1, tgt_len, src_len)
        else:
            # bsz * num_heads * num_layers, tgt_len, src_len
            bsz, tgt_len, src_len = attention.size()
            num_heads_x_layers = 1
            attention = attention.view(-1, tgt_len, src_len)

        if not self.stay_on_last_token:
            residual_attention = \
                1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
            attention = torch.cat(
                [attention[:, :, :-1], residual_attention],
                dim=2
            )

        # bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
        steps = (
            torch
            .arange(1, 1 + src_len)
            .unsqueeze(0)
            .unsqueeze(1)
            .expand_as(attention)
            .type_as(attention)
        )

        if source_padding_mask is not None:
            src_offset = (
                source_padding_mask.type_as(attention)
                .sum(dim=1, keepdim=True)
                .expand(bsz, num_heads_x_layers)
                .contiguous()
                .view(-1, 1)
            )
            src_lens = src_len - src_offset
            if source_padding_mask[:, 0].any():
                # Pad left
                src_offset = src_offset.view(-1, 1, 1)
                steps = steps - src_offset
                steps = steps.masked_fill(steps <= 0, 0)
        else:
            src_lens = attention.new_ones([bsz, num_heads_x_layers]) * src_len
            src_lens = src_lens.view(-1, 1)

        # bsz * num_heads_num_layers, tgt_len, src_len
        expected_delays = (steps * attention).sum(dim=2).view(
            bsz, num_heads_x_layers, tgt_len
        )

        if target_padding_mask is not None:
            expected_delays.masked_fill_(
                target_padding_mask.unsqueeze(1),
                0
            )

        return expected_delays, src_lens

    def avg_loss(self, expected_delays, src_lens, target_padding_mask):

        bsz, num_heads_x_layers, tgt_len = expected_delays.size()
        target_padding_mask = (
            target_padding_mask
            .unsqueeze(1)
            .expand_as(expected_delays)
            .contiguous()
            .view(-1, tgt_len)
        )

        if self.average_method == "average":
            # bsz * tgt_len
            expected_delays = expected_delays.mean(dim=1)
        elif self.average_method == "weighted_average":
            weights = torch.nn.functional.softmax(expected_delays, dim=1)
            expected_delays = torch.sum(expected_delays * weights, dim=1)
        elif self.average_method == "max":
            # bsz * num_heads_x_num_layers, tgt_len
            expected_delays = expected_delays.max(dim=1)[0]
        else:
            raise RuntimeError(f"{self.average_method} is not supported")

        src_lens = src_lens.view(bsz, -1)[:, :1]
        target_padding_mask = target_padding_mask.view(bsz, -1, tgt_len)[:, 0]

        if self.avg_weight > 0.0:
            if self.avg_type in self.metric_calculator:
                average_delays = self.metric_calculator[self.avg_type](
                    expected_delays, src_lens, target_padding_mask,
                    batch_first=True, start_from_zero=False
                )
            else:
                raise RuntimeError(f"{self.avg_type} is not supported.")

            # bsz * num_heads_x_num_layers, 1
            return self.avg_weight * average_delays.sum()
        else:
            return 0.0

    def var_loss(self, expected_delays, src_lens, target_padding_mask):
        src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[:, :1]
        if self.var_weight > 0.0:
            if self.var_type in self.variance_calculator:
                variance_delays = self.variance_calculator[self.var_type](
                    expected_delays, src_lens, target_padding_mask,
                    batch_first=True, start_from_zero=False
                )
            else:
                raise RuntimeError(f"{self.var_type} is not supported.")

            return self.var_weight * variance_delays.sum()
        else:
            return 0.0

    def loss(self, attention, source_padding_mask=None, target_padding_mask=None):
        expected_delays, src_lens = self.expected_delays_from_attention(
            attention, source_padding_mask, target_padding_mask
        )

        latency_loss = 0

        latency_loss += self.avg_loss(expected_delays, src_lens, target_padding_mask)

        latency_loss += self.var_loss(expected_delays, src_lens, target_padding_mask)

        return latency_loss
