#ifndef LI_H
#define LI_H

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iterator>
#include <numeric>
#include <vector>
#include <iostream>
#include <cassert>
#include <memory>
#include <execution>
#include <set>
#include "btree/btree_multimap.h"
#include "pgm/pgm_index.hpp"
#include "rmi/rmi.hpp"

namespace learned_sort_using_learned_index {

template <typename T>
class Learned_Index {
public:
  virtual ~Learned_Index() = default;
  virtual size_t lower_bound(const T& q) const = 0;
  virtual size_t approx_pos(const T& q) const = 0;
};

template <typename T>
class Learned_Index_BinarySearch : public Learned_Index<T> {
public:
  Learned_Index_BinarySearch(typename std::vector<T>::iterator begin, typename std::vector<T>::iterator end)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)) {
    assert(n_ > 0);
  }

  size_t lower_bound(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return std::lower_bound(data_.begin(), data_.end(), q) - data_.begin();
  }

  size_t approx_pos(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return std::lower_bound(data_.begin(), data_.end(), q) - data_.begin();
  }

private:
  const std::vector<T> data_;
  const size_t n_;
  const T x_min_, x_max_;
};

template <typename T>
class Learned_Index_BTree : public Learned_Index<T> {
public:
  Learned_Index_BTree(typename std::vector<T>::iterator begin, typename std::vector<T>::iterator end)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)) {
    assert(n_ > 0);
    std::vector<std::pair<T, size_t>> data_with_index(n_);
    for (size_t i = 0; i < n_; ++i) {
      data_with_index[i] = std::make_pair(data_[i], i);
    }
    btree_.bulk_load(data_with_index.begin(), data_with_index.end());
  }

  size_t lower_bound(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return btree_.lower_bound(q)->second;
  }

  size_t approx_pos(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return btree_.lower_bound(q)->second;
  }
private:
  const std::vector<T> data_;
  const size_t n_;
  const T x_min_, x_max_;
  stx::btree_multimap<T, size_t> btree_;
};

template <typename T>
class Learned_Index_ESPC : public Learned_Index<T> {
public:
  Learned_Index_ESPC(typename std::vector<T>::iterator begin, typename std::vector<T>::iterator end)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)), 
      scale_factor_((n_ - 1) / (static_cast<long double>(x_max_ - x_min_))) {
    assert(n_ > 0);
    counts.resize(n_);
    for (size_t i = 0; i < n_; ++i) {
      counts[static_cast<size_t>((data_[i] - x_min_) * scale_factor_)]++;
    }
    std::partial_sum(counts.begin(), counts.end(), counts.begin());
  }

  size_t lower_bound(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    size_t predicted_index = static_cast<size_t>((q - x_min_) * scale_factor_);
    if (predicted_index == 0) {
      return std::lower_bound(data_.begin(), data_.begin() + counts[predicted_index], q) - data_.begin();
    } else {
      return std::lower_bound(data_.begin() + counts[predicted_index - 1], data_.begin() + counts[predicted_index], q) - data_.begin();
    }
  }

  size_t approx_pos(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    size_t predicted_index = static_cast<size_t>((q - x_min_) * scale_factor_);
    return counts[predicted_index];
  }

private:
  const std::vector<T> data_;
  const size_t n_;
  const T x_min_, x_max_;
  const long double scale_factor_;
  std::vector<size_t> counts;
};

template <typename T>
class Learned_Index_PGM : public Learned_Index<T> {
public:
  Learned_Index_PGM(typename std::vector<T>::iterator begin, typename std::vector<T>::iterator end)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)), pgm_index_(begin, end) {
    assert(n_ > 0);
  }

  size_t lower_bound(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    pgm::ApproxPos pos = pgm_index_.search(q);
    return std::lower_bound(data_.begin() + pos.lo, data_.begin() + pos.hi, q) - data_.begin();
  }

  size_t approx_pos(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return pgm_index_.search(q).lo;
  }

private:
  const std::vector<T> data_;
  const size_t n_;
  const T x_min_, x_max_;
  const pgm::PGMIndex<T> pgm_index_;
};

template <typename T>
class Learned_Index_PGM_Epsilon : public Learned_Index<T> {
public:
  Learned_Index_PGM_Epsilon(typename std::vector<T>::iterator begin, typename std::vector<T>::iterator end, size_t epsilon)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)), epsilon(epsilon) {
    assert(n_ > 0);
    switch (epsilon) {
      case 1:
        pgm_index_1 = pgm::PGMIndex<T, 1>(begin, end);
        break;
      case 2:
        pgm_index_2 = pgm::PGMIndex<T, 2>(begin, end);
        break;
      case 4:
        pgm_index_4 = pgm::PGMIndex<T, 4>(begin, end);
        break;
      case 8:
        pgm_index_8 = pgm::PGMIndex<T, 8>(begin, end);
        break;
      case 16:
        pgm_index_16 = pgm::PGMIndex<T, 16>(begin, end);
        break;
      case 32:
        pgm_index_32 = pgm::PGMIndex<T, 32>(begin, end);
        break;
      case 64:
        pgm_index_64 = pgm::PGMIndex<T, 64>(begin, end);
        break;
      case 128:
        pgm_index_128 = pgm::PGMIndex<T, 128>(begin, end);
        break;
      case 256:
        pgm_index_256 = pgm::PGMIndex<T, 256>(begin, end);
        break;
      case 512:
        pgm_index_512 = pgm::PGMIndex<T, 512>(begin, end);
        break;
      case 1024:
        pgm_index_1024 = pgm::PGMIndex<T, 1024>(begin, end);
        break;
      case 2048:
        pgm_index_2048 = pgm::PGMIndex<T, 2048>(begin, end);
        break;
      case 4096:
        pgm_index_4096 = pgm::PGMIndex<T, 4096>(begin, end);
        break;
      case 8192:
        pgm_index_8192 = pgm::PGMIndex<T, 8192>(begin, end);
        break;
      default:
        std::cerr << "Unknown epsilon: " << epsilon << std::endl;
        exit(1);
    }
  }

  size_t lower_bound(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    pgm::ApproxPos pos;
    switch (epsilon) {
      case 1:
        pos = pgm_index_1.search(q);
        break;
      case 2:
        pos = pgm_index_2.search(q);
        break;
      case 4:
        pos = pgm_index_4.search(q);
        break;
      case 8:
        pos = pgm_index_8.search(q);
        break;
      case 16:
        pos = pgm_index_16.search(q);
        break;
      case 32:
        pos = pgm_index_32.search(q);
        break;
      case 64:
        pos = pgm_index_64.search(q);
        break;
      case 128:
        pos = pgm_index_128.search(q);
        break;
      case 256:
        pos = pgm_index_256.search(q);
        break;
      case 512:
        pos = pgm_index_512.search(q);
        break;
      case 1024:
        pos = pgm_index_1024.search(q);
        break;
      case 2048:
        pos = pgm_index_2048.search(q);
        break;
      case 4096:
        pos = pgm_index_4096.search(q);
        break;
      case 8192:
        pos = pgm_index_8192.search(q);
        break;
      default:
        std::cerr << "Unknown epsilon: " << epsilon << std::endl;
        exit(1);
    }
    return std::lower_bound(data_.begin() + pos.lo, data_.begin() + pos.hi, q) - data_.begin();
  }

  size_t approx_pos(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    pgm::ApproxPos pos;
    switch (epsilon) {
      case 1:
        pos = pgm_index_1.search(q);
        break;
      case 2:
        pos = pgm_index_2.search(q);
        break;
      case 4:
        pos = pgm_index_4.search(q);
        break;
      case 8:
        pos = pgm_index_8.search(q);
        break;
      case 16:
        pos = pgm_index_16.search(q);
        break;
      case 32:
        pos = pgm_index_32.search(q);
        break;
      case 64:
        pos = pgm_index_64.search(q);
        break;
      case 128:
        pos = pgm_index_128.search(q);
        break;
      case 256:
        pos = pgm_index_256.search(q);
        break;
      case 512:
        pos = pgm_index_512.search(q);
        break;
      case 1024:
        pos = pgm_index_1024.search(q);
        break;
      case 2048:
        pos = pgm_index_2048.search(q);
        break;
      case 4096:
        pos = pgm_index_4096.search(q);
        break;
      case 8192:
        pos = pgm_index_8192.search(q);
        break;
      default:
        std::cerr << "Unknown epsilon: " << epsilon << std::endl;
        exit(1);
    }
    return (pos.lo + pos.hi) / 2;
  }

private:
  const std::vector<T> data_;
  const size_t n_;
  const T x_min_, x_max_;
  size_t segment_size_;
  size_t epsilon;
  pgm::PGMIndex<T, 1> pgm_index_1;
  pgm::PGMIndex<T, 2> pgm_index_2;
  pgm::PGMIndex<T, 4> pgm_index_4;
  pgm::PGMIndex<T, 8> pgm_index_8;
  pgm::PGMIndex<T, 16> pgm_index_16;
  pgm::PGMIndex<T, 32> pgm_index_32;
  pgm::PGMIndex<T, 64> pgm_index_64;
  pgm::PGMIndex<T, 128> pgm_index_128;
  pgm::PGMIndex<T, 256> pgm_index_256;
  pgm::PGMIndex<T, 512> pgm_index_512;
  pgm::PGMIndex<T, 1024> pgm_index_1024;
  pgm::PGMIndex<T, 2048> pgm_index_2048;
  pgm::PGMIndex<T, 4096> pgm_index_4096;
  pgm::PGMIndex<T, 8192> pgm_index_8192;
};

template <typename T>
class Learned_Index_RMI : public Learned_Index<T> {
public:
  Learned_Index_RMI(typename std::vector<T>::iterator begin, typename std::vector<T>::iterator end, size_t layer2_size)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)),
      rmi_index_(begin, end, layer2_size) {
    assert(n_ > 0);
  }

  size_t lower_bound(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    rmi::Approx pos = rmi_index_.search(q);
    return std::lower_bound(data_.begin() + pos.lo, data_.begin() + pos.hi, q) - data_.begin();
  }

  size_t approx_pos(const T& q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return rmi_index_.search(q).pos;
  }

private:
  const std::vector<T> data_;
  const size_t n_;
  const T x_min_, x_max_;
  const rmi::Rmi<T, rmi::LinearRegression, rmi::LinearRegression> rmi_index_;
};

}  // namespace learned_sort_using_learned_index

#endif  // LI_H
