/* Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved. */

/* Helper methods for fast index mapping builds */

#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <stdexcept>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <random>

namespace py = pybind11;
using namespace std;

const int32_t LONG_SENTENCE_LEN = 512;

void build_blending_indices(py::array_t<int16_t> &dataset_index,
                            py::array_t<int64_t> &dataset_sample_index,
                            const py::array_t<double> &weights,
                            const int32_t num_datasets,
                            const int64_t size, const bool verbose)
{
  /* Given multiple datasets and a weighting array, build samples
   such that it follows those wieghts.*/

  if (verbose)
  {
    std::cout << "> building indices for blended datasets ..." << std::endl;
  }

  // Get the pointer access without the checks.
  auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
  auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
  auto weights_ptr = weights.unchecked<1>();

  // Initialize buffer for number of samples used for each dataset.
  int64_t current_samples[num_datasets];
  for (int64_t i = 0; i < num_datasets; ++i)
  {
    current_samples[i] = 0;
  }

  // For each sample:
  for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx)
  {

    // Determine where the max error in sampling is happening.
    auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
    int64_t max_error_index = 0;
    double max_error = weights_ptr[0] * sample_idx_double -
                       static_cast<double>(current_samples[0]);
    for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx)
    {
      double error = weights_ptr[dataset_idx] * sample_idx_double -
                     static_cast<double>(current_samples[dataset_idx]);
      if (error > max_error)
      {
        max_error = error;
        max_error_index = dataset_idx;
      }
    }

    // Populate the indices.
    dataset_index_ptr[sample_idx] = static_cast<int16_t>(max_error_index);
    dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];

    // Update the total samples.
    current_samples[max_error_index] += 1;
  }

  // print info
  if (verbose)
  {
    std::cout << " > sample ratios:" << std::endl;
    for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx)
    {
      auto ratio = static_cast<double>(current_samples[dataset_idx]) /
                   static_cast<double>(size);
      std::cout << "   dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
    }
  }
}

py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
                           const py::array_t<int32_t> &doc_idx_,
                           const int32_t seq_length,
                           const int32_t num_epochs,
                           const int64_t tokens_per_epoch)
{
  /* Sample index (sample_idx) is used for gpt2 like dataset for which
     the documents are flattened and the samples are built based on this
     1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
     where [..., 0] contains the index into `doc_idx` and [..., 1] is the
     starting offset in that document.*/

  // Consistency checks.
  assert(seq_length > 1);
  assert(num_epochs > 0);
  assert(tokens_per_epoch > 1);

  // Remove bound checks.
  auto sizes = sizes_.unchecked<1>();
  auto doc_idx = doc_idx_.unchecked<1>();

  // Mapping and it's length (1D).
  int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
  int32_t *sample_idx = new int32_t[2 * (num_samples + 1)];

  // Index into sample_idx.
  int64_t sample_index = 0;
  // Index into doc_idx.
  int64_t doc_idx_index = 0;
  // Begining offset for each document.
  int32_t doc_offset = 0;
  // Start with first document and no offset.
  sample_idx[2 * sample_index] = doc_idx_index;
  sample_idx[2 * sample_index + 1] = doc_offset;
  ++sample_index;

  while (sample_index <= num_samples)
  {
    // Start with a fresh sequence.
    int32_t remaining_seq_length = seq_length + 1;
    while (remaining_seq_length != 0)
    {
      // Get the document length.
      auto doc_id = doc_idx[doc_idx_index];
      auto doc_length = sizes[doc_id] - doc_offset;
      // And add it to the current sequence.
      remaining_seq_length -= doc_length;
      // If we have more than a full sequence, adjust offset and set
      // remaining length to zero so we return from the while loop.
      // Note that -1 here is for the same reason we have -1 in
      // `_num_epochs` calculations.
      if (remaining_seq_length <= 0)
      {
        doc_offset += (remaining_seq_length + doc_length - 1);
        remaining_seq_length = 0;
      }
      else
      {
        // Otherwise, start from the begining of the next document.
        ++doc_idx_index;
        doc_offset = 0;
      }
    }
    // Record the sequence.
    sample_idx[2 * sample_index] = doc_idx_index;
    sample_idx[2 * sample_index + 1] = doc_offset;
    ++sample_index;
  }

  // Method to deallocate memory.
  py::capsule free_when_done(sample_idx, [](void *mem_)
                             {
	int32_t *mem = reinterpret_cast<int32_t*>(mem_);
	delete[] mem; });

  // Return the numpy array.
  const auto byte_size = sizeof(int32_t);
  return py::array(std::vector<int64_t>{num_samples + 1, 2}, // shape
                   {2 * byte_size, byte_size},               // C-style contiguous strides
                   sample_idx,                               // the data pointer
                   free_when_done);                          // numpy array references
}

inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
                                     const int32_t max_length,
                                     std::mt19937 &rand32_gen)
{
  /* Training sample length. */
  if (short_seq_ratio == 0)
  {
    return max_length;
  }
  const auto random_number = rand32_gen();
  if ((random_number % short_seq_ratio) == 0)
  {
    return 2 + random_number % (max_length - 1);
  }
  return max_length;
}

template <typename DocIdx>
py::array build_mapping_impl(const py::array_t<int64_t> &docs_,
                             const py::array_t<int32_t> &sizes_,
                             const int32_t num_epochs,
                             const uint64_t max_num_samples,
                             const int32_t max_seq_length,
                             const double short_seq_prob,
                             const int32_t seed,
                             const bool verbose,
                             const int32_t min_num_sent)
{
  /* Build a mapping of (start-index, end-index, sequence-length) where
     start and end index are the indices of the sentences in the sample
     and sequence-length is the target sequence length.
  */

  // Consistency checks.
  assert(num_epochs > 0);
  assert(max_seq_length > 1);
  assert(short_seq_prob >= 0.0);
  assert(short_seq_prob <= 1.0);
  assert(seed > 0);

  // Remove bound checks.
  auto docs = docs_.unchecked<1>();
  auto sizes = sizes_.unchecked<1>();

  // For efficiency, convert probability to ratio. Note: rand() generates int.
  int32_t short_seq_ratio = 0;
  if (short_seq_prob > 0)
  {
    short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
  }

  if (verbose)
  {
    const auto sent_start_index = docs[0];
    const auto sent_end_index = docs[docs_.shape(0) - 1];
    const auto num_sentences = sent_end_index - sent_start_index;
    cout << "    using:" << endl
         << std::flush;
    cout << "     number of documents:            " << docs_.shape(0) - 1 << endl
         << std::flush;
    cout << "     sentences range:                [" << sent_start_index << ", " << sent_end_index << ")" << endl
         << std::flush;
    cout << "     total number of sentences:      " << num_sentences << endl
         << std::flush;
    cout << "     number of epochs:               " << num_epochs << endl
         << std::flush;
    cout << "     maximum number of samples:      " << max_num_samples << endl
         << std::flush;
    cout << "     maximum sequence length:        " << max_seq_length << endl
         << std::flush;
    cout << "     short sequence probability:     " << short_seq_prob << endl
         << std::flush;
    cout << "     short sequence ration (1/prob): " << short_seq_ratio << endl
         << std::flush;
    cout << "     seed:                           " << seed << endl
         << std::flush;
  }

  // Mapping and it's length (1D).
  int64_t num_samples = -1;
  DocIdx *maps = NULL;

  // Perform two iterations, in the first iteration get the size
  // and allocate memory and in the second iteration populate the map.
  bool second = false;
  for (int32_t iteration = 0; iteration < 2; ++iteration)
  {

    // Set the seed so both iterations produce the same results.
    std::mt19937 rand32_gen(seed);

    // Set the flag on second iteration.
    second = (iteration == 1);

    // Counters:
    uint64_t empty_docs = 0;
    uint64_t one_sent_docs = 0;
    uint64_t long_sent_docs = 0;

    // Current map index.
    uint64_t map_index = 0;

    // For each epoch:
    for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
    {
      if (map_index >= max_num_samples)
      {
        if (verbose && (!second))
        {
          cout << "    reached " << max_num_samples << " samples after "
               << epoch << " epochs ..." << endl
               << std::flush;
        }
        break;
      }
      // For each document:
      for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
      {

        // Document sentences are in [sent_index_first, sent_index_last)
        const auto sent_index_first = docs[doc];
        const auto sent_index_last = docs[doc + 1];

        // At the begining of the document previous index is the
        // start index.
        auto prev_start_index = sent_index_first;

        // Remaining documents.
        auto num_remain_sent = sent_index_last - sent_index_first;

        // Some bookkeeping
        if ((epoch == 0) && (!second))
        {
          if (num_remain_sent == 0)
          {
            ++empty_docs;
          }
          if (num_remain_sent == 1)
          {
            ++one_sent_docs;
          }
        }

        // Detect documents with long sentences.
        bool contains_long_sentence = false;
        if (num_remain_sent > 1)
        {
          for (auto sent_index = sent_index_first;
               sent_index < sent_index_last; ++sent_index)
          {
            if (sizes[sent_index] > LONG_SENTENCE_LEN)
            {
              if ((epoch == 0) && (!second))
              {
                ++long_sent_docs;
              }
              contains_long_sentence = true;
              break;
            }
          }
        }

        // If we have more than two sentences.
        if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
        {

          // Set values.
          auto seq_len = int32_t{0};
          auto num_sent = int32_t{0};
          auto target_seq_len = get_target_sample_len(short_seq_ratio,
                                                      max_seq_length,
                                                      rand32_gen);

          // Loop through sentences.
          for (auto sent_index = sent_index_first;
               sent_index < sent_index_last; ++sent_index)
          {

            // Add the size and number of sentences.
            seq_len += sizes[sent_index];
            ++num_sent;
            --num_remain_sent;

            // If we have reached the target length.
            // and if not only one sentence is left in the document.
            // and if we have at least two sentneces.
            // and if we have reached end of the document.
            if (((seq_len >= target_seq_len) &&
                 (num_remain_sent > 1) &&
                 (num_sent >= min_num_sent)) ||
                (num_remain_sent == 0))
            {

              // Check for overflow.
              if ((3 * map_index + 2) >
                  std::numeric_limits<int64_t>::max())
              {
                cout << "number of samples exceeded maximum "
                     << "allowed by type int64: "
                     << std::numeric_limits<int64_t>::max()
                     << endl;
                throw std::overflow_error("Number of samples");
              }

              // Populate the map.
              if (second)
              {
                const auto map_index_0 = 3 * map_index;
                maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
                maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
                maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
              }

              // Update indices / counters.
              ++map_index;
              prev_start_index = sent_index + 1;
              target_seq_len = get_target_sample_len(short_seq_ratio,
                                                     max_seq_length,
                                                     rand32_gen);
              seq_len = 0;
              num_sent = 0;
            }

          } // for (auto sent_index=sent_index_first; ...
        }   // if (num_remain_sent > 1) {
      }     // for (int doc=0; doc < num_docs; ++doc) {
    }       // for (int epoch=0; epoch < num_epochs; ++epoch) {

    if (!second)
    {
      if (verbose)
      {
        cout << "   number of empty documents: " << empty_docs << endl
             << std::flush;
        cout << "   number of documents with one sentence: " << one_sent_docs << endl
             << std::flush;
        cout << "   number of documents with long sentences: " << long_sent_docs << endl
             << std::flush;
        cout << "   will create mapping for " << map_index << " samples" << endl
             << std::flush;
      }
      assert(maps == NULL);
      assert(num_samples < 0);
      maps = new DocIdx[3 * map_index];
      num_samples = static_cast<int64_t>(map_index);
    }

  } // for (int iteration=0; iteration < 2; ++iteration) {

  // Shuffle.
  // We need a 64 bit random number generator as we might have more
  // than 2 billion samples.
  std::mt19937_64 rand64_gen(seed + 1);
  for (auto i = (num_samples - 1); i > 0; --i)
  {
    const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
    const auto i0 = 3 * i;
    const auto j0 = 3 * j;
    // Swap values.
    swap(maps[i0], maps[j0]);
    swap(maps[i0 + 1], maps[j0 + 1]);
    swap(maps[i0 + 2], maps[j0 + 2]);
  }

  // Method to deallocate memory.
  py::capsule free_when_done(maps, [](void *mem_)
                             {
            DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
	    delete[] mem; });

  // Return the numpy array.
  const auto byte_size = sizeof(DocIdx);
  return py::array(std::vector<int64_t>{num_samples, 3}, // shape
                   {3 * byte_size, byte_size},           // C-style contiguous strides
                   maps,                                 // the data pointer
                   free_when_done);                      // numpy array references
}

py::array build_mapping(const py::array_t<int64_t> &docs_,
                        const py::array_t<int> &sizes_,
                        const int num_epochs,
                        const uint64_t max_num_samples,
                        const int max_seq_length,
                        const double short_seq_prob,
                        const int seed,
                        const bool verbose,
                        const int32_t min_num_sent)
{

  if (sizes_.size() > std::numeric_limits<uint32_t>::max())
  {
    if (verbose)
    {
      cout << "    using uint64 for data mapping..." << endl
           << std::flush;
    }
    return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
                                        max_num_samples, max_seq_length,
                                        short_seq_prob, seed, verbose,
                                        min_num_sent);
  }
  else
  {
    if (verbose)
    {
      cout << "    using uint32 for data mapping..." << endl
           << std::flush;
    }
    return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
                                        max_num_samples, max_seq_length,
                                        short_seq_prob, seed, verbose,
                                        min_num_sent);
  }
}

template <typename DocIdx>
py::array build_blocks_mapping_impl(const py::array_t<int64_t> &docs_,
                                    const py::array_t<int32_t> &sizes_,
                                    const py::array_t<int32_t> &titles_sizes_,
                                    const int32_t num_epochs,
                                    const uint64_t max_num_samples,
                                    const int32_t max_seq_length,
                                    const int32_t seed,
                                    const bool verbose,
                                    const bool use_one_sent_blocks)
{
  /* Build a mapping of (start-index, end-index, sequence-length) where
     start and end index are the indices of the sentences in the sample
     and sequence-length is the target sequence length.
  */

  // Consistency checks.
  assert(num_epochs > 0);
  assert(max_seq_length > 1);
  assert(seed > 0);

  // Remove bound checks.
  auto docs = docs_.unchecked<1>();
  auto sizes = sizes_.unchecked<1>();
  auto titles_sizes = titles_sizes_.unchecked<1>();

  if (verbose)
  {
    const auto sent_start_index = docs[0];
    const auto sent_end_index = docs[docs_.shape(0) - 1];
    const auto num_sentences = sent_end_index - sent_start_index;
    cout << "    using:" << endl
         << std::flush;
    cout << "     number of documents:            " << docs_.shape(0) - 1 << endl
         << std::flush;
    cout << "     sentences range:                [" << sent_start_index << ", " << sent_end_index << ")" << endl
         << std::flush;
    cout << "     total number of sentences:      " << num_sentences << endl
         << std::flush;
    cout << "     number of epochs:               " << num_epochs << endl
         << std::flush;
    cout << "     maximum number of samples:      " << max_num_samples << endl
         << std::flush;
    cout << "     maximum sequence length:        " << max_seq_length << endl
         << std::flush;
    cout << "     seed:                           " << seed << endl
         << std::flush;
  }

  // Mapping and its length (1D).
  int64_t num_samples = -1;
  DocIdx *maps = NULL;

  // Acceptable number of sentences per block.
  int min_num_sent = 2;
  if (use_one_sent_blocks)
  {
    min_num_sent = 1;
  }

  // Perform two iterations, in the first iteration get the size
  // and allocate memory and in the second iteration populate the map.
  bool second = false;
  for (int32_t iteration = 0; iteration < 2; ++iteration)
  {

    // Set the flag on second iteration.
    second = (iteration == 1);

    // Current map index.
    uint64_t map_index = 0;

    uint64_t empty_docs = 0;
    uint64_t one_sent_docs = 0;
    uint64_t long_sent_docs = 0;
    // For each epoch:
    for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
    {
      // assign every block a unique id
      int32_t block_id = 0;

      if (map_index >= max_num_samples)
      {
        if (verbose && (!second))
        {
          cout << "    reached " << max_num_samples << " samples after "
               << epoch << " epochs ..." << endl
               << std::flush;
        }
        break;
      }
      // For each document:
      for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
      {

        // Document sentences are in [sent_index_first, sent_index_last)
        const auto sent_index_first = docs[doc];
        const auto sent_index_last = docs[doc + 1];
        const auto target_seq_len = max_seq_length - titles_sizes[doc];

        // At the begining of the document previous index is the
        // start index.
        auto prev_start_index = sent_index_first;

        // Remaining documents.
        auto num_remain_sent = sent_index_last - sent_index_first;

        // Some bookkeeping
        if ((epoch == 0) && (!second))
        {
          if (num_remain_sent == 0)
          {
            ++empty_docs;
          }
          if (num_remain_sent == 1)
          {
            ++one_sent_docs;
          }
        }
        // Detect documents with long sentences.
        bool contains_long_sentence = false;
        if (num_remain_sent >= min_num_sent)
        {
          for (auto sent_index = sent_index_first;
               sent_index < sent_index_last; ++sent_index)
          {
            if (sizes[sent_index] > LONG_SENTENCE_LEN)
            {
              if ((epoch == 0) && (!second))
              {
                ++long_sent_docs;
              }
              contains_long_sentence = true;
              break;
            }
          }
        }
        // If we have enough sentences and no long sentences.
        if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
        {

          // Set values.
          auto seq_len = int32_t{0};
          auto num_sent = int32_t{0};

          // Loop through sentences.
          for (auto sent_index = sent_index_first;
               sent_index < sent_index_last; ++sent_index)
          {

            // Add the size and number of sentences.
            seq_len += sizes[sent_index];
            ++num_sent;
            --num_remain_sent;

            // If we have reached the target length.
            // and there are an acceptable number of sentences left
            // and if we have at least the minimum number of sentences.
            // or if we have reached end of the document.
            if (((seq_len >= target_seq_len) &&
                 (num_remain_sent >= min_num_sent) &&
                 (num_sent >= min_num_sent)) ||
                (num_remain_sent == 0))
            {

              // Populate the map.
              if (second)
              {
                const auto map_index_0 = 4 * map_index;
                // Each sample has 4 items: the starting sentence index, ending sentence index,
                // the index of the document from which the block comes (used for fetching titles)
                // and the unique id of the block (used for creating block indexes)

                maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
                maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
                maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
                maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
              }

              // Update indices / counters.
              ++map_index;
              ++block_id;
              prev_start_index = sent_index + 1;
              seq_len = 0;
              num_sent = 0;
            }
          } // for (auto sent_index=sent_index_first; ...
        }   // if (num_remain_sent > 1) {
      }     // for (int doc=0; doc < num_docs; ++doc) {
    }       // for (int epoch=0; epoch < num_epochs; ++epoch) {

    if (!second)
    {
      if (verbose)
      {
        cout << "   number of empty documents: " << empty_docs << endl
             << std::flush;
        cout << "   number of documents with one sentence: " << one_sent_docs << endl
             << std::flush;
        cout << "   number of documents with long sentences: " << long_sent_docs << endl
             << std::flush;
        cout << "   will create mapping for " << map_index << " samples" << endl
             << std::flush;
      }
      assert(maps == NULL);
      assert(num_samples < 0);
      maps = new DocIdx[4 * map_index];
      num_samples = static_cast<int64_t>(map_index);
    }

  } // for (int iteration=0; iteration < 2; ++iteration) {

  // Shuffle.
  // We need a 64 bit random number generator as we might have more
  // than 2 billion samples.
  std::mt19937_64 rand64_gen(seed + 1);
  for (auto i = (num_samples - 1); i > 0; --i)
  {
    const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
    const auto i0 = 4 * i;
    const auto j0 = 4 * j;
    // Swap values.
    swap(maps[i0], maps[j0]);
    swap(maps[i0 + 1], maps[j0 + 1]);
    swap(maps[i0 + 2], maps[j0 + 2]);
    swap(maps[i0 + 3], maps[j0 + 3]);
  }

  // Method to deallocate memory.
  py::capsule free_when_done(maps, [](void *mem_)
                             {
            DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
	    delete[] mem; });

  // Return the numpy array.
  const auto byte_size = sizeof(DocIdx);
  return py::array(std::vector<int64_t>{num_samples, 4}, // shape
                   {4 * byte_size, byte_size},           // C-style contiguous strides
                   maps,                                 // the data pointer
                   free_when_done);                      // numpy array references
}

py::array build_blocks_mapping(const py::array_t<int64_t> &docs_,
                               const py::array_t<int> &sizes_,
                               const py::array_t<int> &titles_sizes_,
                               const int num_epochs,
                               const uint64_t max_num_samples,
                               const int max_seq_length,
                               const int seed,
                               const bool verbose,
                               const bool use_one_sent_blocks)
{

  if (sizes_.size() > std::numeric_limits<uint32_t>::max())
  {
    if (verbose)
    {
      cout << "    using uint64 for data mapping..." << endl
           << std::flush;
    }
    return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
                                               num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
  }
  else
  {
    if (verbose)
    {
      cout << "    using uint32 for data mapping..." << endl
           << std::flush;
    }
    return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
                                               num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
  }
}

PYBIND11_MODULE(helpers, m)
{
  m.def("build_mapping", &build_mapping);
  m.def("build_blocks_mapping", &build_blocks_mapping);
  m.def("build_sample_idx", &build_sample_idx);
  m.def("build_blending_indices", &build_blending_indices);
}
