#include "enumeration.cuh"

#include <iostream>
#include <helper_cuda.h>
#include <cuda_runtime_api.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <signal.h>
#include <cooperative_groups.h>
#include <algorithm>
#include <chrono>

using namespace std::chrono;
using namespace cooperative_groups;

#define CHECK_CUDA(err) ({cudaError_t v = (cudaError_t) err; if(v != cudaSuccess) { printf("CUDA error in %s:%d: %s\n", __FILE__, __LINE__, _cudaGetErrorEnum(v)); raise(SIGSEGV); }})
#define DIV_UP(a, b) (((a) / (b) + ((a) % (b) == 0 ? 0 : 1)))

using warp_t = thread_block_tile<32>;

// Note: not finished, there are bugs in this code. 

__global__
void
k_knapsack(
  const uint64_t start_ix,
  const uint64_t end_ix,
  const uint64_t ixs_per_warp,
  const int n,
  const int capacity,
  const int* weights,
  const double* direction,
  double * opt_vals)
{
  thread_block tb = this_thread_block();
  warp_t warp = tiled_partition<32>(tb);

  const uint64_t warp_ix = tb.group_index().x;
  const uint64_t thread_in_warp = warp.thread_rank();

  const uint64_t warp_start_ix = start_ix + warp_ix * ixs_per_warp;
  const uint64_t warp_aligned_end_ix = 32 * DIV_UP(warp_start_ix + ixs_per_warp, 32);

  const uint64_t nnz_pages = 32 * DIV_UP(n, 32);

  double obj = -INFINITY;
  for(uint64_t ix = warp_start_ix + thread_in_warp; ix < warp_aligned_end_ix; ix += 32)
  {
    int capacity_sum = 0;
    double obj_sum = 0.0;

    for(uint32_t page = 0; page < nnz_pages; page += 32)
    {
      if (page + thread_in_warp < n){
        
        const int sol_val = (ix & (1ul << (page + thread_in_warp))) > 0;

        capacity_sum += sol_val * weights[page + thread_in_warp];
        obj_sum += sol_val * direction[page + thread_in_warp];
      }
    }
    warp.sync();
    
    #pragma unroll
    for (int i=16; i>0; i=i/2){
      
      capacity_sum += warp.shfl_down(capacity_sum, i);
      obj_sum += warp.shfl_down(obj_sum, i);
    }
    
    warp.sync();
    
    if(thread_in_warp == 0 && capacity_sum <= capacity && obj_sum > obj)
    {
        obj = obj_sum;
    }
    warp.sync();
  }
  
  if (thread_in_warp == 0){
    opt_vals[warp_ix] = obj;
  }
}

void inner_call_oracle_gpu(
   const int            n,
   const long long int* weights,
   const long long int  capacity,
   const double*        direction,
   double*              objval,
   double*              vertex
   )
{
    uint64_t resident_warps = 4096;
    const uint64_t num_sols = 1ul << n;
    uint64_t ixs_per_warp = DIV_UP(num_sols, resident_warps);
    const uint64_t num_blocks = std::min(num_sols, resident_warps);

    int* d_weights;
    double* d_direction;
    double* d_opt_vals;

    CHECK_CUDA(cudaMallocManaged(&d_weights, n * sizeof(int)));
    CHECK_CUDA(cudaMallocManaged(&d_direction, n * sizeof(double)));
    CHECK_CUDA(cudaMallocManaged(&d_opt_vals, num_blocks * sizeof(double)));

    for (int i=0; i<n; i++){
      d_weights[i] = static_cast<int>(weights[i]);
      d_direction[i] = direction[i];
    }

    printf("num blocks: %u, threads: 32 num sols: %u, ix_per_warp: %u\n", uint(num_blocks), uint(num_sols), uint(ixs_per_warp));
    // execute enumeration
    auto start = high_resolution_clock::now();
    k_knapsack<<<num_blocks, 32>>>(
      0ul, 
      num_sols,
      ixs_per_warp,
      n,
      capacity,
      d_weights,
      d_direction,
      d_opt_vals
    );

    CHECK_CUDA(cudaPeekAtLastError());
    CHECK_CUDA(cudaDeviceSynchronize());
    auto stop = high_resolution_clock::now();
    auto duration_ref = duration_cast<microseconds>(stop - start);
    std::cout << " time kernel: " << duration_ref.count() << std::endl;
    *objval = *std::max_element(d_opt_vals , d_opt_vals + num_blocks);
}