# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
    This module contains collection of classes which implement
    collate functionalities for various tasks.

    Collaters should know what data to expect for each sample
    and they should pack / collate them into batches
"""


from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np

import torch
from fairseq.data import data_utils as fairseq_data_utils


class Seq2SeqCollater(object):
    """
        Implements collate function mainly for seq2seq tasks
        This expects each sample to contain feature (src_tokens) and
        targets.
        This collator is also used for aligned training task.
    """

    def __init__(
        self,
        feature_index=0,
        label_index=1,
        pad_index=1,
        eos_index=2,
        move_eos_to_beginning=True,
    ):
        self.feature_index = feature_index
        self.label_index = label_index
        self.pad_index = pad_index
        self.eos_index = eos_index
        self.move_eos_to_beginning = move_eos_to_beginning

    def _collate_frames(self, frames):
        """Convert a list of 2d frames into a padded 3d tensor
        Args:
            frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is
                length of i-th frame and f_dim is static dimension of features
        Returns:
            3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
        """
        len_max = max(frame.size(0) for frame in frames)
        f_dim = frames[0].size(1)
        res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0)

        for i, v in enumerate(frames):
            res[i, : v.size(0)] = v

        return res

    def collate(self, samples):
        """
        utility function to collate samples into batch for speech recognition.
        """
        if len(samples) == 0:
            return {}

        # parse samples into torch tensors
        parsed_samples = []
        for s in samples:
            # skip invalid samples
            if s["data"][self.feature_index] is None:
                continue
            source = s["data"][self.feature_index]
            if isinstance(source, (np.ndarray, np.generic)):
                source = torch.from_numpy(source)
            target = s["data"][self.label_index]
            if isinstance(target, (np.ndarray, np.generic)):
                target = torch.from_numpy(target).long()
            elif isinstance(target, list):
                target = torch.LongTensor(target)

            parsed_sample = {"id": s["id"], "source": source, "target": target}
            parsed_samples.append(parsed_sample)
        samples = parsed_samples

        id = torch.LongTensor([s["id"] for s in samples])
        frames = self._collate_frames([s["source"] for s in samples])
        # sort samples by descending number of frames
        frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
        frames_lengths, sort_order = frames_lengths.sort(descending=True)
        id = id.index_select(0, sort_order)
        frames = frames.index_select(0, sort_order)

        target = None
        target_lengths = None
        prev_output_tokens = None
        if samples[0].get("target", None) is not None:
            ntokens = sum(len(s["target"]) for s in samples)
            target = fairseq_data_utils.collate_tokens(
                [s["target"] for s in samples],
                self.pad_index,
                self.eos_index,
                left_pad=False,
                move_eos_to_beginning=False,
            )
            target = target.index_select(0, sort_order)
            target_lengths = torch.LongTensor(
                [s["target"].size(0) for s in samples]
            ).index_select(0, sort_order)
            prev_output_tokens = fairseq_data_utils.collate_tokens(
                [s["target"] for s in samples],
                self.pad_index,
                self.eos_index,
                left_pad=False,
                move_eos_to_beginning=self.move_eos_to_beginning,
            )
            prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
        else:
            ntokens = sum(len(s["source"]) for s in samples)

        batch = {
            "id": id,
            "ntokens": ntokens,
            "net_input": {"src_tokens": frames, "src_lengths": frames_lengths},
            "target": target,
            "target_lengths": target_lengths,
            "nsentences": len(samples),
        }
        if prev_output_tokens is not None:
            batch["net_input"]["prev_output_tokens"] = prev_output_tokens
        return batch
