#ifndef LS_USING_LI_H
#define LS_USING_LI_H

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iterator>
#include <numeric>
#include <vector>
#include <iostream>
#include <cassert>
#include <memory>
#include <execution>
#include <set>
#include "learned_index.h"
#include "../counter.h"
#include "../merge_sort/merge_sort.h"

namespace learned_sort_using_learned_index {

template <typename RandomIt>
void bucket_sort(
  RandomIt begin, RandomIt end,
  size_t bucket_num, const std::vector<size_t>& bucketIds
) {
  using T = typename std::iterator_traits<RandomIt>::value_type;
  size_t m2 = std::distance(begin, end);

  std::vector<size_t> bucketSizes(bucket_num, 0);

  for (size_t i = 0; i < m2; ++i) {
    ++bucketSizes[bucketIds[i]];
  }

  // Compute bucket offsets (prefix sums)
  // e.g.) bucket_num = 4, bucketSizes = {2, 3, 2, 1} -> bucketOffsets = {0, 0, 2, 5, 7, 8}
  std::vector<size_t> bucketOffsets(bucket_num + 2);
  bucketOffsets[0] = 0;
  bucketOffsets[1] = 0;
  std::partial_sum(bucketSizes.begin(), bucketSizes.end(), bucketOffsets.begin() + 2);

  std::vector<T> bins(m2);
  for (size_t i = 0; i < m2; ++i) {
    size_t pos = bucketOffsets[bucketIds[i] + 1]++;
    bins[pos] = *(begin + i);
  }

  // Sort each bucket
  // e.g.) bucketOffsets = {0, 2, 5, 7, 8, 8}
  for (size_t i = 0; i < bucket_num; ++i) {
    merge_sort::merge_sort(bins.begin() + bucketOffsets[i], bins.begin() + bucketOffsets[i + 1]);
  }

  if (!std::is_sorted(bins.begin(), bins.end())) {
    std::cerr << "Error: The bins are not sorted." << std::endl;
    exit(1);
  }

  // // Insertion sort
  // for (size_t i = 1; i < m2; ++i) {
  //   if (counter::counter.increment("insertion sort") && bins[i - 1] > bins[i]) {
  //     T tmp = bins[i];
  //     size_t j = i;
  //     do {
  //       bins[j] = bins[j - 1];
  //     } while (counter::counter.increment("insertion sort") && --j > 0 && bins[j - 1] > tmp);
  //     bins[j] = tmp;
  //   }
  // }

  std::copy(std::execution::unseq, bins.begin(), bins.end(), begin);
}

template <typename RandomIt, typename T>
size_t exponential_search(RandomIt begin, RandomIt end, const T& val, size_t pos) {
  size_t n = std::distance(begin, end);
  if (n == 0) return 0;
  if (pos < 0 || n <= pos) {
    std::cerr << "Error: pos is out of range; n = " << n << ", pos = " << pos << std::endl;
  }
  assert(0 <= pos && pos < n);

  auto binary_search_with_counter = [&](size_t lo, size_t hi) {
    while (lo < hi) {
      size_t mid = lo + (hi - lo) / 2;
      if (counter::counter.increment("exponential search") && *(begin + mid) < val) {
        lo = mid + 1;
      } else {
        hi = mid;
      }
    }
    return lo;
  };

  if (counter::counter.increment("exponential search") && *(begin + pos) < val) {
    size_t step = 1;
    while (pos + step < n && counter::counter.increment("exponential search") && *(begin + pos + step) < val) {
      step <<= 1;
    }
    size_t lo = pos + (step >> 1);
    size_t hi = std::min(pos + step, n);
    return binary_search_with_counter(lo, hi);
  } else {
    size_t step = 1;
    while (pos >= step && counter::counter.increment("exponential search") && *(begin + (pos - step)) >= val) {
      step <<= 1;
    }
    size_t lo = (pos >= step) ? (pos - step) : 0;
    size_t hi = pos - (step >> 1);
    return binary_search_with_counter(lo, hi);
  }
}

template <typename RandomIt>
void learned_sort_using_learned_index(RandomIt begin, RandomIt end, std::string index_type, int epsilon = -1) {
  using T = typename std::iterator_traits<RandomIt>::value_type;
  size_t n = std::distance(begin, end);
  if (n <= 1) return;  // No need to sort if the vector has 0 or 1 element
  if (n <= 128) {
    std::sort(std::execution::unseq, begin, end);
    return;
  }

  size_t m1 = n / 2;

  // recursively sort the first half pf the input vector
  learned_sort_using_learned_index(begin, begin + m1, index_type, epsilon);
  if (!std::is_sorted(begin, begin + m1)) {
    std::cerr << "Error: The first half of the input vector is not sorted." << std::endl;
    exit(1);
  }

  // construct learned index using the first half of the input vector
  std::vector<size_t> bucketIds(n - m1);
  if (index_type == "binary_search") {
    learned_sort_using_learned_index::Learned_Index_BinarySearch<T> learned_index(begin, begin + m1);
    for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
  } else if (index_type == "btree") {
    learned_sort_using_learned_index::Learned_Index_BTree<T> learned_index(begin, begin + m1);
    for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
  } else if (index_type == "espc") {
    learned_sort_using_learned_index::Learned_Index_ESPC<T> learned_index(begin, begin + m1);
    for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
  } else if (index_type == "pgm") {
    learned_sort_using_learned_index::Learned_Index_PGM<T> learned_index(begin, begin + m1);
    for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
  } else if (index_type == "btree_approx") {
    assert(epsilon > 0);
    std::vector<T> keys;
    for (size_t i = 0; i < m1; i += epsilon) keys.push_back(*(begin + i));
    learned_sort_using_learned_index::Learned_Index_BTree<T> learned_index(keys.begin(), keys.end());
    std::vector<size_t> approx_pos(n - m1);
    for (size_t i = 0; i < n - m1; ++i) approx_pos[i] = std::min(learned_index.lower_bound(*(begin + m1 + i)) * epsilon, m1 - 1);
    for (size_t i = 0; i < n - m1; ++i) {
      bucketIds[i] = exponential_search(begin, begin + m1, *(begin + m1 + i), approx_pos[i]);
    }
  } else if (index_type == "pgm_approx") {
    assert(epsilon > 0);
    learned_sort_using_learned_index::Learned_Index_PGM_Epsilon<T> learned_index(begin, begin + m1, epsilon);
    std::vector<size_t> approx_pos(n - m1);
    for (size_t i = 0; i < n - m1; ++i) approx_pos[i] = std::min(learned_index.approx_pos(*(begin + m1 + i)), m1 - 1);
    for (size_t i = 0; i < n - m1; ++i) {
      bucketIds[i] = exponential_search(begin, begin + m1, *(begin + m1 + i), approx_pos[i]);
    }
  } else if (index_type == "rmi") {
    learned_sort_using_learned_index::Learned_Index_RMI<T> learned_index(begin, begin + m1, 128);
    for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
  } else {
    std::cerr << "Unknown index type: " << index_type << std::endl;
    return;
  }

  for (size_t i = 0; i < n - m1; ++i) {
    if (bucketIds[i] == m1) {
      bucketIds[i] = 2 * bucketIds[i];  // range bucket
      continue;
    }
    if (counter::counter.increment("point bucket") && *(begin + bucketIds[i]) == *(begin + m1 + i)) {
      bucketIds[i] = 2 * bucketIds[i] + 1;  // point bucket
    } else {
      bucketIds[i] = 2 * bucketIds[i];  // range bucket
    }
  }
  bucket_sort(begin + m1, end, 2 * m1 + 1, bucketIds);

  // Merge the two sorted halves
  std::vector<T> tmp(n);
  std::merge(std::execution::unseq, begin, begin + m1, begin + m1, end, tmp.begin());
  std::copy(std::execution::unseq, tmp.begin(), tmp.end(), begin);
}

}  // namespace learned_sort_using_learned_index

#endif  // LS_USING_LI_H
