// compile with c++20 or higher

#include <bits/stdc++.h>

#include "graph_cc_cost_working.hpp"

using namespace std;

template <size_t ITERS, size_t... Ks, double... Bs>
auto run_experiment_simple_sampler(const vector<element>& sigma, k_seq<Ks...>, space_seq<Bs...>) {
  constexpr size_t N_K = sizeof...(Ks);
  constexpr size_t N_B = sizeof...(Bs);

  cout << "Ks(" << (sizeof...(Ks)) << "):";
  ((std::cout << " " << Ks), ...);
  cout << "\n";
  cout << "Bs(" << sizeof...(Bs) << "):";
  ((cout << " " << Bs), ...);
  cout << endl;

  const int n = sigma.size();

  vector<RankGenerator> pi;
  for (int i = 0; i < ITERS; i++) 
    pi.emplace_back(rd());

  std::array<baseline_stats, ITERS> res_pivot{};
  std::array<baseline_stats, N_K * ITERS> res_pruned_pivot{};
  std::array<exec_stats, N_K * N_B * ITERS> res_cc{}, res_node{}, res_edge{};
  
  auto sigma_ptr = &sigma;
  auto pi_ptr = &pi;
  auto res_node_ptr = &res_node;
  auto res_edge_ptr = &res_edge;
  auto res_pruned_pivot_ptr = &res_pruned_pivot;
  auto res_pivot_ptr = &res_pivot;
  auto res_cc_ptr = &res_cc;

  #pragma omp parallel
  #pragma omp single nowait
  {
    // Pivot iterations
    #pragma omp task priority(1)
    {
      #pragma omp taskloop grainsize(1)
      for (int i = 0; i < ITERS; i++) {
        (*res_pivot_ptr)[i] = pivot_stats((*sigma_ptr), (*pi_ptr)[i], n);
      }
    }

    cout << "exited pivots" << endl;

    // Pruned Pivot iterations
    [&]<size_t... I> (index_sequence<I...>) {
      cout << "entered pruned pivot builder" << endl;

      ([&]{
        cout << "entered " << I << " for pruned pivot" << endl;
        #pragma omp task priority(1)
        {
          #pragma omp taskloop grainsize(1) 
          for (int j = 0; j < ITERS; ++j) {
            (*res_pruned_pivot_ptr)[I*ITERS + j] = pruned_pivot_stats<Ks>((*sigma_ptr), (*pi_ptr)[j], n);
          }
        }
      }(), ...);
    }(make_index_sequence<sizeof...(Ks)>());
  }

  cout << "finished pivot and pruned pivot!" << endl;

  if (sigma.size() > 1'000'000) {
    omp_set_num_threads(48);
  }

  const double baseline_mem = get_memory_usage_percent();
  auto total_kb_running = std::make_shared<std::atomic<double>>(0.0);

  #pragma omp parallel
  #pragma omp single nowait
  {
    // Use OpenMP depend token for robust release: create a single release task
    // with depend(out: token), and all worker tasks depend(in: token).
    char token = 0;

    // Create worker tasks (each depends on token)
    [&]<size_t... Is, size_t... Js> (index_sequence<Is...>, index_sequence<Js...>) {
      ([&]{
        constexpr auto I = Is;
        constexpr auto K = Ks;
        ([&]{
          constexpr auto J = Js;
          constexpr auto B = Bs;
          constexpr double R_ALLOC = min(1., B * R_BUDGET);
          const int r = min((int)ceil(n * R_ALLOC), n);

          for (int i = 0; i < ITERS; ++i) {
            #pragma omp task firstprivate(i, sigma_ptr, pi_ptr, res_node_ptr, res_edge_ptr, res_cc_ptr, total_kb_running, baseline_mem) depend(in: token)
            {
              // Wait if memory usage is too high
              while (get_memory_usage_percent() > 80.0) {
                double curr_mem = get_memory_usage_percent();
                double curr_kb = *total_kb_running;
                double unit = curr_kb > 1 ? (curr_mem - baseline_mem) / curr_kb : 0.;
                double projected_mem = baseline_mem + unit * (curr_kb + K * B);
                if (curr_mem < 70.0 && projected_mem < 90.0)
                  break;                
                
                cout << "waiting for more memory! currently using " << curr_mem << " and have unit=" << unit << ", curr kb=" << curr_kb << ", my kb=" << K*B << ", projected usage=" << projected_mem << endl;

                #pragma omp taskyield
                std::this_thread::sleep_for(std::chrono::milliseconds(1000));  // Sleep 1s to avoid busy waiting
              }
              
              total_kb_running->fetch_add(K * B);

              auto& stat_node = (*res_node_ptr)[I*ITERS*N_B + J*ITERS + i];
              auto& stat_edge = (*res_edge_ptr)[I*ITERS*N_B + J*ITERS + i];
              auto& stat_cc = (*res_cc_ptr)[I*ITERS*N_B + J*ITERS + i];
              simple_sampler_node<K, B>(*sigma_ptr, (*pi_ptr)[i], stat_node);
              simple_sampler_edge<K, B>(*sigma_ptr, (*pi_ptr)[i], stat_edge);
              cc_cost<K, B>(*sigma_ptr, (*pi_ptr)[i], stat_cc, r);

              total_kb_running->fetch_sub(K * B);
            }
          }

        }(), ...);
      }(), ...);
    }(make_index_sequence<N_K>(), make_index_sequence<N_B>());

    // Release task: after all worker tasks are created, this completes and releases them
    #pragma omp task depend(out: token)
    {
      // No-op: just releases the token
    }
  }

  cout << "completed!" << endl;
  // res = std::array{run_experiment_fixed_space<ITERS, K, SPACE_BUDGETs>(sigma, pi)...};
  return tuple{res_pivot, res_pruned_pivot, res_cc, res_node, res_edge};
}


void run_space_vs_error_full(string input_file, float threshold) {
  cout << "reading in graph from " << input_file << "..." << endl;

  auto in = deserialize(input_file, threshold);
  vector<element> G;
  G.reserve(in.size());
  for (int i = 0; i < in.size(); i++)
    G.push_back({i, in[i]});

  cout << "finished reading graph!" << endl;

  // main big one
  auto res = run_experiment_simple_sampler<150>(G, k_seq<5>{}, space_seq<
       0.001     , 0.00119125, 0.00141908, 0.00169049, 0.0020138 ,
       0.00239895, 0.00285775, 0.00340431, 0.00405539, 0.004831  ,
       0.00575495, 0.0068556 , 0.00816676, 0.00972868, 0.01158932,
       0.01380582, 0.01644622, 0.01959162, 0.02333858, 0.02780216,
       0.03311942, 0.03945362, 0.04699926, 0.05598803, 0.06669592,
       0.07945174, 0.09464716, 0.11274874, 0.13431232, 0.16>{});
  auto [res_pivot, res_pruned_pivot, res_cc, res_node, res_edge] = res;

  string out_folder = "./results/";
  assert(folder_exists(out_folder));

  std::filesystem::path input_path(input_file);
  string basename = input_path.stem().string();

  if (threshold > 0) {
    std::stringstream ss;
    ss << std::fixed << std::setprecision(3) << threshold;
    std::string thresh_str = ss.str().substr(2);
    basename += "_threshold_" + thresh_str;
  }
  string base_file = out_folder + basename + "_space_vs_variance_";
  string latest_file = base_file + "latest.txt";
  auto output_file = timestamp_filename(base_file, ".txt");

  std::ofstream of(output_file);

  of << "[\n";

  for (auto x : res_pivot)
    write_baseline_stats_json(of, x, true);
  for (auto x : res_pruned_pivot)
    write_baseline_stats_json(of, x, true);
  for (auto x : res_node)
    write_exec_stats_json(of, x, true);
  for (auto x : res_cc)
    write_exec_stats_json(of, x, true);
  for (int i = 0; i < res_edge.size(); i++)
    write_exec_stats_json(of, res_edge[i], i+1 < res_edge.size());

  of << "]" << endl;

  of.close();

  stringstream command;
  command << "cp " << output_file << " " << latest_file;
  std::system(command.str().c_str());
}

int main(int argc, char* argv[]) {
  cin.tie(0)->sync_with_stdio(0);
  cin.exceptions(cin.failbit);

  if (argc < 2) {
    cerr << "usage: " << argv[0] << " graph_file" << endl;
    return 1;
  }

  string filename = argv[1];
  if (!filename.ends_with(".bin") && !filename.ends_with(".npz")) {
    cerr << "input graph must be in binary format (.bin) or numpy format (.npz)" << endl;
    return 1;
  }

  float threshold = argc < 3 ? -1 : std::stof(argv[2]);

  run_space_vs_error_full(filename, threshold);

  return 0;
}
