/**
 * Copyright 2017-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the license found in the
 * LICENSE file in the root directory of this source tree.
 */

/*
C++ code for solving the linear assignment problem.
Based on the Auction Algorithm from
https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the
implementation from: https://github.com/bkj/auction-lap Adapted to be more
efficient when each worker is looking for k jobs instead of 1.
*/
#include <torch/extension.h>
#include <iostream>
using namespace torch::indexing;
torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score, bool strict_sort=false) {
  int max_iterations = 100;
  torch::Tensor epsilon =
      (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50;
  epsilon.clamp_min_(1e-04);
  torch::Tensor worker_and_job_to_score =
      job_and_worker_to_score.detach().transpose(0, 1).contiguous();
  int num_workers = worker_and_job_to_score.size(0);
  int num_jobs = worker_and_job_to_score.size(1);
  auto device = worker_and_job_to_score.device();
  int jobs_per_worker = num_jobs / num_workers;
  torch::Tensor value = worker_and_job_to_score.clone();
  int counter = 0;
  torch::Tensor max_value = worker_and_job_to_score.max();

  torch::Tensor bid_indices;
  torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs});
  torch::Tensor bids =
      worker_and_job_to_score.new_empty({num_workers, num_jobs});
  torch::Tensor bid_increments =
      worker_and_job_to_score.new_empty({num_workers, jobs_per_worker});
  torch::Tensor top_values =
      worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1});
  torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs});

  torch::Tensor top_index = top_values.to(torch::kLong);
  torch::Tensor high_bidders = top_index.new_empty({num_jobs});
  torch::Tensor have_bids = high_bidders.to(torch::kBool);
  torch::Tensor have_bids_ = high_bidders.to(torch::kBool);
  torch::Tensor jobs_indices =
      torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device));
  torch::Tensor true_tensor =
      torch::ones({1}, torch::dtype(torch::kBool).device(device));

  while (true) {
    bids.zero_();
    torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1);

    // Each worker bids the difference in value between that job and the k+1th
    // job
    torch::sub_out(
        bid_increments,
        top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}),
        top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1));

    bid_increments.add_(epsilon);
    bids.scatter_(
        1,
        top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}),
        bid_increments);

    if (counter < max_iterations && counter > 0) {
      // Put in a minimal bid to retain items from the last round if no-one else
      // bids for them this round
      bids.view(-1).index_put_({bid_indices}, epsilon);
    }

    // Find the highest bidding worker per job
    torch::max_out(high_bids, high_bidders, bids, 0);
    torch::gt_out(have_bids, high_bids, 0); // have_bids=high_bids>0

    if (!strict_sort && have_bids.all().item<bool>()) {
      // All jobs were bid for
      break;
    }
    if(strict_sort){
      torch::gt_out(have_bids_, high_bids, epsilon); // have_bids_=high_bids>eps
      if (have_bids_.all().item<bool>()){
        break;
      }
    }

    // Make popular items more expensive
    cost.add_(high_bids);
    torch::sub_out(value, worker_and_job_to_score, cost);

    bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids});

    if (counter < max_iterations) {
      // Make sure that this item will be in the winning worker's top-k next
      // time.
      value.view(-1).index_put_({bid_indices}, max_value);
    } else {
      // Suboptimal approximation that converges quickly from current solution
      value.view(-1).index_put_(
          {bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices}));
    }

    counter += 1;
  }

  return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)})
      .reshape(-1);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment");
}