# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch
import functools
import re
import os
import collections

from fairseq.data import FairseqDataset, data_utils, DynamicDataset, indexed_dataset
from fairseq import distributed_utils, utils


logger = logging.getLogger(__name__)


def collate(
    samples,
    pad_idx,
    eos_idx,
    left_pad_source=True,
    left_pad_target=False,
    input_feeding=True,
    pad_to_length=None,
    pad_to_multiple=1,
):
    if len(samples) == 0:
        return {}

    def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx,
            eos_idx,
            left_pad,
            move_eos_to_beginning,
            pad_to_length=pad_to_length,
            pad_to_multiple=pad_to_multiple,
        )

    def check_alignment(alignment, src_len, tgt_len):
        if alignment is None or len(alignment) == 0:
            return False
        if (
            alignment[:, 0].max().item() >= src_len - 1
            or alignment[:, 1].max().item() >= tgt_len - 1
        ):
            logger.warning("alignment size mismatch found, skipping alignment!")
            return False
        return True

    def compute_alignment_weights(alignments):
        """
        Given a tensor of shape [:, 2] containing the source-target indices
        corresponding to the alignments, a weight vector containing the
        inverse frequency of each target index is computed.
        For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
        a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
        index 3 is repeated twice)
        """
        align_tgt = alignments[:, 1]
        _, align_tgt_i, align_tgt_c = torch.unique(
            align_tgt, return_inverse=True, return_counts=True
        )
        align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
        return 1.0 / align_weights.float()

    id = torch.LongTensor([s["id"] for s in samples])
    src_tokens = merge(
        "source",
        left_pad=left_pad_source,
        pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
    )
    # sort by descending source length
    src_lengths = torch.LongTensor(
        [s["source"].ne(pad_idx).long().sum() for s in samples]
    )
    src_lengths, sort_order = src_lengths.sort(descending=True)
    id = id.index_select(0, sort_order)
    src_tokens = src_tokens.index_select(0, sort_order)

    prev_output_tokens = None
    target = None
    if samples[0].get("target", None) is not None:
        target = merge(
            "target",
            left_pad=left_pad_target,
            pad_to_length=pad_to_length["target"]
            if pad_to_length is not None
            else None,
        )
        target = target.index_select(0, sort_order)
        tgt_lengths = torch.LongTensor(
            [s["target"].ne(pad_idx).long().sum() for s in samples]
        ).index_select(0, sort_order)
        ntokens = tgt_lengths.sum().item()

        if samples[0].get("prev_output_tokens", None) is not None:
            prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
        elif input_feeding:
            # we create a shifted version of targets for feeding the
            # previous output token(s) into the next decoder step
            prev_output_tokens = merge(
                "target",
                left_pad=left_pad_target,
                move_eos_to_beginning=True,
                pad_to_length=pad_to_length["target"]
                if pad_to_length is not None
                else None,
            )
    else:
        ntokens = src_lengths.sum().item()

    batch = {
        "id": id,
        "nsentences": len(samples),
        "ntokens": ntokens,
        "net_input": {
            "src_tokens": src_tokens,
            "src_lengths": src_lengths,
        },
        "target": target,
    }

    # assumes that have samples in this batch have the same src_lang and tgt_lang
    # (achieved by setting --batch-by-lang-pair or --batch-by-target-lang)
    batch['net_input']['meta'] = {
        'src_lang': samples[0].get('meta', {}).get('src_lang'),
        'tgt_lang': samples[0].get('meta', {}).get('tgt_lang'),
        'corpus_tag': samples[0].get('meta', {}).get('corpus_tag'),
    }

    if prev_output_tokens is not None:
        batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
            0, sort_order
        )

    if samples[0].get("alignment", None) is not None:
        bsz, tgt_sz = batch["target"].shape
        src_sz = batch["net_input"]["src_tokens"].shape[1]

        offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
        offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
        if left_pad_source:
            offsets[:, 0] += src_sz - src_lengths
        if left_pad_target:
            offsets[:, 1] += tgt_sz - tgt_lengths

        alignments = [
            alignment + offset
            for align_idx, offset, src_len, tgt_len in zip(
                sort_order, offsets, src_lengths, tgt_lengths
            )
            for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
            if check_alignment(alignment, src_len, tgt_len)
        ]

        if len(alignments) > 0:
            alignments = torch.cat(alignments, dim=0)
            align_weights = compute_alignment_weights(alignments)

            batch["alignments"] = alignments
            batch["align_weights"] = align_weights

    if samples[0].get("constraints", None) is not None:
        # Collate the packed constraints across the samples, padding to
        # the length of the longest sample.
        lens = [sample.get("constraints").size(0) for sample in samples]
        max_len = max(lens)
        constraints = torch.zeros((len(samples), max(lens))).long()
        for i, sample in enumerate(samples):
            constraints[i, 0 : lens[i]] = samples[i].get("constraints")
        batch["constraints"] = constraints

    return batch


class LanguagePairDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        src_dict (~fairseq.data.Dictionary): source vocabulary
        tgt (torch.utils.data.Dataset, optional): target dataset to wrap
        tgt_sizes (List[int], optional): target sentence lengths
        tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
        left_pad_source (bool, optional): pad source tensors on the left side
            (default: True).
        left_pad_target (bool, optional): pad target tensors on the left side
            (default: False).
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        input_feeding (bool, optional): create a shifted version of the targets
            to be passed into the model for teacher forcing (default: True).
        remove_eos_from_source (bool, optional): if set, removes eos from end
            of source if it's present (default: False).
        append_eos_to_target (bool, optional): if set, appends eos to end of
            target if it's absent (default: False).
        align_dataset (torch.utils.data.Dataset, optional): dataset
            containing alignments.
        constraints (Tensor, optional): 2d tensor with a concatenated, zero-
            delimited list of constraints for each sentence.
        append_bos (bool, optional): if set, appends bos to the beginning of
            source/target sentence.
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
        src_lang_id (int, optional): source language ID, if set, the collated batch
            will contain a field 'src_lang_id' in 'net_input' which indicates the
            source language of the samples.
        tgt_lang_id (int, optional): target language ID, if set, the collated batch
            will contain a field 'tgt_lang_id' which indicates the target language
             of the samples.
    """

    def __init__(
        self,
        src,
        src_sizes,
        src_dict,
        tgt=None,
        tgt_sizes=None,
        tgt_dict=None,
        left_pad_source=True,
        left_pad_target=False,
        shuffle=True,
        input_feeding=True,
        remove_eos_from_source=False,
        append_eos_to_target=False,
        align_dataset=None,
        constraints=None,
        append_bos=False,
        eos=None,
        num_buckets=0,
        src_lang_id=None,
        tgt_lang_id=None,
        pad_to_multiple=1,
        src_lang=None,
        tgt_lang=None,
        corpus_tag=None,
    ):
        if tgt_dict is not None:
            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()
        if tgt is not None:
            assert len(src) == len(
                tgt
            ), "Source and target must contain the same number of examples"
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.sizes = (
            np.vstack((self.src_sizes, self.tgt_sizes)).T
            if self.tgt_sizes is not None
            else self.src_sizes
        )
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
        self.left_pad_source = left_pad_source
        self.left_pad_target = left_pad_target
        self.shuffle = shuffle
        self.input_feeding = input_feeding
        self.remove_eos_from_source = remove_eos_from_source
        self.append_eos_to_target = append_eos_to_target
        self.align_dataset = align_dataset
        if self.align_dataset is not None:
            assert (
                self.tgt_sizes is not None
            ), "Both source and target needed when alignments are provided"
        self.constraints = constraints
        self.append_bos = append_bos
        self.eos = eos if eos is not None else src_dict.eos()
        self.src_lang_id = src_lang_id
        self.tgt_lang_id = tgt_lang_id
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.corpus_tag = corpus_tag
        if num_buckets > 0:
            from fairseq.data import BucketPadLengthDataset

            self.src = BucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=self.src_dict.pad(),
                left_pad=self.left_pad_source,
            )
            self.src_sizes = self.src.sizes
            logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
            if self.tgt is not None:
                self.tgt = BucketPadLengthDataset(
                    self.tgt,
                    sizes=self.tgt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.tgt_dict.pad(),
                    left_pad=self.left_pad_target,
                )
                self.tgt_sizes = self.tgt.sizes
                logger.info(
                    "bucketing target lengths: {}".format(list(self.tgt.buckets))
                )

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to BucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None
        self.pad_to_multiple = pad_to_multiple

    def get_batch_shapes(self):
        return self.buckets

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        src_item = self.src[index]
        # Append EOS to end of tgt sentence if it does not have an EOS and remove
        # EOS from end of src sentence if it exists. This is useful when we use
        # use existing datasets for opposite directions i.e., when we want to
        # use tgt_dataset as src_dataset and vice versa
        if self.append_eos_to_target:
            eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
            if self.tgt and self.tgt[index][-1] != eos:
                tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])

        if self.append_bos:
            bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
            if self.tgt and self.tgt[index][0] != bos:
                tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])

            bos = self.src_dict.bos()
            if self.src[index][0] != bos:
                src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])

        if self.remove_eos_from_source:
            eos = self.src_dict.eos()
            if self.src[index][-1] == eos:
                src_item = self.src[index][:-1]

        example = {
            "id": index,
            "source": src_item,
            "target": tgt_item,
            "meta": {"src_lang": self.src_lang, "tgt_lang": self.tgt_lang, "corpus_tag": self.corpus_tag},
        }
        if self.align_dataset is not None:
            example["alignment"] = self.align_dataset[index]
        if self.constraints is not None:
            example["constraints"] = self.constraints[index]
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples, pad_to_length=None):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate
            pad_to_length (dict, optional): a dictionary of
                {'source': source_pad_to_length, 'target': target_pad_to_length}
                to indicate the max length to pad to in source and target respectively.

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.
                  - `src_lang_id` (LongTensor): a long Tensor which contains source
                    language IDs of each sample in the batch

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
                - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
                   IDs of each sample in the batch
        """
        res = collate(
            samples,
            pad_idx=self.src_dict.pad(),
            eos_idx=self.eos,
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
            pad_to_length=pad_to_length,
            pad_to_multiple=self.pad_to_multiple,
        )
        if self.src_lang_id is not None or self.tgt_lang_id is not None:
            src_tokens = res["net_input"]["src_tokens"]
            bsz = src_tokens.size(0)
            if self.src_lang_id is not None:
                res["net_input"]["src_lang_id"] = (
                    torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
                )
            if self.tgt_lang_id is not None:
                res["tgt_lang_id"] = (
                    torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
                )
        return res

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return max(
            self.src_sizes[index],
            self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
        )

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (
            self.src_sizes[index],
            self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
        )

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self)).astype(np.int64)
        else:
            indices = np.arange(len(self), dtype=np.int64)
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
            return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
        else:
            # sort by bucketed_num_tokens, which is:
            #   max(padded_src_len, padded_tgt_len)
            return indices[
                np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
            ]

    @property
    def supports_prefetch(self):
        return getattr(self.src, "supports_prefetch", False) and (
            getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
        )

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)
        if self.align_dataset is not None:
            self.align_dataset.prefetch(indices)

    def filter_indices_by_size(self, indices, max_sizes):
        """Filter a list of sample indices. Remove those that are longer
            than specified in max_sizes.

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        return data_utils.filter_paired_dataset_indices_by_size(
            self.src_sizes,
            self.tgt_sizes,
            indices,
            max_sizes,
        )

    @classmethod
    def remove_empty_lines(cls, src_dataset, tgt_dataset):
        assert isinstance(src_dataset, indexed_dataset.IndexedRawTextDataset) and \
               isinstance(tgt_dataset, indexed_dataset.IndexedRawTextDataset)

        for i in range(len(src_dataset) - 1, -1, -1):
            if not src_dataset.lines[i] or not tgt_dataset.lines[i]:
                src_dataset.delete_item(i)
                tgt_dataset.delete_item(i)


class DynamicLanguagePairDataset(DynamicDataset):
    """ Subclass of DynamicDataset that can be used to dynamically read batches from a TCP socket for TranslationTask """
    def __init__(self, *args, **collater_options):
        collate_fn = functools.partial(collate, **collater_options)
        super().__init__(*args, collate_fn)


class ParallelReader:
    """
    Creates a parallel line generator over several parallel corpora. Picks a corpus at random (with some probability
    that depends on each corpus size), then reads the next line from this corpus. When reaching the end of a corpus,
    start from the beginning. Stop iterating once N line pairs have been produced, where N is the sum of all corpus
    sizes.

    Warning: because it is sequential, it is vital that the corpora be shuffled before using this reader.

    Args:
        corpora (List[ParallelCorpus]): list of parallel corpora

    How to use:
        ```
        reader = ParallelReader(src_paths, tgt_paths)
        # do one epoch over reader
        itr = iter(reader)
        for index, (src_line, tgt_line), meta in itr:
            # do stuff
        ```
    """
    def __init__(self, args, corpora):
        self.args = args
        self.corpora = corpora
        self.corpus_ids = np.arange(len(self.corpora))

        assert all(corpus.exists() for corpus in self.corpora), 'error: some training files do not exist or are empty'

        class_name = self.__class__.__name__

        cache = {}   # avoid reading the same files over and over to find the line positions
        self.files = []
        for corpus in self.corpora:
            src_path = os.path.realpath(corpus.src_path)
            tgt_path = os.path.realpath(corpus.tgt_path)

            if self.args.dynamic_dataset_block_size < 1:
                file_ = InfiniteParallelFileIterator(
                    corpus.src_path,
                    corpus.tgt_path,
                    args.skip_empty_lines,
                    src_positions=cache.get(src_path),
                    tgt_positions=cache.get(tgt_path),
                )
            else:
                file_ = RandomInfiniteParallelFileIterator(
                    corpus.src_path,
                    corpus.tgt_path,
                    args.skip_empty_lines,
                    args.dynamic_dataset_block_size,
                    src_positions=cache.get(src_path),
                    tgt_positions=cache.get(tgt_path),
                )

            cache[src_path] = file_.src_positions
            cache[tgt_path] = file_.tgt_positions

            self.files.append(file_)
            if args.dynamic_dataset_verbose:
                logger.info(f'{class_name} | {corpus.corpus_id} | lines {file_.size}')

        self.sizes = self.get_sizes()  # real corpus size in number of lines
        # adjust each corpus "size" depending on its multiplier
        self.sizes = (
            self.sizes * np.array(
                [1 if corpus.multiplier is None else corpus.multiplier for corpus in self.corpora]
            )
        ).astype(np.int64)
        self.total_lines = self.sizes.sum()

        # by default the length of one epoch is the total number of lines in all corpora (adjusted with sampling multipliers),
        # but one can manually set the epoch length as a fixed number of line pairs (--lines-per-epoch), or as a multiple of
        # the true epoch length (--scale-epoch-length, can be real numbers lower or higher than 1)
        self.lines_per_epoch = args.lines_per_epoch or self.total_lines
        if args.scale_epoch_length:
            self.lines_per_epoch = max(1, int(self.lines_per_epoch * args.scale_epoch_length))

        logger.info(
            f'{class_name} | total lines {self.total_lines} | per epoch {self.lines_per_epoch} '
            f'({100 * self.lines_per_epoch / self.total_lines:.1f}%)')

        lines_per_lang = collections.defaultdict(int)
        for size, corpus in zip(self.sizes, self.corpora):
            lines_per_lang[(corpus.src_lang, corpus.tgt_lang)] += size

        # compute probability to sample lines from each corpus
        if args.lang_temperature and args.lang_temperature != 1:
            # the temperature parameter only applies to different language pairs, not corpora
            lang_prob_sum = np.array([
                (line_count / self.total_lines) ** (1 / args.lang_temperature)
                for line_count in lines_per_lang.values()
            ]).sum()
            lines_per_lang_ = np.array([lines_per_lang[(corpus.src_lang, corpus.tgt_lang)] for corpus in self.corpora])
            # compute the probability of each lang pair and corpus independently
            prob_per_lang = lines_per_lang_ / self.total_lines
            # re-normalize the probability of each language pair with temperature
            prob_per_lang = prob_per_lang ** (1 / args.lang_temperature) / lang_prob_sum
            prob_per_corpus = self.sizes / lines_per_lang_
            self.probs = prob_per_corpus * prob_per_lang
        else:
            self.probs = self.sizes / self.total_lines

        if len(lines_per_lang) > 1:
            for (src_lang, tgt_lang), size in lines_per_lang.items():
                prob = sum([
                    p for p, corpus in zip(self.probs, self.corpora)
                    if corpus.src_lang == src_lang and corpus.tgt_lang == tgt_lang], 0)
                logger.info(
                    f'{class_name} | {src_lang}-{tgt_lang} | prob {prob:.3f} | lines {size} '
                    f'({100 * size / self.total_lines:.1f}%)')

        for corpus, prob, size in zip(self.corpora, self.probs, self.sizes):
            logger.info(
                f'{class_name} | {corpus.corpus_id} | prob {prob:.3f} | lines {size} '
                f'({100 * size / self.total_lines:.1f}%)')

    def get_sizes(self):
        """
        Count lines in each corpus: necessary to compute their sampling probability.
        """
        return np.array([file_.size for file_ in self.files])

    def get_random_corpus(self):
        corpus_id = np.random.choice(self.corpus_ids, p=self.probs)
        return corpus_id, self.corpora[corpus_id]

    def __len__(self):
        return self.lines_per_epoch

    def __iter__(self):
        """
        Returns an iterator that yields sentence pairs for approximately one 'epoch'.
        """
        for index in range(len(self)):  # stop when N lines have been produced, N is the total number of lines.
            # pick a corpus at random
            corpus_id, corpus = self.get_random_corpus()
            src_line, tgt_line = next(self.files[corpus_id])
            # FIXME: line_id DOES NOT correspond to the position of this line pair in corpus, should we also save
            # this info in the metadata?
            yield (src_line, tgt_line), {**corpus.meta, 'line_id': index}

    def read_line(self, src_file, tgt_file):
        src_line = next(src_file).strip().decode()
        tgt_line = next(tgt_file).strip().decode()
        return src_line, tgt_line


class InfiniteParallelFileIterator:
    def __init__(self, src_path, tgt_path, skip_empty_lines=False, src_positions=None, tgt_positions=None):
        self.src_file = open(src_path, 'rb')
        self.tgt_file = open(tgt_path, 'rb')
        self.skip_empty_lines = skip_empty_lines
        self.src_positions = self.tgt_positions = None  # TODO
        self.src_file.seek(0)
        self.tgt_file.seek(0)
        size = 0
        for src_line, tgt_line in zip(self.src_file, self.tgt_file):
            src_line = src_line.decode().strip()
            tgt_line = tgt_line.decode().strip()
            if not self.skip_empty_lines or src_line and tgt_line:
                size += 1
        self.src_file.seek(0)
        self.tgt_file.seek(0)
        self.size = size

    def close(self):
        self.src_file.close()
        self.tgt_file.close()

    def __next__(self):
        try:
            src_line = next(self.src_file).decode().strip()
            tgt_line = next(self.tgt_file).decode().strip()
            if not self.skip_empty_lines or src_line and tgt_line:
                return src_line, tgt_line
            else:
                return next(self)
        except StopIteration:
            self.src_file.seek(0)
            self.tgt_file.seek(0)
            return next(self)

    def __iter__(self):
        while True:
            yield next(self)

    def __del__(self):
        self.close()


class RandomInfiniteParallelFileIterator:
    def __init__(self, src_path, tgt_path, skip_empty_lines=False, block_size=256, src_positions=None, tgt_positions=None):
        self.src_file = open(src_path, 'rb')
        self.tgt_file = open(tgt_path, 'rb')
        self.skip_empty_lines = skip_empty_lines
        self.block_size = block_size

        self.src_positions = self.get_positions(self.src_file) if src_positions is None else src_positions
        self.tgt_positions = self.get_positions(self.tgt_file) if tgt_positions is None else tgt_positions

        blocks = []
        size = block_size = 0
        src_pos = tgt_pos = 0
        for src_pos_, tgt_pos_ in zip(self.src_positions, self.tgt_positions):
            if src_pos_ == -1 or tgt_pos_ == -1:
                continue

            if block_size == 0:
                src_pos = src_pos_
                tgt_pos = tgt_pos_

            size += 1
            block_size += 1
            if block_size == self.block_size:
                blocks.append((src_pos, tgt_pos))
                block_size = 0

        if block_size != 0:
            blocks.append((src_pos, tgt_pos))

        self.size = size
        self.blocks = np.array(blocks)
        self._iter = iter(self)

    def get_positions(self, file_):
        file_.seek(0)
        pos = 0
        positions = []
        for line in file_:
            if not self.skip_empty_lines or line.strip():
                positions.append(pos)
            else:
                positions.append(-1)
            pos = file_.tell()
        file_.seek(0)
        return np.array(positions, dtype=np.int64)

    def close(self):
        self.src_file.close()
        self.tgt_file.close()

    def __len__(self):
        return self.size

    def __next__(self):
        return next(self._iter)

    def __iter__(self):
        while True:
            blocks = np.random.permutation(self.blocks)

            for src_pos, tgt_pos in blocks:
                self.src_file.seek(src_pos)
                self.tgt_file.seek(tgt_pos)
                block_size = 0
                while block_size < self.block_size:
                    try:
                        src_line = next(self.src_file).strip()
                        tgt_line = next(self.tgt_file).strip()
                    except StopIteration:  # last block may be shorter
                        break

                    if not self.skip_empty_lines or src_line and tgt_line:
                        yield src_line.decode(), tgt_line.decode()
                        block_size += 1

    def __del__(self):
        self.close()
