# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch.utils.data
import socket
import socketserver
from torch import multiprocessing
import queue
import random
import time
import logging
import threading
import json
import pickle
import struct
import traceback
from collections import defaultdict

from fairseq.data import data_utils
from fairseq import utils, distributed_utils


logger = logging.getLogger(__name__)


class EpochListening:
    """Mixin for receiving updates whenever the epoch increments."""

    @property
    def can_reuse_epoch_itr_across_epochs(self):
        """
        Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
        this dataset across epochs.

        This needs to return ``False`` if the sample sizes can change across
        epochs, in which case we may need to regenerate batches at each epoch.
        If your dataset relies in ``set_epoch`` then you should consider setting
        this to ``False``.
        """
        return True

    def set_epoch(self, epoch):
        """Will receive the updated epoch number at the beginning of the epoch."""
        pass


class FairseqDataset(torch.utils.data.Dataset, EpochListening):
    """A dataset that provides helpers for batching."""

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def collater(self, samples):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch suitable for forwarding with a Model
        """
        raise NotImplementedError

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        raise NotImplementedError

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        raise NotImplementedError

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        return np.arange(len(self), dtype=np.int64)

    @property
    def supports_prefetch(self):
        """Whether this dataset supports prefetching."""
        return False

    def attr(self, attr: str, index: int):
        return getattr(self, attr, None)

    def prefetch(self, indices):
        """Prefetch the data required for this epoch."""
        raise NotImplementedError

    def get_batch_shapes(self):
        """
        Return a list of valid batch shapes, for example::

            [(8, 512), (16, 256), (32, 128)]

        The first dimension of each tuple is the batch size and can be ``None``
        to automatically infer the max batch size based on ``--max-tokens``.
        The second dimension of each tuple is the max supported length as given
        by :func:`fairseq.data.FairseqDataset.num_tokens`.

        This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
        to restrict batch shapes. This is useful on TPUs to avoid too many
        dynamic shapes (and recompilations).
        """
        return None

    def batch_by_size(
        self,
        indices,
        max_tokens=None,
        max_sentences=None,
        required_batch_size_multiple=1,
    ):
        """
        Given an ordered set of indices, return batches according to
        *max_tokens*, *max_sentences* and *required_batch_size_multiple*.
        """
        from fairseq.data import data_utils

        fixed_shapes = self.get_batch_shapes()
        if fixed_shapes is not None:

            def adjust_bsz(bsz, num_tokens):
                if bsz is None:
                    assert max_tokens is not None, "Must specify --max-tokens"
                    bsz = max_tokens // num_tokens
                if max_sentences is not None:
                    bsz = min(bsz, max_sentences)
                elif (
                    bsz >= required_batch_size_multiple
                    and bsz % required_batch_size_multiple != 0
                ):
                    bsz -= bsz % required_batch_size_multiple
                return bsz

            fixed_shapes = np.array(
                [
                    [adjust_bsz(bsz, num_tokens), num_tokens]
                    for (bsz, num_tokens) in fixed_shapes
                ]
            )

        return data_utils.batch_by_size(
            indices,
            num_tokens_fn=self.num_tokens,
            max_tokens=max_tokens,
            max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
            fixed_shapes=fixed_shapes,
        )

    def filter_indices_by_size(self, indices, max_sizes):
        """
        Filter a list of sample indices. Remove those that are longer than
        specified in *max_sizes*.

        WARNING: don't update, override method in child classes

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        if isinstance(max_sizes, float) or isinstance(max_sizes, int):
            if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
                ignored = indices[self.sizes[indices] > max_sizes].tolist()
                indices = indices[self.sizes[indices] <= max_sizes]
            elif (
                hasattr(self, "sizes")
                and isinstance(self.sizes, list)
                and len(self.sizes) == 1
            ):
                ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
                indices = indices[self.sizes[0][indices] <= max_sizes]
            else:
                indices, ignored = data_utils._filter_by_size_dynamic(
                    indices, self.size, max_sizes
                )
        else:
            indices, ignored = data_utils._filter_by_size_dynamic(
                indices, self.size, max_sizes
            )
        return indices, ignored

    @property
    def supports_fetch_outside_dataloader(self):
        """Whether this dataset supports fetching outside the workers of the dataloader."""
        return True


class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
    """
    For datasets that need to be read sequentially, usually because the data is
    being streamed or otherwise can't be manipulated on a single machine.
    """

    def __iter__(self):
        raise NotImplementedError


class DynamicDataset(FairseqIterableDataset):
    """
    Lazy dataset that reads batches from a TCP socket. The batches are usually sent by DynamicDatasetServer.
    Each task can use this DynamicDataset with its own task-dependent "collate_fn"
    """
    def __init__(self, port, device_id, collate_fn):
        self.port = port
        self.device_id = device_id
        self.epoch = 0
        self._len = None
        self.collate_fn = collate_fn

    def set_epoch(self, epoch):
        self.epoch = epoch

    def get_size(self):
        # not using __len__, as DataLoader would rely on it, and it is only approximate
        if self._len is None:
            self._len = self._read({'epoch': -1, 'device_id': self.device_id})
        return self._len

    def _read(self, input_):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
            sock.connect(('127.0.0.1', self.port))
            sock.sendall(bytes(json.dumps(input_) + '\n', 'utf-8'))
            size, = struct.unpack('Q', sock.recv(8))
            obj = bytearray()
            while len(obj) < size:
                obj.extend(sock.recv(min(1024, size - len(obj))))
            obj = pickle.loads(obj)
            return obj

    def __iter__(self):
        attempts = 0
        while True:
            try:
                batch = self._read({'epoch': self.epoch, 'device_id': self.device_id})
            except (ConnectionResetError, ConnectionRefusedError) as e:
                attempts += 1
                if attempts >= 3:
                    raise e                
                logger.info(traceback.format_exc())
                time.sleep(30)
                continue

            attempts = 0

            if batch is None:
                break

            # recursively convert every numpy array in this dict to torch tensors
            yield utils.apply_to_sample(
                lambda x: torch.tensor(x).long(),
                batch,
                numpy=True
            )

    def collater(self, samples):
        return self.collate_fn(samples[0])


class MonolingualReader:
    """
    Creates a line generator over several corpora. Picks a corpus at random (with some probability
    that depends on each corpus size), then reads the next line from this corpus. When reaching the end of a corpus,
    start from the beginning. Stop iterating once N lines have been produced, where N is the sum of all corpus
    sizes.

    Warning: because it is sequential, it is vital that the corpora be shuffled before using this reader.

    Args:
        paths (List[str]): list of corpus paths

    How to use:
        ```
        reader = MonolingualReader(args, paths)
        # do one epoch over reader
        itr = iter(reader)
        for index, line, meta in itr:
            # do stuff
        ```
    """
    def __init__(self, args, paths, queues):
        assert all(os.path.exists(path) for path in paths), 'error: some training files do not exist or are empty'

        self.paths = paths
        self.corpus_ids = np.arange(len(self.paths))
        self.files = [open(path, 'rb') for path in self.paths]
        self.sizes = self.get_sizes().astype(np.int64)
        self.total_lines = self.sizes.sum()
        self.probs = self.sizes / self.total_lines

        class_name = self.__class__.__name__
        logger.info(f'{class_name} | total lines {self.total_lines}')

    def get_sizes(self):
        """
        Count lines in each corpus: necessary to compute their sampling probability.
        """
        sizes = []
        for file in self.files:
            sizes.append(sum(1 for _ in file))
            file.seek(0)
        return np.array(sizes)

    def get_random_corpus(self):
        corpus_id = np.random.choice(self.corpus_ids, p=self.probs)
        return corpus_id, self.corpora[corpus_id]

    def __len__(self):
        return self.total_lines

    def __iter__(self):
        """
        Returns an iterator that yields sentence pairs for approximately one 'epoch'.
        """
        for index in range(len(self)):  # stop when N lines have been produced, N is the total number of lines.
            # pick a corpus at random
            corpus_id, corpus = self.get_random_corpus()
            file = self.files[corpus_id]
            while True:
                try:
                    # get next line pair from this corpus
                    line = self.read_line(file)
                    break
                except StopIteration:   # reached the end of this particular corpus, start from beginning
                    file.seek(0)
            yield line, {'line_id': index}

    def read_line(self, file):
        return next(file).strip().decode()

    def __del__(self):
        for f in self.files:
            f.close()


class DynamicDatasetServer:
    """
    Manages a number of processes that read lines from corpora, pre-process them and store them
    in batches.

    1. One reader process reads lines from the corpora and stores them in the reader queues.
    2. N worker processes (N = args.dynamic_dataset_workers) get line pairs from these queues,
    pre-process (e.g., tokenization and BPE), filter them by length and binarize them, and put them in the worker queues.
    3. One batcher process gets samples from the worker queues, by groups of K (K = args.dynamic_dataser_buffer),
    sorts them by length and puts them in batches (while respecting the batch size constraints: args.max_tokens and
    args.max_sentences). Then it puts these batches in the batcher queues (one queue per GPU).
    4. One server process gets batches from these batcher queues and serves them on a TCP socket to DynamicDataset.

    All queues have a maximum size, so as to not saturate memory.

    """
    def __init__(
        self, args, task, port, num_workers, buffer_size, **kwargs
    ):
        self.args = args
        self._reader = task.reader        # provides an iterator over lines 
        self._worker = task.worker        # pre-processes and binarizes its text input (a line or tuple of lines)
        self._batcher = task.batcher      # sorts samples in a buffer list by length and batches them

        self.port = port
        self.num_workers = max(1, num_workers)
        self.buffer_size = max(4, buffer_size)

        self.num_nodes = args.distributed_world_size // args.distributed_num_procs
        self.node_id = args.distributed_rank // args.distributed_num_procs

    def start(self):
        self.logging_queue = multiprocessing.Queue()
        self.reader_queues = [multiprocessing.Queue(100000) for _ in range(self.num_workers)]
        self.worker_queues = [multiprocessing.Queue(100000) for _ in range(self.num_workers)]

        # batches take a lot of memory, max size should be chosen carefully
        self.batcher_queues = [multiprocessing.Queue(maxsize=10000) for _ in range(self.args.distributed_num_procs)]
        self.starting_epoch = multiprocessing.Value('i', 0)
        self.epoch = [multiprocessing.Value('i', 0) for _ in range(self.args.distributed_num_procs)]

        self.start_training_event = multiprocessing.Event()
        self.end_of_training_event = multiprocessing.Event()
        self.worker_barrier = multiprocessing.Barrier(self.num_workers + 1)

        self.total_lines = multiprocessing.Value('i', 0)

        workers = []
        reader = multiprocessing.Process(target=self.reader)
        workers.append(reader)

        for worker_id in range(self.num_workers):
            worker = multiprocessing.Process(target=self.worker, args=(worker_id,))
            workers.append(worker)

        batcher = multiprocessing.Process(target=self.batcher)
        workers.append(batcher)

        server = multiprocessing.Process(target=self.server)
        workers.append(server)

        logger_ = threading.Thread(target=self.logger, args=(list(workers),))
        workers.append(logger_)

        for worker in workers:
            worker.daemon = True
            worker.start()

        return workers

    def reset_logger(self):
        # Each process has its own logger that writes to a queue
        # The DynamicDatasetServer.logger thread reads from this queue and writes to the main process' logger
        handler = logging.handlers.QueueHandler(self.logging_queue)
        logging.getLogger().handlers = [handler]

    def reader(self):
        """
        Reads line pairs from the provided corpora and stores them in the reader queues.
        """
        self.reset_logger()
        line_reader = self._reader()  # can take some time as it reads the corpora to count their lines
        self.total_lines.value = len(line_reader)

        self.start_training_event.wait()
        epoch = self.starting_epoch.value

        while (self.args.max_epoch <= 0 or epoch <= self.args.max_epoch) and not self.end_of_training_event.is_set():
            np.random.seed(self.args.seed + epoch)
            worker_id = 0

            for line, meta in line_reader:
                # TODO: more efficient iteration method when doing distributed training
                if meta['line_id'] % self.num_nodes != self.node_id:
                    continue
                
                self.reader_queues[worker_id].put((line, meta))
                worker_id = (worker_id + 1) % self.num_workers

            # notify the workers that this epoch is finished
            for worker_id in range(self.num_workers):
                self.reader_queues[worker_id].put(None)
            epoch += 1

        self.end_of_training_event.wait()
        for queue in self.reader_queues:
            queue.cancel_join_thread()
            queue.close()

    def worker(self, worker_id):
        self.reset_logger()

        self.start_training_event.wait()
        epoch = self.starting_epoch.value

        np.random.seed(self.args.seed + worker_id)
        reader_queue = self.reader_queues[worker_id]
        worker_queue = self.worker_queues[worker_id]

        while True:
            data = reader_queue.get()

            if data is None:  # this means that this epoch is finished
                # notify the batcher
                worker_queue.put((epoch, None))
                epoch += 1
                # wait for the other workers
                self.worker_barrier.wait()

                if epoch > self.args.max_epoch > 0 or self.end_of_training_event.is_set():
                    break
                else:
                    continue

            line, meta = data
            # pre-process the line (e.g., with tokenization + BPE + binarization)
            sample, OK = self._worker(line, meta, worker_id=worker_id)
            
            if not OK:   # e.g., sample is too long
                continue

            worker_queue.put((epoch, sample))

        self.end_of_training_event.wait()
        worker_queue.cancel_join_thread()
        worker_queue.close()

    def logger(self, workers):
        """
        Periodically logs statistics about the processes and their queues.

        For example:

        "workers alive 1/1/1/1/1 | lines 98319/97602 | samples 1666/0 | batches 9804/9788"
        means that all processes are alive (1 reader, 2 workers, 1 batcher and 1 server), there are
        98319 and 97602 elements in the reader queues, 1666 and 0 elements in the worker queues,
        and 9804 and 9788 elements in the batcher queue.

        This information can be used to debug DynamicDatasetServer and find hyper-parameters to improve its speed.
        For instance, if 'samples' and 'batches' are close to zero, it probably means that the workers are slower
        than Fairseq's training code, and that they are a bottleneck. Increasing '--dynamic-dataset-workers' might
        be necessary, or even optimizing the pre-processing code (e.g., BPE implementation).
        """
        self.end_of_training_event.wait(10)
        while not self.end_of_training_event.is_set():
            while True:
                try:
                    record = self.logging_queue.get_nowait()
                    if record is None:
                        break
                except queue.Empty:
                    break
                logging.getLogger(record.name).handle(record)

            alive = '/'.join(str(int(worker.is_alive())) for worker in workers)
            lines = '/'.join(str(queue.qsize()) for queue in self.reader_queues)
            samples = '/'.join(str(queue.qsize()) for queue in self.worker_queues)
            batches = '/'.join(str(queue.qsize()) for queue in self.batcher_queues)
            epoch = '/'.join(str(epoch.value) for epoch in self.epoch)
            logger.info(f'workers alive {alive} | lines {lines} | samples {samples} | batches {batches} | epoch {epoch}')
            self.end_of_training_event.wait(60)

    def batcher(self):
        """
        Single process that gets binarized samples from the worker queues by groups of '--dynamic-dataset-buffer' size,
        sorts them by length, stores them in batches and writes these batches to the batcher queues (one queue per GPU).
        """
        self.reset_logger()

        self.start_training_event.wait()
        epoch = self.starting_epoch.value

        np.random.seed(self.args.seed + epoch)
        worker_id = 0
        batcher_id = 0
        workers_finished = set()
        buffer, next_buffer = [], []  # to store data from the current and next epoch

        last_key = None

        def get_key(batch):
            return batch[0].get('meta', {}).get('key')

        while True:
            if len(buffer) >= self.buffer_size or len(workers_finished) >= self.num_workers:

                batches = self._batcher(buffer)
                np.random.shuffle(batches)

                while batches:
                    # If the 'key' info is present, GPUs should receive batches with the same key when possible,
                    # for efficiency reasons.
                    if batcher_id == 0 or last_key is None:
                        batch = batches.pop(-1)
                        last_key = get_key(batch)
                    else:
                        i = next((k for k, batch in enumerate(reversed(batches)) if get_key(batch) == last_key), 0)
                        batch = batches.pop(len(batches) - 1 - i)
                        last_key = get_key(batch)

                    # then put these batches in the batcher queues
                    self.batcher_queues[batcher_id].put((epoch, batch))
                    batcher_id = (batcher_id + 1) % len(self.batcher_queues)

                buffer = []
                if len(workers_finished) >= self.num_workers:
                    # this epoch is finished
                    buffer = next_buffer
                    next_buffer = []
                    for queue in self.batcher_queues:
                        queue.put((epoch, None))  # signal the next epoch to the server
                    self.worker_barrier.wait()
                    epoch += 1
                    batcher_id = 0
                    np.random.seed(self.args.seed + epoch)
                    if epoch > self.args.max_epoch > 0 or self.end_of_training_event.is_set():
                        break
                    else:
                        workers_finished = set()
                        continue
            elif worker_id in workers_finished:
                worker_id = (worker_id + 1) % self.num_workers
            else:
                worker_epoch, sample = self.worker_queues[worker_id].get()

                if sample is None:  # this particular worker has finished its epoch
                    workers_finished.add(worker_id)
                elif worker_epoch != epoch:
                    # this sample is from the next epoch
                    next_buffer.append(sample)
                else:
                    buffer.append(sample)
                
                worker_id = (worker_id + 1) % self.num_workers

        self.end_of_training_event.wait()
        for queue in self.batcher_queues:
            queue.cancel_join_thread()
            queue.close()

    def server(self):
        """
        Gets batches from the batcher queue and sends them to DynamicDataset through a TCP socket.
        """
        class RequestHandler(socketserver.StreamRequestHandler):
            def handle(self_):
                # the client (DynamicDataset) sends a JSON file with the current epoch and device id
                data = self_.rfile.readline().decode().strip()
                if not data:
                    return
                data = json.loads(data)

                client_epoch = data['epoch']
                device_id = data['device_id']  # ID of the queue from which to read batches

                if client_epoch > 0 and self.starting_epoch.value == 0:
                    # The fairseq client gives DynamicDataset the current epoch so that
                    # it can start reading lines and pre-process them for this epoch.
                    # The random seeds depend on this epoch number, so we have to wait
                    # for this information.
                    # We don't know it at start because DynamicDataset is started before
                    # checkpoint loading (for efficiency reasons: forks seem to be 
                    # much slower once the model has been loaded)
                    self.starting_epoch.value = client_epoch
                    self.start_training_event.set()

                server_epoch = self.epoch[device_id]

                if client_epoch == -1:  # Client is asking for the number of lines
                    batch = self.total_lines.value
                elif client_epoch < server_epoch.value:
                    # There are no more batches for this epoch
                    # When receiving None, DynamicDataset will end the current epoch
                    batch = None
                else:
                    if client_epoch > server_epoch.value:
                        # Client has moved to the next epoch
                        server_epoch.value = client_epoch

                    # Get the next batch for this client
                    _, batch = self.batcher_queues[device_id].get()

                    if batch is None:
                        # There are no more batches for this epoch
                        server_epoch.value += 1

                data = pickle.dumps(batch)
                self_.wfile.write(struct.pack('Q', len(data)))
                self_.wfile.write(data)
                del data, batch

        self.reset_logger()
        socketserver.TCPServer.allow_reuse_address = True
        with socketserver.TCPServer(('', self.port), RequestHandler) as server_:
            while not self.end_of_training_event.is_set():
                server_.handle_request()
