/**
 * Copyright (c) 2016-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <cstring>
#include <istream>
#include <ostream>
#include <random>
#include <vector>

#include "real.h"
#include "vector.h"

namespace fasttext {

class ProductQuantizer {
 protected:
  const int32_t nbits_ = 8;
  const int32_t ksub_ = 1 << nbits_;
  const int32_t max_points_per_cluster_ = 256;
  const int32_t max_points_ = max_points_per_cluster_ * ksub_;
  const int32_t seed_ = 1234;
  const int32_t niter_ = 25;
  const real eps_ = 1e-7;

  int32_t dim_;
  int32_t nsubq_;
  int32_t dsub_;
  int32_t lastdsub_;

  std::vector<real> centroids_;

  std::minstd_rand rng;

 public:
  ProductQuantizer() {}
  ProductQuantizer(int32_t, int32_t);

  real* get_centroids(int32_t, uint8_t);
  const real* get_centroids(int32_t, uint8_t) const;

  real assign_centroid(const real*, const real*, uint8_t*, int32_t) const;
  void Estep(const real*, const real*, uint8_t*, int32_t, int32_t) const;
  void MStep(const real*, real*, const uint8_t*, int32_t, int32_t);
  void kmeans(const real*, real*, int32_t, int32_t);
  void train(int, const real*);

  real mulcode(const Vector&, const uint8_t*, int32_t, real) const;
  void addcode(Vector&, const uint8_t*, int32_t, real) const;
  void compute_code(const real*, uint8_t*) const;
  void compute_codes(const real*, uint8_t*, int32_t) const;

  void save(std::ostream&) const;
  void load(std::istream&);
};

} // namespace fasttext
