/*
 coding=utf-8
 Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */

/* Helper methods for fast index mapping builds */

#include <math.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <algorithm>
#include <iostream>
#include <limits>
#include <random>
#include <stdexcept>

namespace py = pybind11;
using namespace std;

const int32_t LONG_SENTENCE_LEN = 512;

void build_blending_indices(py::array_t<uint8_t>& dataset_index,
                            py::array_t<int64_t>& dataset_sample_index,
                            const py::array_t<double>& weights,
                            const int32_t num_datasets,
                            const int64_t size,
                            const bool verbose)
{
    /* Given multiple datasets and a weighting array, build samples
     such that it follows those weights.*/

    if (verbose) { std::cout << "> building indices for blendable datasets ..." << std::endl; }

    // Get the pointer access without the checks.
    auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
    auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
    auto weights_ptr = weights.unchecked<1>();

    // Initialize buffer for number of samples used for each dataset.
    int64_t current_samples[num_datasets];
    for (int64_t i = 0; i < num_datasets; ++i) { current_samples[i] = 0; }

    // For each sample:
    for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
        // Determine where the max error in sampling is happening.
        double sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
        int64_t max_error_index = 0;
        double max_error =
            weights_ptr[0] * sample_idx_double - static_cast<double>(current_samples[0]);
        for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
            double error = weights_ptr[dataset_idx] * sample_idx_double -
                           static_cast<double>(current_samples[dataset_idx]);
            if (error > max_error) {
                max_error = error;
                max_error_index = dataset_idx;
            }
        }

        // Populate the indices.
        dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
        dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];

        // Update the total samples.
        current_samples[max_error_index] += 1;
    }

    // print info
    if (verbose) {
        std::cout << " > sample ratios:" << std::endl;
        for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
            auto ratio =
                static_cast<double>(current_samples[dataset_idx]) / static_cast<double>(size);
            std::cout << "   dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx]
                      << ", achieved: " << ratio << std::endl;
        }
    }
}

py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
                           const py::array_t<int32_t>& doc_idx_,
                           const int32_t seq_length,
                           const int32_t num_epochs,
                           const int64_t tokens_per_epoch)
{
    /* Sample index (sample_idx) is used for gpt2 like dataset for which
       the documents are flattened and the samples are built based on this
       1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
       where [..., 0] contains the index into `doc_idx` and [..., 1] is the
       starting offset in that document.*/

    // Consistency checks.
    assert(seq_length > 1);
    assert(num_epochs > 0);
    assert(tokens_per_epoch > 1);

    // Remove bound checks.
    auto sizes = sizes_.unchecked<1>();
    auto doc_idx = doc_idx_.unchecked<1>();

    // Mapping and it's length (1D).
    int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
    int32_t* sample_idx = new int32_t[2 * (num_samples + 1)];

    cout << "    using:" << endl << std::flush;
    cout << "     number of documents:       " << doc_idx_.shape(0) / num_epochs << endl
         << std::flush;
    cout << "     number of epochs:          " << num_epochs << endl << std::flush;
    cout << "     sequence length:           " << seq_length << endl << std::flush;
    cout << "     total number of samples:   " << num_samples << endl << std::flush;

    // Index into sample_idx.
    int64_t sample_index = 0;
    // Index into doc_idx.
    int64_t doc_idx_index = 0;
    // Beginning offset for each document.
    int32_t doc_offset = 0;
    // Start with first document and no offset.
    sample_idx[2 * sample_index] = doc_idx_index;
    sample_idx[2 * sample_index + 1] = doc_offset;
    ++sample_index;

    while (sample_index <= num_samples) {
        // Start with a fresh sequence.
        int32_t remaining_seq_length = seq_length + 1;
        while (remaining_seq_length != 0) {
            // Get the document length.
            auto doc_id = doc_idx[doc_idx_index];
            auto doc_length = sizes[doc_id] - doc_offset;
            // And add it to the current sequence.
            remaining_seq_length -= doc_length;
            // If we have more than a full sequence, adjust offset and set
            // remaining length to zero so we return from the while loop.
            // Note that -1 here is for the same reason we have -1 in
            // `_num_epochs` calculations.
            if (remaining_seq_length <= 0) {
                doc_offset += (remaining_seq_length + doc_length - 1);
                remaining_seq_length = 0;
            } else {
                // Otherwise, start from the beginning of the next document.
                ++doc_idx_index;
                doc_offset = 0;
            }
        }
        // Record the sequence.
        sample_idx[2 * sample_index] = doc_idx_index;
        sample_idx[2 * sample_index + 1] = doc_offset;
        ++sample_index;
    }

    // Method to deallocate memory.
    py::capsule free_when_done(sample_idx, [](void* mem_) {
        int32_t* mem = reinterpret_cast<int32_t*>(mem_);
        delete[] mem;
    });

    // Return the numpy array.
    const auto byte_size = sizeof(int32_t);
    return py::array(std::vector<int64_t>{num_samples + 1, 2},  // shape
                     {2 * byte_size, byte_size},                // C-style contiguous strides
                     sample_idx,                                // the data pointer
                     free_when_done);                           // numpy array references
}

inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
                                     const int32_t max_length,
                                     std::mt19937& rand32_gen)
{
    /* Training sample length. */
    const auto random_number = rand32_gen();
    if ((random_number % short_seq_ratio) == 0) { return 2 + random_number % (max_length - 1); }
    return max_length;
}

template <typename DocIdx>
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
                             const py::array_t<int32_t>& sizes_,
                             const int32_t num_epochs,
                             const uint64_t max_num_samples,
                             const int32_t max_seq_length,
                             const double short_seq_prob,
                             const int32_t seed,
                             const bool verbose)
{
    /* Build a mapping of (start-index, end-index, sequence-length) where
       start and end index are the indices of the sentences in the sample
       and sequence-length is the target sequence length.
    */

    // Consistency checks.
    assert(num_epochs > 0);
    assert(max_seq_length > 1);
    assert(short_seq_prob > 0.0);
    assert(short_seq_prob <= 1.0);
    assert(seed > 0);

    // Remove bound checks.
    auto docs = docs_.unchecked<1>();
    auto sizes = sizes_.unchecked<1>();

    // For efficiency, convert probability to ratio. Note: rand() generates int.
    const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));

    if (verbose) {
        const auto sent_start_index = docs[0];
        const auto sent_end_index = docs[docs_.shape(0) - 1];
        const auto num_sentences = sent_end_index - sent_start_index;
        cout << "    using:" << endl << std::flush;
        cout << "     number of documents:            " << docs_.shape(0) - 1 << endl << std::flush;
        cout << "     sentences range:                [" << sent_start_index << ", "
             << sent_end_index << ")" << endl
             << std::flush;
        cout << "     total number of sentences:      " << num_sentences << endl << std::flush;
        cout << "     number of epochs:               " << num_epochs << endl << std::flush;
        cout << "     maximum number of samples:      " << max_num_samples << endl << std::flush;
        cout << "     maximum sequence length:        " << max_seq_length << endl << std::flush;
        cout << "     short sequence probability:     " << short_seq_prob << endl << std::flush;
        cout << "     short sequence ration (1/prob): " << short_seq_ratio << endl << std::flush;
        cout << "     seed:                           " << seed << endl << std::flush;
    }

    // Mapping and it's length (1D).
    int64_t num_samples = -1;
    DocIdx* maps = NULL;

    // Perform two iterations, in the first iteration get the size
    // and allocate memory and in the second iteration populate the map.
    bool second = false;
    for (int32_t iteration = 0; iteration < 2; ++iteration) {
        // Set the seed so both iterations produce the same results.
        std::mt19937 rand32_gen(seed);

        // Set the flag on second iteration.
        second = (iteration == 1);

        // Counters:
        uint64_t empty_docs = 0;
        uint64_t one_sent_docs = 0;
        uint64_t long_sent_docs = 0;

        // Current map index.
        uint64_t map_index = 0;

        // For each epoch:
        for (int32_t epoch = 0; epoch < num_epochs; ++epoch) {
            if (map_index >= max_num_samples) {
                if (verbose && (!second)) {
                    cout << "    reached " << max_num_samples << " samples after " << epoch
                         << " epochs ..." << endl
                         << std::flush;
                }
                break;
            }
            // For each document:
            for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) {
                // Document sentences are in [sent_index_first, sent_index_last)
                const auto sent_index_first = docs[doc];
                const auto sent_index_last = docs[doc + 1];

                // At the beginning of the document previous index is the
                // start index.
                auto prev_start_index = sent_index_first;

                // Remaining documents.
                auto num_remain_sent = sent_index_last - sent_index_first;

                // Some bookkeeping
                if ((epoch == 0) && (!second)) {
                    if (num_remain_sent == 0) { ++empty_docs; }
                    if (num_remain_sent == 1) { ++one_sent_docs; }
                }

                // Detect documents with long sentences.
                bool contains_long_sentence = false;
                if (num_remain_sent > 1) {
                    for (auto sent_index = sent_index_first; sent_index < sent_index_last;
                         ++sent_index) {
                        if (sizes[sent_index] > LONG_SENTENCE_LEN) {
                            if ((epoch == 0) && (!second)) { ++long_sent_docs; }
                            contains_long_sentence = true;
                            break;
                        }
                    }
                }

                // If we have more than two sentences.
                if ((num_remain_sent > 1) && (!contains_long_sentence)) {
                    // Set values.
                    auto seq_len = int32_t{0};
                    auto num_sent = int32_t{0};
                    auto target_seq_len =
                        get_target_sample_len(short_seq_ratio, max_seq_length, rand32_gen);

                    // Loop through sentences.
                    for (auto sent_index = sent_index_first; sent_index < sent_index_last;
                         ++sent_index) {
                        // Add the size and number of sentences.
                        seq_len += sizes[sent_index];
                        ++num_sent;
                        --num_remain_sent;

                        // If we have reached the target length.
                        // and if not only one sentence is left in the document.
                        // and if we have at least two sentneces.
                        // and if we have reached end of the document.
                        if (((seq_len >= target_seq_len) && (num_remain_sent > 1) &&
                             (num_sent > 1)) ||
                            (num_remain_sent == 0)) {
                            // Check for overflow.
                            if ((3 * map_index + 2) > std::numeric_limits<int64_t>::max()) {
                                cout << "number of samples exceeded maximum "
                                     << "allowed by type int64: "
                                     << std::numeric_limits<int64_t>::max() << endl;
                                throw std::overflow_error("Number of samples");
                            }

                            // Populate the map.
                            if (second) {
                                const auto map_index_0 = 3 * map_index;
                                maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
                                maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
                                maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
                            }

                            // Update indices / counters.
                            ++map_index;
                            prev_start_index = sent_index + 1;
                            target_seq_len =
                                get_target_sample_len(short_seq_ratio, max_seq_length, rand32_gen);
                            seq_len = 0;
                            num_sent = 0;
                        }

                    }  // for (auto sent_index=sent_index_first; ...
                }      // if (num_remain_sent > 1) {
            }          // for (int doc=0; doc < num_docs; ++doc) {
        }              // for (int epoch=0; epoch < num_epochs; ++epoch) {

        if (!second) {
            if (verbose) {
                cout << "   number of empty documents: " << empty_docs << endl << std::flush;
                cout << "   number of documents with one sentence: " << one_sent_docs << endl
                     << std::flush;
                cout << "   number of documents with long sentences: " << long_sent_docs << endl
                     << std::flush;
                cout << "   will create mapping for " << map_index << " samples" << endl
                     << std::flush;
            }
            assert(maps == NULL);
            assert(num_samples < 0);
            maps = new DocIdx[3 * map_index];
            num_samples = static_cast<int64_t>(map_index);
        }

    }  // for (int iteration=0; iteration < 2; ++iteration) {

    // Shuffle.
    // We need a 64 bit random number generator as we might have more
    // than 2 billion samples.
    std::mt19937_64 rand64_gen(seed + 1);
    for (auto i = (num_samples - 1); i > 0; --i) {
        const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
        const auto i0 = 3 * i;
        const auto j0 = 3 * j;
        // Swap values.
        swap(maps[i0], maps[j0]);
        swap(maps[i0 + 1], maps[j0 + 1]);
        swap(maps[i0 + 2], maps[j0 + 2]);
    }

    // Method to deallocate memory.
    py::capsule free_when_done(maps, [](void* mem_) {
        DocIdx* mem = reinterpret_cast<DocIdx*>(mem_);
        delete[] mem;
    });

    // Return the numpy array.
    const auto byte_size = sizeof(DocIdx);
    return py::array(std::vector<int64_t>{num_samples, 3},  // shape
                     {3 * byte_size, byte_size},            // C-style contiguous strides
                     maps,                                  // the data pointer
                     free_when_done);                       // numpy array references
}

py::array build_mapping(const py::array_t<int64_t>& docs_,
                        const py::array_t<int>& sizes_,
                        const int num_epochs,
                        const uint64_t max_num_samples,
                        const int max_seq_length,
                        const double short_seq_prob,
                        const int seed,
                        const bool verbose)
{
    if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
        if (verbose) { cout << "    using uint64 for data mapping..." << endl << std::flush; }
        return build_mapping_impl<uint64_t>(docs_,
                                            sizes_,
                                            num_epochs,
                                            max_num_samples,
                                            max_seq_length,
                                            short_seq_prob,
                                            seed,
                                            verbose);
    } else {
        if (verbose) { cout << "    using uint32 for data mapping..." << endl << std::flush; }
        return build_mapping_impl<uint32_t>(docs_,
                                            sizes_,
                                            num_epochs,
                                            max_num_samples,
                                            max_seq_length,
                                            short_seq_prob,
                                            seed,
                                            verbose);
    }
}

template <typename DocIdx>
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
                                    const py::array_t<int32_t>& sizes_,
                                    const py::array_t<int32_t>& titles_sizes_,
                                    const int32_t num_epochs,
                                    const uint64_t max_num_samples,
                                    const int32_t max_seq_length,
                                    const int32_t seed,
                                    const bool verbose,
                                    const bool use_one_sent_blocks)
{
    /* Build a mapping of (start-index, end-index, sequence-length) where
       start and end index are the indices of the sentences in the sample
       and sequence-length is the target sequence length.
    */

    // Consistency checks.
    assert(num_epochs > 0);
    assert(max_seq_length > 1);
    assert(seed > 0);

    // Remove bound checks.
    auto docs = docs_.unchecked<1>();
    auto sizes = sizes_.unchecked<1>();
    auto titles_sizes = titles_sizes_.unchecked<1>();

    if (verbose) {
        const auto sent_start_index = docs[0];
        const auto sent_end_index = docs[docs_.shape(0) - 1];
        const auto num_sentences = sent_end_index - sent_start_index;
        cout << "    using:" << endl << std::flush;
        cout << "     number of documents:            " << docs_.shape(0) - 1 << endl << std::flush;
        cout << "     sentences range:                [" << sent_start_index << ", "
             << sent_end_index << ")" << endl
             << std::flush;
        cout << "     total number of sentences:      " << num_sentences << endl << std::flush;
        cout << "     number of epochs:               " << num_epochs << endl << std::flush;
        cout << "     maximum number of samples:      " << max_num_samples << endl << std::flush;
        cout << "     maximum sequence length:        " << max_seq_length << endl << std::flush;
        cout << "     seed:                           " << seed << endl << std::flush;
    }

    // Mapping and its length (1D).
    int64_t num_samples = -1;
    DocIdx* maps = NULL;

    // Acceptable number of sentences per block.
    int min_num_sent = 2;
    if (use_one_sent_blocks) { min_num_sent = 1; }

    // Perform two iterations, in the first iteration get the size
    // and allocate memory and in the second iteration populate the map.
    bool second = false;
    for (int32_t iteration = 0; iteration < 2; ++iteration) {
        // Set the flag on second iteration.
        second = (iteration == 1);

        // Current map index.
        uint64_t map_index = 0;

        uint64_t empty_docs = 0;
        uint64_t one_sent_docs = 0;
        uint64_t long_sent_docs = 0;
        // For each epoch:
        for (int32_t epoch = 0; epoch < num_epochs; ++epoch) {
            // assign every block a unique id
            int32_t block_id = 0;

            if (map_index >= max_num_samples) {
                if (verbose && (!second)) {
                    cout << "    reached " << max_num_samples << " samples after " << epoch
                         << " epochs ..." << endl
                         << std::flush;
                }
                break;
            }
            // For each document:
            for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) {
                // Document sentences are in [sent_index_first, sent_index_last)
                const auto sent_index_first = docs[doc];
                const auto sent_index_last = docs[doc + 1];
                const auto target_seq_len = max_seq_length - titles_sizes[doc];

                // At the beginning of the document previous index is the
                // start index.
                auto prev_start_index = sent_index_first;

                // Remaining documents.
                auto num_remain_sent = sent_index_last - sent_index_first;

                // Some bookkeeping
                if ((epoch == 0) && (!second)) {
                    if (num_remain_sent == 0) { ++empty_docs; }
                    if (num_remain_sent == 1) { ++one_sent_docs; }
                }
                // Detect documents with long sentences.
                bool contains_long_sentence = false;
                if (num_remain_sent >= min_num_sent) {
                    for (auto sent_index = sent_index_first; sent_index < sent_index_last;
                         ++sent_index) {
                        if (sizes[sent_index] > LONG_SENTENCE_LEN) {
                            if ((epoch == 0) && (!second)) { ++long_sent_docs; }
                            contains_long_sentence = true;
                            break;
                        }
                    }
                }
                // If we have enough sentences and no long sentences.
                if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
                    // Set values.
                    auto seq_len = int32_t{0};
                    auto num_sent = int32_t{0};

                    // Loop through sentences.
                    for (auto sent_index = sent_index_first; sent_index < sent_index_last;
                         ++sent_index) {
                        // Add the size and number of sentences.
                        seq_len += sizes[sent_index];
                        ++num_sent;
                        --num_remain_sent;

                        // If we have reached the target length.
                        // and there are an acceptable number of sentences left
                        // and if we have at least the minimum number of sentences.
                        // or if we have reached end of the document.
                        if (((seq_len >= target_seq_len) && (num_remain_sent >= min_num_sent) &&
                             (num_sent >= min_num_sent)) ||
                            (num_remain_sent == 0)) {
                            // Populate the map.
                            if (second) {
                                const auto map_index_0 = 4 * map_index;
                                // Each sample has 4 items: the starting sentence index, ending
                                // sentence index, the index of the document from which the block
                                // comes (used for fetching titles) and the unique id of the block
                                // (used for creating block indexes)

                                maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
                                maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
                                maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
                                maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
                            }

                            // Update indices / counters.
                            ++map_index;
                            ++block_id;
                            prev_start_index = sent_index + 1;
                            seq_len = 0;
                            num_sent = 0;
                        }
                    }  // for (auto sent_index=sent_index_first; ...
                }      // if (num_remain_sent > 1) {
            }          // for (int doc=0; doc < num_docs; ++doc) {
        }              // for (int epoch=0; epoch < num_epochs; ++epoch) {

        if (!second) {
            if (verbose) {
                cout << "   number of empty documents: " << empty_docs << endl << std::flush;
                cout << "   number of documents with one sentence: " << one_sent_docs << endl
                     << std::flush;
                cout << "   number of documents with long sentences: " << long_sent_docs << endl
                     << std::flush;
                cout << "   will create mapping for " << map_index << " samples" << endl
                     << std::flush;
            }
            assert(maps == NULL);
            assert(num_samples < 0);
            maps = new DocIdx[4 * map_index];
            num_samples = static_cast<int64_t>(map_index);
        }

    }  // for (int iteration=0; iteration < 2; ++iteration) {

    // Shuffle.
    // We need a 64 bit random number generator as we might have more
    // than 2 billion samples.
    std::mt19937_64 rand64_gen(seed + 1);
    for (auto i = (num_samples - 1); i > 0; --i) {
        const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
        const auto i0 = 4 * i;
        const auto j0 = 4 * j;
        // Swap values.
        swap(maps[i0], maps[j0]);
        swap(maps[i0 + 1], maps[j0 + 1]);
        swap(maps[i0 + 2], maps[j0 + 2]);
        swap(maps[i0 + 3], maps[j0 + 3]);
    }

    // Method to deallocate memory.
    py::capsule free_when_done(maps, [](void* mem_) {
        DocIdx* mem = reinterpret_cast<DocIdx*>(mem_);
        delete[] mem;
    });

    // Return the numpy array.
    const auto byte_size = sizeof(DocIdx);
    return py::array(std::vector<int64_t>{num_samples, 4},  // shape
                     {4 * byte_size, byte_size},            // C-style contiguous strides
                     maps,                                  // the data pointer
                     free_when_done);                       // numpy array references
}

py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
                               const py::array_t<int>& sizes_,
                               const py::array_t<int>& titles_sizes_,
                               const int num_epochs,
                               const uint64_t max_num_samples,
                               const int max_seq_length,
                               const int seed,
                               const bool verbose,
                               const bool use_one_sent_blocks)
{
    if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
        if (verbose) { cout << "    using uint64 for data mapping..." << endl << std::flush; }
        return build_blocks_mapping_impl<uint64_t>(docs_,
                                                   sizes_,
                                                   titles_sizes_,
                                                   num_epochs,
                                                   max_num_samples,
                                                   max_seq_length,
                                                   seed,
                                                   verbose,
                                                   use_one_sent_blocks);
    } else {
        if (verbose) { cout << "    using uint32 for data mapping..." << endl << std::flush; }
        return build_blocks_mapping_impl<uint32_t>(docs_,
                                                   sizes_,
                                                   titles_sizes_,
                                                   num_epochs,
                                                   max_num_samples,
                                                   max_seq_length,
                                                   seed,
                                                   verbose,
                                                   use_one_sent_blocks);
    }
}

PYBIND11_MODULE(helpers, m)
{
    m.def("build_mapping", &build_mapping);
    m.def("build_blocks_mapping", &build_blocks_mapping);
    m.def("build_sample_idx", &build_sample_idx);
    m.def("build_blending_indices", &build_blending_indices);
}
