#ifndef LI_H
#define LI_H

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iterator>
#include <numeric>
#include <vector>
#include <iostream>
#include <cassert>
#include <memory>
#include <execution>
#include <set>
#include "btree/btree_multimap.h"
#include "pgm/pgm_index.hpp"
#include "rmi/rmi.hpp"

namespace learned_sort_using_learned_index {

class Learned_Index {
public:
  virtual ~Learned_Index() = default;
  virtual size_t lower_bound(double q) const = 0;
  virtual size_t approx_pos(double q) const = 0;
};

class Learned_Index_BinarySearch : public Learned_Index {
public:
  Learned_Index_BinarySearch(std::vector<double>::iterator begin, std::vector<double>::iterator end)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)) {
    assert(n_ > 0);
  }

  size_t lower_bound(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return std::lower_bound(data_.begin(), data_.end(), q) - data_.begin();
  }

  size_t approx_pos(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return std::lower_bound(data_.begin(), data_.end(), q) - data_.begin();
  }

private:
  const std::vector<double> data_;
  const size_t n_;
  const double x_min_, x_max_;
};

class Learned_Index_BTree : public Learned_Index {
public:
  Learned_Index_BTree(std::vector<double>::iterator begin, std::vector<double>::iterator end)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)) {
    assert(n_ > 0);
    std::vector<std::pair<double, size_t>> data_with_index(n_);
    for (size_t i = 0; i < n_; ++i) {
      data_with_index[i] = std::make_pair(data_[i], i);
    }
    btree_.bulk_load(data_with_index.begin(), data_with_index.end());
  }

  size_t lower_bound(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return btree_.lower_bound(q)->second;
  }

  size_t approx_pos(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return btree_.lower_bound(q)->second;
  }
private:
  const std::vector<double> data_;
  const size_t n_;
  const double x_min_, x_max_;
  stx::btree_multimap<double, size_t> btree_;
};

class Learned_Index_ESPC : public Learned_Index {
public:
  Learned_Index_ESPC(std::vector<double>::iterator begin, std::vector<double>::iterator end)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)), scale_factor_((n_ - 1) / (x_max_ - x_min_)) {
    assert(n_ > 0);
    counts.resize(n_);
    for (size_t i = 0; i < n_; ++i) {
      counts[static_cast<size_t>((data_[i] - x_min_) * scale_factor_)]++;
    }
    std::partial_sum(counts.begin(), counts.end(), counts.begin());
  }

  size_t lower_bound(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    size_t predicted_index = static_cast<size_t>((q - x_min_) * scale_factor_);
    if (predicted_index == 0) {
      return std::lower_bound(data_.begin(), data_.begin() + counts[predicted_index], q) - data_.begin();
    } else {
      return std::lower_bound(data_.begin() + counts[predicted_index - 1], data_.begin() + counts[predicted_index], q) - data_.begin();
    }
  }

  size_t approx_pos(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    size_t predicted_index = static_cast<size_t>((q - x_min_) * scale_factor_);
    return counts[predicted_index];
  }

private:
  const std::vector<double> data_;
  const size_t n_;
  const double x_min_, x_max_, scale_factor_;
  std::vector<size_t> counts;
};

class Learned_Index_PGM : public Learned_Index {
public:
  Learned_Index_PGM(std::vector<double>::iterator begin, std::vector<double>::iterator end)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)), pgm_index_(begin, end) {
    assert(n_ > 0);
  }

  size_t lower_bound(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    pgm::ApproxPos pos = pgm_index_.search(q);

    return std::lower_bound(data_.begin() + pos.lo, data_.begin() + pos.hi, q) - data_.begin();
  }

  size_t approx_pos(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return pgm_index_.search(q).lo;
  }

private:
  const std::vector<double> data_;
  const size_t n_;
  const double x_min_, x_max_;
  const pgm::PGMIndex<double> pgm_index_;
};

class Learned_Index_RMI : public Learned_Index {
public:
  Learned_Index_RMI(std::vector<double>::iterator begin, std::vector<double>::iterator end, size_t layer2_size)
    : data_(begin, end), n_(std::distance(begin, end)), x_min_(*begin), x_max_(*(end - 1)), rmi_index_(begin, end, layer2_size) {
    assert(n_ > 0);
  }

  size_t lower_bound(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    rmi::Approx pos = rmi_index_.search(q);

    return std::lower_bound(data_.begin() + pos.lo, data_.begin() + pos.hi, q) - data_.begin();
  }

  size_t approx_pos(double q) const override {
    if (q > x_max_) return n_;
    if (q <= x_min_) return 0;
    return rmi_index_.search(q).pos;
  }

private:
  const std::vector<double> data_;
  const size_t n_;
  const double x_min_, x_max_;
  const rmi::Rmi<double, rmi::LinearRegression, rmi::LinearRegression> rmi_index_;
};

}  // namespace learned_sort_using_learned_index

#endif  // LI_H
