#include "kde.h"
#include <cmath>
#include <random>
#include <functional>
#include <thread>
#include <future>
#include <iostream>
#include "fgt.hpp"

const double PI = 3.1415926535;
const double INV_SQRT_2PI = 1.0 / std::sqrt(2 * PI);

//--------------------------------------------------------
// Improved Fast Gauss Transform KDE
//--------------------------------------------------------
IFGT::IFGT() {}

Eigen::VectorXd IFGT::query(const Eigen::MatrixXd& data,
                            const Eigen::MatrixXd& query,
                            double bandwidth,
                            bool debug) {
    // Call through to the IFGT code - the bandwidth must be at most 1!
    return fgt::ifgt(data, query, bandwidth, 1.0);
}

//-----------------------------------------------------------------
// Laplace KDE for high dimensions. Largely converted
// from the Python implementation of Backurs et al., NeurIPS 2019
//-----------------------------------------------------------------

/**
 * Compute the Laplacian kernel function.
 * @param x
 * @param y
 * @param bandwidth
 * @return
 */
double laplacian_kernel(Eigen::VectorXd& x, Eigen::VectorXd& y, double bandwidth) {
  return exp(-(x - y).lpNorm<1>() / bandwidth);
}

double exponential_kernel(Eigen::VectorXd& x, Eigen::VectorXd& y, double bandwidth) {
  return exp(-(x - y).lpNorm<2>() / bandwidth);
}

class LSH {
  virtual size_t compute_hash(Eigen::VectorXd& x) = 0;
};

/**
 * This assumes that the data is in the rand (0, 1) in every dimension.
 */
class LaplaceLSH : public LSH {
public:
  LaplaceLSH(long dimension, double bandwidth) {
    std::random_device rd;
    std::mt19937 gen(rd());

    // Get the number of hashes to take
    double poisson_param = dimension * 1. / bandwidth;
    std::poisson_distribution<> poissDist(poisson_param);
    reps_ = poissDist(gen);

    // Sample the axes and thresholds for the compute_hash functions
    std::uniform_int_distribution<long> axesDist(0, dimension - 1);
    std::uniform_real_distribution<double> threshDist(0, 1);
    for (int i = 0; i < reps_; i++) {
      axes_.push_back(axesDist(gen));
      thresholds_.push_back(threshDist(gen));
    }
  }

  size_t compute_hash(Eigen::VectorXd& x) override {
    size_t hash = 0;
    size_t multiple = 1;

    for (int i = 0; i < reps_; i++) {
      long bin_d = 0;
      if (x[axes_.at(i)] < thresholds_.at(i)) {
        bin_d = 1;
      }

      hash += multiple * bin_d;

      // Update the multiple we must add each time
      multiple *= 2;
    }

    return hash;
  }

protected:
  long reps_;
  std::vector<long> axes_;
  std::vector<double> thresholds_;
};

/**
 * Again, we assume that the data is in the range (0, 1) in every dimension.
 */
class BinningLaplaceLSH : public LSH {
public:
  BinningLaplaceLSH(long dimension, double bandwidth) {
    dimension_ = dimension;
    std::random_device rd;
    std::mt19937 gen(rd());

    // The bin lengths are chosen from the gamma distribution
    std::gamma_distribution<double> deltaDist(2, bandwidth);

    // Choose all the parameters
    for (int i = 0; i < dimension; i++) {
      double bin_length = deltaDist(gen);

      std::uniform_real_distribution<double> shiftDist(0, bin_length);
      double shift = shiftDist(gen);
      shifts_.push_back(shift);
      bin_lengths_.push_back(bin_length);
    }
  }

  size_t compute_hash(Eigen::VectorXd& x) override {
    size_t hash = 0;
    size_t multiple = 1;

    for (int d = 0; d < dimension_; d++) {
      long bin_d = 0;
      double shifted = x[d] - shifts_.at(d);
      if (shifted > 0) {
        bin_d = ceil(shifted / bin_lengths_.at(d));
      }

      hash += multiple * bin_d;

      // Update the multiple we must add each time
      multiple *= ceil(1 / bin_lengths_.at(d));
    }

    return hash;
  }

protected:
  long dimension_;
  std::vector<double> shifts_;
  std::vector<double> bin_lengths_;
};

template <typename HashType>
std::vector<double> get_estimates(const Eigen::MatrixXd& data,
                                  const Eigen::MatrixXd& query,
                                  double bandwidth,
                                  long L,
                                  long true_L) {
  // For simplicity, we do not support bandwidths greater than 1
  long n = data.rows();
  long m = query.rows();
  long dimension = data.cols();

  // For each of the L hash functions, we create a hashmap which
  // leads to a vector of all points in the given bin
  std::vector<std::unordered_map<size_t, std::vector<Eigen::VectorXd>>> all_buckets;

  // We will add points to each hash with probability L * 1/n
  double probability = ((double) true_L) / ((double) n);
  std::random_device rd;
  std::mt19937 gen(rd());
  std::bernoulli_distribution sampleDist(probability);

  // Start by adding the data points to the hashes
  std::vector<HashType> hashers;
  for (long i = 0; i < L; i++) {
    // Create the hash object
    HashType hasher(dimension, bandwidth);
    hashers.push_back(hasher);

    // Create a new vector for this hash in all_buckets
    std::unordered_map<size_t, std::vector<Eigen::VectorXd>> new_map;
    all_buckets.push_back(new_map);

    for (int j = 0; j < n; j++) {
      if (sampleDist(gen)) {
        // Get the hash value of this point
        Eigen::VectorXd this_point = data.row(j).transpose();
        size_t hValue = hasher.compute_hash(this_point);

        // Add the point to the appropriate bin
        all_buckets.back()[hValue].push_back(this_point);
      }
    }
  }

  assert(all_buckets.size() == L);
  assert(hashers.size() == L);

  // Now, we do the query step of the KDE algorithm.
  std::vector<double> results;
  for (int j = 0; j < m; j++) {
    Eigen::VectorXd query_point = query.row(j).transpose();
    std::vector<double> estimators;

    for (long i = 0; i < L; i++) {
      // Get the hash value of this query point
      size_t queryHash = hashers.at(i).compute_hash(query_point);

      // Check how many data points were hashed to this bucket
      size_t bucket_size = all_buckets.at(i)[queryHash].size();
      if (bucket_size == 0) {
        estimators.push_back(0);
      } else {
        // We choose a random data point that was mapped to this bucket and use it
        // as the estimator for the query point.
        std::uniform_int_distribution<size_t> bucketDist(0, bucket_size - 1);
        size_t sample_point_idx = bucketDist(gen);
        Eigen::VectorXd sampled_point = all_buckets.at(i)[queryHash].at(sample_point_idx);
        estimators.push_back(bucket_size * laplacian_kernel(query_point, sampled_point, 2 * bandwidth));
      }
    }

    // Our final answer for this query point is the mean of the estimators.
    double estimator_sum = std::accumulate(estimators.begin(), estimators.end(), 0.0);
    results.push_back(estimator_sum / ((double) estimators.size()));
  }

  return results;
}

std::vector<double> get_estimates_proxy(const Eigen::MatrixXd& data,
                                        const Eigen::MatrixXd& query,
                                        double bandwidth,
                                        long L,
                                        long true_L) {
  if (bandwidth < 1) {
    return get_estimates<BinningLaplaceLSH>(data, query, bandwidth, L, true_L);
  } else {
    return get_estimates<LaplaceLSH>(data, query, bandwidth, L, true_L);
  }
}

LaplaceKDE::LaplaceKDE() {}

Eigen::VectorXd LaplaceKDE::query(const Eigen::MatrixXd& data,
                                  const Eigen::MatrixXd& query,
                                  double bandwidth,
                                  bool debug) {
  // For simplicity, we do not support bandwidths greater than 1
  assert(bandwidth <= 1);
  long n = data.rows();

  if (n <= 100) {
    // If n is small enough, don't bother splitting into threads, just compute the results
    long L = 10 * ((long) sqrt(n));
    std::vector<double> results = get_estimates_proxy(data, query, bandwidth, L, L);

    // DEBUG: compare with the true densities
    if (debug) {
      long m = query.rows();
      std::cout << std::endl;
      std::cout << "START OF DEBUG with " << n << " data points and " << m << " query points" << std::endl;
      for (int i = 0; i < query.rows(); i++) {
        Eigen::VectorXd this_query_point = query.row(i).transpose();

        // Compute the true density at this point
        double true_density = 0;
        for (int j = 0; j < data.rows(); j++) {
          Eigen::VectorXd this_data_point = data.row(j).transpose();
          true_density += laplacian_kernel(this_query_point, this_data_point, bandwidth);
        }

        // Display the density comparison
        std::cout << "Estimate: " << results.at(i) << ", True: " << true_density << std::endl;
      }
    }

    return Eigen::Map<Eigen::VectorXd>(results.data(), results.size());
  } else{
    long num_threads = 10;

    // Split into threads
    std::vector<std::future<std::vector<double>>> threads;
    for (int i = 0; i < num_threads; i++) {
      long true_L = 10 * ((long) sqrt(n));
      long L = true_L / num_threads;
      auto task = [data, query, bandwidth, L, true_L] ()
      {
        return get_estimates_proxy(data, query, bandwidth, L, true_L);
      };

      std::future<std::vector<double>> new_future = std::async(std::launch::async,
                                               task);
      threads.push_back(std::move(new_future));
    }

    // Join all the threads. Each true estimate is the average of the sub-estimates
    std::vector<double> results = threads.at(0).get();
    for (int i = 1; i < num_threads; i++) {
      std::vector<double> result = threads.at(i).get();
      for (int j = 0; j < results.size(); j++) {
        results.at(j) += result.at(j) / num_threads;
      }
    }

    // DEBUG: compare with the true densities
    if (debug) {
      long m = query.rows();
      std::cout << std::endl;
      std::cout << "START OF DEBUG with " << n << " data points and " << m << " query points" << std::endl;
      for (int i = 0; i < query.rows(); i++) {
        Eigen::VectorXd this_query_point = query.row(i).transpose();

        // Compute the true density at this point
        double true_density = 0;
        for (int j = 0; j < data.rows(); j++) {
          Eigen::VectorXd this_data_point = data.row(j).transpose();
          true_density += laplacian_kernel(this_query_point, this_data_point, bandwidth);
        }

        // Display the density comparison
        std::cout << "Estimate: " << results.at(i) << ", True: " << true_density << std::endl;
      }
    }

    return Eigen::Map<Eigen::VectorXd>(results.data(), results.size());
  }
}

//---------------------------------------------------------
// Exponential kernel LSH KDE - bandwidth is always 1
//---------------------------------------------------------
class ExponentialLSH : public LSH {
public:
  ExponentialLSH(long dimension, double bandwidth) {
    // Bandwidth is ignored (should be 1)
    dimension_ = dimension;
    std::random_device rd;
    std::mt19937 gen(rd());

    // Set the 'bin width' to be a function of a bunch of stuff. Let's try just seting
    // to 1. We'll also set R = 1, and the number of hashes to 1.
    w_ = sqrt(2 / PI) * bandwidth;

    // Sample the random projection vector g
    std::normal_distribution<double> gDist(0, 1);
    for (int i = 0; i < dimension; i++) {
      g_.push_back(gDist(gen));
    }

    // Sample the random shift beta
    std::uniform_real_distribution<double> betaDist(0, w_);
    beta_ = betaDist(gen);
  }

  size_t compute_hash(Eigen::VectorXd& x) override {
    // Get the inner product with g
    double inner = 0;
    for (int i = 0; i < dimension_; i++) {
      inner += x[i] * g_.at(i);
    }

    // Perform the shift and normalisation
    return ceil((inner + beta_) / w_);
  }

protected:
  long dimension_;
  std::vector<double> g_;
  double beta_;
  double w_;
};


ExponentialKDE::ExponentialKDE () {}

Eigen::VectorXd ExponentialKDE::query(const Eigen::MatrixXd& data,
                                  const Eigen::MatrixXd& query,
                                  double bandwidth,
                                  bool debug) {
  long n = data.rows();

  if (n <= 100) {
    // If n is small enough, don't bother splitting into threads, just compute the results
    long L = 10 * ((long) sqrt(n));
    std::vector<double> results = get_estimates<ExponentialLSH>(data, query, bandwidth, L, L);

    // DEBUG: compare with the true densities
    if (debug) {
      long m = query.rows();
      std::cout << std::endl;
      std::cout << "START OF DEBUG with " << n << " data points and " << m << " query points" << std::endl;
      for (int i = 0; i < query.rows(); i++) {
        Eigen::VectorXd this_query_point = query.row(i).transpose();

        // Compute the true density at this point
        double true_density = 0;
        for (int j = 0; j < data.rows(); j++) {
          Eigen::VectorXd this_data_point = data.row(j).transpose();
          true_density += exponential_kernel(this_query_point, this_data_point, bandwidth);
        }

        // Display the density comparison
        std::cout << "Estimate: " << results.at(i) << ", True: " << true_density << std::endl;
      }
    }

    return Eigen::Map<Eigen::VectorXd>(results.data(), results.size());
  } else{
    long num_threads = 10;

    // Split into threads
    std::vector<std::future<std::vector<double>>> threads;
    for (int i = 0; i < num_threads; i++) {
      long true_L = 10 * ((long) sqrt(n));
      long L = true_L / num_threads;
      auto task = [data, query, bandwidth, L, true_L] ()
      {
        return get_estimates<ExponentialLSH>(data, query, bandwidth, L, true_L);
      };

      std::future<std::vector<double>> new_future = std::async(std::launch::async,
                                                               task);
      threads.push_back(std::move(new_future));
    }

    // Join all the threads. Each true estimate is the average of the sub-estimates
    std::vector<double> results = threads.at(0).get();
    for (int i = 1; i < num_threads; i++) {
      std::vector<double> result = threads.at(i).get();
      for (int j = 0; j < results.size(); j++) {
        results.at(j) += result.at(j) / num_threads;
      }
    }

    // DEBUG: compare with the true densities
    if (debug) {
      long m = query.rows();
      std::cout << std::endl;
      std::cout << "START OF DEBUG with " << n << " data points and " << m << " query points" << std::endl;
      for (int i = 0; i < query.rows(); i++) {
        Eigen::VectorXd this_query_point = query.row(i).transpose();

        // Compute the true density at this point
        double true_density = 0;
        for (int j = 0; j < data.rows(); j++) {
          Eigen::VectorXd this_data_point = data.row(j).transpose();
          true_density += exponential_kernel(this_query_point, this_data_point, bandwidth);
        }

        // Display the density comparison
        std::cout << "Estimate: " << results.at(i) << ", True: " << true_density << std::endl;
      }
    }

    return Eigen::Map<Eigen::VectorXd>(results.data(), results.size());
  }
}

