#ifndef LS_USING_LI_H
#define LS_USING_LI_H

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iterator>
#include <numeric>
#include <vector>
#include <iostream>
#include <cassert>
#include <memory>
#include <execution>
#include <set>
#include <chrono> // for Timer
#include <unordered_map> // for Timer
#include "learned_index.h"
#include "splay_tree/SplayTree.h"

namespace learned_sort_using_learned_index {

class Timer {
public:
    void start(const std::string& name) {
        start_times[name] = std::chrono::steady_clock::now();
    }

    void stop(const std::string& name) {
        auto now = std::chrono::steady_clock::now();
        if (start_times.find(name) != start_times.end()) {
            auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(now - start_times[name]).count();
            total_times[name] += (double)elapsed / 1000000.0; // seconds
            total_counts[name]++;
            start_times.erase(name);
        } else {
            std::cerr << "Error: Timer for " << name << " was not started." << std::endl;
        }
    }

    std::unordered_map<std::string, double> get_total_times() const {
        return total_times;
    }

    void clear() {
        start_times.clear();
        total_times.clear();
        total_counts.clear();
    }

private:
    std::unordered_map<std::string, std::chrono::steady_clock::time_point> start_times;
    std::unordered_map<std::string, double> total_times;
    std::unordered_map<std::string, size_t> total_counts;
};

Timer timer;

void bucket_sort(
  std::vector<double>::iterator begin, std::vector<double>::iterator end,
  size_t bucket_num, const std::vector<size_t>& bucketIds
) {
  timer.start("calc bucketSizes");
  size_t m2 = std::distance(begin, end);

  std::vector<size_t> bucketSizes(bucket_num, 0);

  for (size_t i = 0; i < m2; ++i) {
    ++bucketSizes[bucketIds[i]];
  }
  timer.stop("calc bucketSizes");

  timer.start("calc bucketOffsets");
  // Compute bucket offsets (prefix sums)
  // e.g.) bucket_num = 4, bucketSizes = {2, 3, 2, 1} -> bucketOffsets = {0, 0, 2, 5, 7, 8}
  std::vector<size_t> bucketOffsets(bucket_num + 2);
  bucketOffsets[0] = 0;
  bucketOffsets[1] = 0;
  std::partial_sum(bucketSizes.begin(), bucketSizes.end(), bucketOffsets.begin() + 2);
  timer.stop("calc bucketOffsets");

  std::vector<double> bins(m2);
  timer.start("bucketing");
  for (size_t i = 0; i < m2; ++i) {
    size_t pos = bucketOffsets[bucketIds[i]+ 1]++;
    bins[pos] = *(begin + i);
  }
  timer.stop("bucketing");

  timer.start("sort each bucket");
  // Sort each bucket
  // e.g.) bucketOffsets = {0, 2, 5, 7, 8, 8}
  std::vector<size_t> bucket_indices(bucket_num);
  std::iota(bucket_indices.begin(), bucket_indices.end(), 0);
  std::for_each(
    std::execution::unseq,
    bucket_indices.begin(),
    bucket_indices.end(),
    [&](size_t i) {
      if (bucketSizes[i] > 1) {
        std::sort(
          std::execution::unseq,
          bins.begin() + bucketOffsets[i],
          bins.begin() + bucketOffsets[i + 1]
        );
      }
    }
  );
  timer.stop("sort each bucket");

  timer.start("insertion sort");
  // Insertion sort (with the upper bound of moves)
  int k = m2 * 32;
  for (size_t i = 1; i < m2; ++i) {
    if (bins[i - 1] > bins[i]) {
      double tmp = bins[i];
      int j = i;
      do {
        bins[j] = bins[j - 1];
      } while (--j > 0 && --k >= 0 && bins[j - 1] > tmp);
      bins[j] = tmp;

      if (k < 0) break;
    }
  }
  timer.stop("insertion sort");

  if (!std::is_sorted(bins.begin(), bins.end())) {
    timer.start("splay tree sort");
    FastSplayTree *sTree = new FastSplayTree();
    for (size_t i = 0; i < m2; ++i) {
      sTree -> insert(bins[i]);
    }
    std::vector<double> sorted_keys = sTree -> getSortedKeys();
    timer.stop("splay tree sort");

    if(sorted_keys.size() != m2) {
      std::cerr << "Error: The number of sorted keys is not correct." << std::endl;
      exit(1);
    } else if(!std::is_sorted(sorted_keys.begin(), sorted_keys.end())) {
      std::cerr << "Error: The sorted keys are not sorted." << std::endl;
      exit(1);
    }

    timer.start("copy back");
    std::copy(std::execution::unseq, sorted_keys.begin(), sorted_keys.end(), begin);
    timer.stop("copy back");
  } else {
    timer.start("copy back");
    std::copy(std::execution::unseq, bins.begin(), bins.end(), begin);
    timer.stop("copy back");
  }
}

void learned_sort_using_learned_index(std::vector<double>::iterator begin, std::vector<double>::iterator end, std::string index_type) {
  size_t n = std::distance(begin, end);
  if (n <= 1) return;  // No need to sort if the vector has 0 or 1 element
  if (n <= 128) {
    timer.start("std::sort");
    std::sort(std::execution::unseq, begin, end);
    timer.stop("std::sort");
    return;
  }

  size_t m1 = n / 32;

  // recursively sort the first half pf the input vector
  learned_sort_using_learned_index(begin, begin + m1, index_type);
  if (!is_sorted(begin, begin + m1)) {
    std::cerr << "Error: The first half of the input vector is not sorted." << std::endl;
    exit(1);
  }

  // construct learned index using the first half of the input vector
  std::vector<size_t> bucketIds(n - m1);
  bool exact_lower_bound = false;
  if (index_type == "binary_search") {
    timer.start("construct learned index");
    learned_sort_using_learned_index::Learned_Index_BinarySearch learned_index(begin, begin + m1);
    timer.stop("construct learned index");
    timer.start("query learned index");
    for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
    timer.stop("query learned index");
  } else if (index_type == "btree") {
    if (exact_lower_bound) {
      timer.start("construct learned index");
      learned_sort_using_learned_index::Learned_Index_BTree learned_index(begin, begin + m1);
      timer.stop("construct learned index");
      timer.start("query learned index");
      for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
      timer.stop("query learned index");
    } else {
      constexpr int epsilon = 64;
      std::vector<double> keys;
      for (size_t i = 0; i < m1; i += epsilon) keys.push_back(*(begin + i));

      timer.start("construct learned index");
      learned_sort_using_learned_index::Learned_Index_BTree learned_index(keys.begin(), keys.end());
      timer.stop("construct learned index");

      std::vector<size_t> approx_pos(n - m1);
      timer.start("query learned index");
      for (size_t i = 0; i < n - m1; ++i) approx_pos[i] = learned_index.lower_bound(*(begin + m1 + i));
      timer.stop("query learned index");

      for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = std::min(approx_pos[i] * epsilon, m1 - 1);
    }
  } else if (index_type == "espc") {
    timer.start("construct learned index");
    learned_sort_using_learned_index::Learned_Index_ESPC learned_index(begin, begin + m1);
    timer.stop("construct learned index");
    timer.start("query learned index");
    if (exact_lower_bound) {
      for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
    } else {
      for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.approx_pos(*(begin + m1 + i));
    }
    timer.stop("query learned index");
  } else if (index_type == "pgm") {
    timer.start("construct learned index");
    learned_sort_using_learned_index::Learned_Index_PGM learned_index(begin, begin + m1);
    timer.stop("construct learned index");
    timer.start("query learned index");
    if (exact_lower_bound) {
      for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
    } else {
      for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.approx_pos(*(begin + m1 + i));
    }
    timer.stop("query learned index");
  } else if (index_type == "rmi") {
    timer.start("construct learned index");
    learned_sort_using_learned_index::Learned_Index_RMI learned_index(begin, begin + m1, 1024);
    timer.stop("construct learned index");
    timer.start("query learned index");
    if (exact_lower_bound) {
      for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.lower_bound(*(begin + m1 + i));
    } else {
      for (size_t i = 0; i < n - m1; ++i) bucketIds[i] = learned_index.approx_pos(*(begin + m1 + i));
    }
    timer.stop("query learned index");
  } else {
    std::cerr << "Unknown index type: " << index_type << std::endl;
    return;
  }

  // test
  // {
  //   Learned_Index_BinarySearch binary_search(begin, begin + m1);
  //   Learned_Index_BTree btree(begin, begin + m1);
  //   Learned_Index_ESPC espc(begin, begin + m1);
  //   Learned_Index_PGM pgm(begin, begin + m1);
  //   Learned_Index_RMI rmi(begin, begin + m1, 128);
  //   std::vector<double> queries(begin + m1, end);
  //   for (double q : queries) {
  //     size_t bucketId_binary_search = binary_search.lower_bound(q);
  //     size_t bucketId_btree = btree.lower_bound(q);
  //     size_t bucketId_espc = espc.lower_bound(q);
  //     size_t bucketId_pgm = pgm.lower_bound(q);
  //     size_t bucketId_rmi = rmi.lower_bound(q);

  //     size_t ground_truth = std::lower_bound(begin, begin + m1, q) - begin;

  //     if (
  //       bucketId_binary_search != ground_truth ||
  //       bucketId_btree != ground_truth ||
  //       bucketId_espc != ground_truth || 
  //       bucketId_pgm != ground_truth || 
  //       bucketId_rmi != ground_truth
  //     ) {
  //       std::cerr << "[Error] for q = " << q;
  //       std::cerr << ", bucketId_binary_search = " << bucketId_binary_search << ", bucketId_btree = " << bucketId_btree;
  //       std::cerr << ", bucketId_espc = " << bucketId_espc << ", bucketId_pgm = " << bucketId_pgm << ", bucketId_rmi = " << bucketId_rmi;
  //       std::cerr << ", bucketId_rmi = " << bucketId_rmi << ", ground_truth = " << ground_truth << std::endl;
  //       // debug = true;
  //       // espc.lower_bound(q);
  //       // pgm.lower_bound(q);
  //       exit(1);
  //     }
  //   }
  // }

  if (exact_lower_bound) {
    for (size_t i = 0; i < n - m1; ++i) {
      if (bucketIds[i] == m1) {
        bucketIds[i] = 2 * bucketIds[i];  // range bucket
        continue;
      }
      if (*(begin + bucketIds[i]) == *(begin + m1 + i)) {
        bucketIds[i] = 2 * bucketIds[i] + 1;  // point bucket
      } else {
        bucketIds[i] = 2 * bucketIds[i];  // range bucket
      }
    }
    bucket_sort(begin + m1, end, 2 * m1 + 1, bucketIds);
  } else {
    bucket_sort(begin + m1, end, m1 + 1, bucketIds);
  }

  timer.start("merge");
  // Merge the two sorted halves
  std::vector<double> tmp(n);
  std::merge(std::execution::unseq, begin, begin + m1, begin + m1, end, tmp.begin());
  std::copy(std::execution::unseq, tmp.begin(), tmp.end(), begin);
  timer.stop("merge");
}

}  // namespace learned_sort_using_learned_index

#endif  // LS_USING_LI_H
