/* Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved. */

/* Helper methods for fast index mapping builds */

#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <random>
#include <stdexcept>

namespace py = pybind11;
using namespace std;

void build_blending_indices(py::array_t<int16_t> &dataset_index,
                            py::array_t<int64_t> &dataset_sample_index,
                            const py::array_t<double> &weights,
                            const int32_t num_datasets, const int64_t size,
                            const bool verbose) {
  /* Given multiple datasets and a weighting array, build samples
   such that it follows those wieghts.*/

  if (verbose) {
    std::cout << "> building indices for blended datasets ..." << std::endl;
  }

  // Get the pointer access without the checks.
  auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
  auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
  auto weights_ptr = weights.unchecked<1>();

  // Initialize buffer for number of samples used for each dataset.
  int64_t current_samples[num_datasets];
  for (int64_t i = 0; i < num_datasets; ++i) {
    current_samples[i] = 0;
  }

  // For each sample:
  for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
    // Determine where the max error in sampling is happening.
    auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
    int64_t max_error_index = 0;
    double max_error = weights_ptr[0] * sample_idx_double -
                       static_cast<double>(current_samples[0]);
    for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
      double error = weights_ptr[dataset_idx] * sample_idx_double -
                     static_cast<double>(current_samples[dataset_idx]);
      if (error > max_error) {
        max_error = error;
        max_error_index = dataset_idx;
      }
    }

    // Populate the indices.
    dataset_index_ptr[sample_idx] = static_cast<int16_t>(max_error_index);
    dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];

    // Update the total samples.
    current_samples[max_error_index] += 1;
  }

  // print info
  if (verbose) {
    std::cout << " > sample ratios:" << std::endl;
    for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
      auto ratio = static_cast<double>(current_samples[dataset_idx]) /
                   static_cast<double>(size);
      std::cout << "   dataset " << dataset_idx
                << ", input: " << weights_ptr[dataset_idx]
                << ", achieved: " << ratio << std::endl;
    }
  }
}

py::array build_sample_idx_int32(const py::array_t<int32_t> &sizes_,
                                 const py::array_t<int32_t> &doc_idx_,
                                 const int32_t seq_length,
                                 const int32_t num_epochs,
                                 const int64_t tokens_per_epoch) {
  /* Sample index (sample_idx) is used for gpt2 like dataset for which
     the documents are flattened and the samples are built based on this
     1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
     where [..., 0] contains the index into `doc_idx` and [..., 1] is the
     starting offset in that document.*/

  // Consistency checks.
  assert(seq_length > 1);
  assert(num_epochs > 0);
  assert(tokens_per_epoch > 1);

  // Remove bound checks.
  auto sizes = sizes_.unchecked<1>();
  auto doc_idx = doc_idx_.unchecked<1>();

  // Mapping and it's length (1D).
  int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
  int32_t *sample_idx = new int32_t[2 * (num_samples + 1)];

  // Index into sample_idx.
  int64_t sample_index = 0;
  // Index into doc_idx.
  int64_t doc_idx_index = 0;
  // Begining offset for each document.
  int32_t doc_offset = 0;
  // Start with first document and no offset.
  sample_idx[2 * sample_index] = doc_idx_index;
  sample_idx[2 * sample_index + 1] = doc_offset;
  ++sample_index;

  while (sample_index <= num_samples) {
    // Start with a fresh sequence.
    int32_t remaining_seq_length = seq_length + 1;
    while (remaining_seq_length != 0) {
      // Get the document length.
      auto doc_id = doc_idx[doc_idx_index];
      auto doc_length = sizes[doc_id] - doc_offset;
      // And add it to the current sequence.
      remaining_seq_length -= doc_length;
      // If we have more than a full sequence, adjust offset and set
      // remaining length to zero so we return from the while loop.
      // Note that -1 here is for the same reason we have -1 in
      // `_num_epochs` calculations.
      if (remaining_seq_length <= 0) {
        doc_offset += (remaining_seq_length + doc_length - 1);
        remaining_seq_length = 0;
      } else {
        // Otherwise, start from the begining of the next document.
        ++doc_idx_index;
        doc_offset = 0;
      }
    }
    // Record the sequence.
    sample_idx[2 * sample_index] = doc_idx_index;
    sample_idx[2 * sample_index + 1] = doc_offset;
    ++sample_index;
  }

  // Method to deallocate memory.
  py::capsule free_when_done(sample_idx, [](void *mem_) {
    int32_t *mem = reinterpret_cast<int32_t *>(mem_);
    delete[] mem;
  });

  // Return the numpy array.
  const auto byte_size = sizeof(int32_t);
  return py::array(std::vector<int64_t>{num_samples + 1, 2}, // shape
                   {2 * byte_size, byte_size}, // C-style contiguous strides
                   sample_idx,                 // the data pointer
                   free_when_done);            // numpy array references
}

py::array build_sample_idx_int64(const py::array_t<int32_t> &sizes_,
                                 const py::array_t<int64_t> &doc_idx_,
                                 const int32_t seq_length,
                                 const int32_t num_epochs,
                                 const int64_t tokens_per_epoch) {
  /* Sample index (sample_idx) is used for gpt2 like dataset for which
     the documents are flattened and the samples are built based on this
     1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
     where [..., 0] contains the index into `doc_idx` and [..., 1] is the
     starting offset in that document.*/

  // Consistency checks.
  assert(seq_length > 1);
  assert(num_epochs > 0);
  assert(tokens_per_epoch > 1);

  // Remove bound checks.
  auto sizes = sizes_.unchecked<1>();
  auto doc_idx = doc_idx_.unchecked<1>();

  // Mapping and it's length (1D).
  int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
  int64_t *sample_idx = new int64_t[2 * (num_samples + 1)];

  // Index into sample_idx.
  int64_t sample_index = 0;
  // Index into doc_idx.
  int64_t doc_idx_index = 0;
  // Begining offset for each document.
  int64_t doc_offset = 0;
  // Start with first document and no offset.
  sample_idx[2 * sample_index] = doc_idx_index;
  sample_idx[2 * sample_index + 1] = doc_offset;
  ++sample_index;

  while (sample_index <= num_samples) {
    // Start with a fresh sequence.
    int64_t remaining_seq_length = seq_length + 1;
    while (remaining_seq_length != 0) {
      // Get the document length.
      auto doc_id = doc_idx[doc_idx_index];
      auto doc_length = static_cast<int64_t>(sizes[doc_id]) - doc_offset;
      // And add it to the current sequence.
      remaining_seq_length -= doc_length;
      // If we have more than a full sequence, adjust offset and set
      // remaining length to zero so we return from the while loop.
      // Note that -1 here is for the same reason we have -1 in
      // `_num_epochs` calculations.
      if (remaining_seq_length <= 0) {
        doc_offset += (remaining_seq_length + doc_length - 1);
        remaining_seq_length = 0;
      } else {
        // Otherwise, start from the begining of the next document.
        ++doc_idx_index;
        doc_offset = 0;
      }
    }
    // Record the sequence.
    sample_idx[2 * sample_index] = doc_idx_index;
    sample_idx[2 * sample_index + 1] = doc_offset;
    ++sample_index;
  }

  // Method to deallocate memory.
  py::capsule free_when_done(sample_idx, [](void *mem_) {
    int32_t *mem = reinterpret_cast<int32_t *>(mem_);
    delete[] mem;
  });

  // Return the numpy array.
  const auto byte_size = sizeof(int64_t);
  return py::array(std::vector<int64_t>{num_samples + 1, 2}, // shape
                   {2 * byte_size, byte_size}, // C-style contiguous strides
                   sample_idx,                 // the data pointer
                   free_when_done);            // numpy array references
}

PYBIND11_MODULE(helpers, m) {
  m.def("build_sample_idx_int32", &build_sample_idx_int32);
  m.def("build_sample_idx_int64", &build_sample_idx_int64);
  m.def("build_blending_indices", &build_blending_indices);
}
