// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "../large_action_space.h"
#include "vw/core/cb.h"
#include "vw/core/reductions/gd.h"
#ifdef _MSC_VER
#  include <intrin.h>
#endif

namespace VW
{
namespace cb_explore_adf
{
/**
 * One pass SVD
 * one-pass refers to the the randomness we apply to the original A matrix, as opposed to two-pass randomized SVD which
 * is more commonly found in the literature and has a better reconstruction accuracy which we don't necessarily need
 *
 * Random sampling is used to identify a subspace that captures most of the action of a matrix
 * (https://arxiv.org/abs/0909.4061). By multiplying A with Omega, Omega being of dimensions Fxp, we are essentially
 * randomly sampling from A and projecting onto lower dimensions (p instead of F). We then proceed to apply SVD on
 * AOmega and truncate to d.
 *
 * A matrix with KxF dimenstions (sparse action features matrix)
 * Omega matrix with Fxp dimensions (Lazy Rademacher matrix)
 * AOmega is the dot product of A * Omega (matrix of dimenstions Fxp)
 * We then proceed to apply SVD on AOmega: (U, S, V) <- SVD(AOmega)
 * U <- U_d => truncate rows to d
 *
 * p = d + 10
 * one pass SVD is going to be less accurate than two pass SVD so we need to over-sample. This constant factor should be
 * enough, we need a higher probability that we get a fair coin flip in the Omega matrix
 *
 * The two basic properties we want to get from U after performing this truncated randomized SVD are:
 * 1. if an action is duplicated in A, the duplicates have the same representation in U
 * 2. if a third action is a linear combination of two other actions:
 *  a. one of the singular values is zero (or virtually close to zero)
 *  b. the third actions' representation in U is the linear combination of the other two actions representations in U
 *
 *
 * The matrices A and Omega are never materialized. The only matrix that is materialized is AOmega. We perform the
 * multiplication between A and Omega by keeping them in a lazy format: we iterate over A's rows (for each row), and for
 * every row in A we multiply it with the correct Omega column. We do that by calling foreach_feature on every
 * example-row and calculating the corresponding Omega cell that needs to be multiplied with the corresponding feature
 * (row's cell) on the fly, and adding the product to the final dotproduct corresponding to that example-row)
 */

struct AO_triplet_constructor
{
private:
  uint64_t _weights_mask;
  uint64_t _row_index;
  uint64_t _column_index;
  uint64_t _seed;
  float& _final_dot_product;

public:
  AO_triplet_constructor(
      uint64_t weights_mask, uint64_t row_index, uint64_t column_index, uint64_t seed, float& final_dot_product)
      : _weights_mask(weights_mask)
      , _row_index(row_index)
      , _column_index(column_index)
      , _seed(seed)
      , _final_dot_product(final_dot_product)
  {
  }

  void set(float feature_value, uint64_t index)
  {
    // set is going to be called foreach feature
    // The index and _column_index are used figure out the Omega cell and its value that is needed to multiply the
    // feature value with
    // That combined index is then used to flip a coin and generate rademacher randomness
    // The way the coin is flipped is by counting the number of bits in the combined index (plus the seed). If the
    // number of bits are even then we multiply with -1.f else we multiply with 1.f
#ifdef _MSC_VER
    float val = ((__popcnt((index & _weights_mask) + _column_index + _seed) & 1) << 1) - 1.f;
#else
    float val = (__builtin_parity((index & _weights_mask) + _column_index + _seed) << 1) - 1.f;
#endif
    _final_dot_product += feature_value * val;
  }
};

void one_pass_svd_impl::generate_AOmega(const multi_ex& examples, const std::vector<float>& shrink_factors)
{
  auto num_actions = examples[0]->pred.a_s.size();
  // one pass SVD is going to be less accurate than two pass SVD so we need to over-sample
  // this constant factor should be enough, we need a higher probability that we get a fair coin flip in the Omega
  // matrix
  const uint64_t sampling_slack = 10;
  auto p = std::min(num_actions, _d + sampling_slack);
  AOmega.resize(num_actions, p);

  auto calculate_aomega_row = [](uint64_t row_index, uint64_t p, VW::workspace* _all, uint64_t _seed, VW::example* ex,
                                  Eigen::MatrixXf& AOmega, const std::vector<float>& shrink_factors) -> void {
    auto& red_features = ex->_reduction_features.template get<VW::generated_interactions::reduction_features>();

    for (uint64_t col = 0; col < p; ++col)
    {
      float final_dot_prod = 0.f;

      AO_triplet_constructor tc(_all->weights.mask(), row_index, col, _seed, final_dot_prod);

      GD::foreach_feature<AO_triplet_constructor, uint64_t, triplet_construction, dense_parameters>(
          _all->weights.dense_weights, _all->ignore_some_linear, _all->ignore_linear,
          (red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions),
          (red_features.generated_extent_interactions ? *red_features.generated_extent_interactions
                                                      : *ex->extent_interactions),
          _all->permutations, *ex, tc, _all->_generate_interactions_object_cache);

      AOmega(row_index, col) = final_dot_prod * shrink_factors[row_index];
    }
  };

  uint64_t row_index = 0;
  for (auto* ex : examples)
  {
    _futures.emplace_back(_thread_pool.submit(
        calculate_aomega_row, row_index, p, _all, _seed, ex, std::ref(AOmega), std ::ref(shrink_factors)));
    row_index++;
  }

  for (auto& ft : _futures) { ft.get(); }
  _futures.clear();
}

void one_pass_svd_impl::_test_only_set_rank(uint64_t rank) { _d = rank; }

void one_pass_svd_impl::run(const multi_ex& examples, const std::vector<float>& shrink_factors, Eigen::MatrixXf& U,
    Eigen::VectorXf& _S, Eigen::MatrixXf& _V)
{
  generate_AOmega(examples, shrink_factors);
  _svd.compute(AOmega, Eigen::ComputeThinU | Eigen::ComputeThinV);
  U = _svd.matrixU().leftCols(_d);
  if (_set_testing_components)
  {
    _S = _svd.singularValues();
    _V = _svd.matrixV();
  }
}

one_pass_svd_impl::one_pass_svd_impl(VW::workspace* all, uint64_t d, uint64_t seed, size_t, size_t thread_pool_size)
    : _all(all), _d(d), _seed(seed), _thread_pool(thread_pool_size)
{
}

}  // namespace cb_explore_adf
}  // namespace VW