# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from fairseq.search import Search


class NoisyChannelBeamSearch(Search):

    def __init__(self, tgt_dict):
        super().__init__(tgt_dict)
        self.fw_scores_buf = None
        self.lm_scores_buf = None

    def _init_buffers(self, t):
        # super()._init_buffers(t)
        if self.fw_scores_buf is None:
            self.scores_buf = t.new()
            self.indices_buf = torch.LongTensor().to(device=t.device)
            self.beams_buf = torch.LongTensor().to(device=t.device)
            self.fw_scores_buf = t.new()
            self.lm_scores_buf = t.new()

    def combine_fw_bw(self, combine_method, fw_cum, bw, step):
        if combine_method == "noisy_channel":
            fw_norm = fw_cum.div(step + 1)
            lprobs = bw + fw_norm
        elif combine_method == "lm_only":
            lprobs = bw + fw_cum

        return lprobs

    def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method):
        self._init_buffers(fw_lprobs)
        bsz, beam_size, vocab_size = fw_lprobs.size()

        if step == 0:
            # at the first step all hypotheses are equally likely, so use
            # only the first beam
            fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous()
            bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous()
            # nothing to add since we are at the first step
            fw_lprobs_cum = fw_lprobs

        else:
            # make probs contain cumulative scores for each hypothesis
            raw_scores = (scores[:, :, step - 1].unsqueeze(-1))
            fw_lprobs_cum = (fw_lprobs.add(raw_scores))

        combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step)

        # choose the top k according to the combined noisy channel model score
        torch.topk(
            combined_lprobs.view(bsz, -1),
            k=min(
                # Take the best 2 x beam_size predictions. We'll choose the first
                # beam_size of these which don't predict eos to continue with.
                beam_size * 2,
                combined_lprobs.view(bsz, -1).size(1) - 1,  # -1 so we never select pad
            ),
            out=(self.scores_buf, self.indices_buf),
        )
        # save corresponding fw and lm scores
        self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf)
        self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf)
        # Project back into relative indices and beams
        self.beams_buf = self.indices_buf // vocab_size
        self.indices_buf.fmod_(vocab_size)
        return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf
