// compile with c++20 or higher

#include <bits/stdc++.h>

#include <algorithm>
#include <bit>
#include <cmath>
#include <cstdint>
#include <execution>
#include <iterator>
#include <random>
#include <ranges>
#include <tuple>
#include <utility>
#include <vector>

#include <cassert>
#include <iostream>
#include <set>

#include <omp.h>
#include <sys/resource.h>
#include <unistd.h>

#include "MurmurHash3.cpp"
#include "serialization.cpp"
#include "util.cpp"

using namespace std;
std::random_device rd;
// mt19937 rd(hash<string>()("please work"));

using ll = long long;
using vi = vector<int>;

// space allocations in est_EA()
constexpr double S1_BUDGET = .5;
constexpr double S2_BUDGET = 1 - S1_BUDGET;

// space to R in cc_cost()
constexpr double R_BUDGET = .8;
constexpr double EST_EA_BUDGET = 1 - R_BUDGET;
constexpr double EST_EB_BUDGET = 1 - R_BUDGET;

struct exec_stats {
  string name;
  uint32_t pi_seed = 0;

  ll k = 0;
  // ll passes = 0;
  ll oracle_calls = 0;

  double allocated_space = 0;
  double allocated_R = 0;
  double allocated_S1 = 0;
  double allocated_S2 = 0;
  double allocated_N = 0;

  ll r = 0;
  ll t1 = 0;
  ll t2 = 0;
  ll nl = 0;
  double d_A = 0;
  double D = 0;

  ll t = 0;
  ll N_size = 0;
  ll nB = 0;
  ll n_in = 0;
  ll n_out = 0;
  ll surviving_groups = 0;

  double m_A = 0, m_B = 0, m_mis = 0;
};

struct baseline_stats {
  string name;
  uint32_t pi_seed = 0;
  ll k = 0;
  ll r = 0;

  ll n = 0;
  ll edges = 0;
  
  ll E_A = 0;
  ll E_B = 0;
  ll E_mis = 0;

  ll n_A = 0;
  ll n_B = 0;

  ll false_negatives = 0;
  ll false_positives = 0;

  ll oracle_calls = 0;
};

// write exec_stats as JSON to an output filestream (or any ostream)
inline void write_exec_stats_json(std::ostream &out, const exec_stats &s, bool comma = false) {
  out << std::fixed << std::setprecision(6);
  out << "{\n";
  out << "  \"name\": \"" << s.name << "\",\n";
  out << "  \"pi_seed\": " << s.pi_seed << ",\n";
  out << "  \"k\": " << s.k << ",\n";
  out << "  \"oracle_calls\": " << s.oracle_calls << ",\n";
  out << "  \"allocated_space\": " << s.allocated_space << ",\n";
  out << "  \"allocated_R\": " << s.allocated_R << ",\n";
  out << "  \"allocated_S1\": " << s.allocated_S1 << ",\n";
  out << "  \"allocated_S2\": " << s.allocated_S2 << ",\n";
  out << "  \"allocated_N\": " << s.allocated_N << ",\n";
  out << "  \"r\": " << s.r << ",\n";
  out << "  \"t1\": " << s.t1 << ",\n";
  out << "  \"t2\": " << s.t2 << ",\n";
  out << "  \"nl\": " << s.nl << ",\n";
  out << "  \"d_A\": " << s.d_A << ",\n";
  out << "  \"D\": " << s.D << ",\n";
  out << "  \"t\": " << s.t << ",\n";
  out << "  \"N_size\": " << s.N_size << ",\n";
  out << "  \"nB\": " << s.nB << ",\n";
  out << "  \"n_in\": " << s.n_in << ",\n";
  out << "  \"n_out\": " << s.n_out << ",\n";
  out << "  \"surviving_groups\": " << s.surviving_groups << ",\n";
  out << "  \"m_A\": " << s.m_A << ",\n";
  out << "  \"m_B\": " << s.m_B << ",\n";
  out << "  \"m_mis\": " << s.m_mis << "\n";
  out << "}" << (comma ? "," : "") << "\n";
}

// write baseline_stats as JSON to an output filestream (or any ostream)
inline void write_baseline_stats_json(std::ostream &out, const baseline_stats &b, bool comma = false) {
  out << std::fixed << std::setprecision(6);
  out << "{\n";
  out << "  \"name\": \"" << b.name << "\",\n";
  out << "  \"pi_seed\": " << b.pi_seed << ",\n";
  out << "  \"k\": " << b.k << ",\n";
  out << "  \"r\": " << b.r << ",\n";
  out << "  \"n\": " << b.n << ",\n";
  out << "  \"edges\": " << b.edges << ",\n";
  out << "  \"E_A\": " << b.E_A << ",\n";
  out << "  \"E_B\": " << b.E_B << ",\n";
  out << "  \"E_mis\": " << b.E_mis << ",\n";
  out << "  \"n_A\": " << b.n_A << ",\n";
  out << "  \"n_B\": " << b.n_B << ",\n";
  out << "  \"false_negatives\": " << b.false_negatives << ",\n";
  out << "  \"false_positives\": " << b.false_positives << ",\n";
  out << "  \"oracle_calls\": " << b.oracle_calls << "\n";
  out << "}" << (comma ? "," : "") << "\n";
}

// vi true_pivots, true_pivots_R, true_cc;
// int true_count_A = 0;
// double true_avg_degree = 0;

// int space_S1, space_S2, space_R, space_N;

// exec_stats stats;
// baseline_stats baseline;

struct element {
  int idx;
  reference_wrapper<const vi> _adj;

  const vi& adj() const { return _adj.get(); }
  bool edge(int v) const {
    auto it = std::lower_bound(adj().begin(), adj().end(), v);
    return it != adj().end() && *it == v;
  }
  bool edge(const element& v) const { return edge(v.idx); }

  bool operator==(const element& other) const {
    return idx == other.idx;
  }
};

struct TopSample {
  element elem;
  uint32_t cc;
  bool is_pivot;
};

// template <typename T>
// concept Rankable = requires(const T& x, const RankGenerator& rg) {
//   { rg(x) } -> std::same_as<Rank128>;
// };

// template<typename I>
// concept adj_iterator = std::forward_iterator<I> &&
//                        std::same_as<std::iter_value_t<I>, vi>;

namespace ranks {
using Rank128 = std::array<uint64_t, 2>;

struct RankGenerator {
  const uint32_t seed;
  RankGenerator(uint32_t seed) : seed(seed) {}

  Rank128 operator()(uint32_t i) const {
    Rank128 r;
    MurmurHash3_x86_128(&i, sizeof(i), seed, r.data());
    return r;
  }
  Rank128 operator()(const element& u) const { return (*this)(u.idx); }
  Rank128 operator()(const TopSample& u) const { return (*this)(u.elem); }

  // Comparator: compares x and y by (*this)(x) < (*this)(y)
  auto comp_by_rank() const {
    return [this](const auto& x, const auto& y) -> bool {
      if constexpr (std::is_same_v<std::decay_t<decltype(x)>, Rank128>) {
        // x is Rank128, y is element or TopSample
        return x < (*this)(y);
      } else if constexpr (std::is_same_v<std::decay_t<decltype(y)>, Rank128>) {
        // y is Rank128, x is element or TopSample
        return (*this)(x) < y;
      } else {
        // both are element/TopSample
        return (*this)(x) < (*this)(y);
      }
    };
  }
};
using RankComparator = decltype(std::declval<RankGenerator>().comp_by_rank());

template <class T>
bool comp_by_first(const pair<Rank128, T>& a, const pair<Rank128, T>& b) {
  return a.first < b.first;
}

template <class T>
struct rank_queue : priority_queue<pair<Rank128, T>,
                                   vector<pair<Rank128, T>>,
                                   decltype(&comp_by_first<T>)> {
  RankGenerator rank;
  rank_queue(RankGenerator rank)
      : priority_queue<pair<Rank128, T>,
                       vector<pair<Rank128, T>>,
                       decltype(&comp_by_first<T>)>(comp_by_first<T>),
        rank(rank) {}

  void push(const T& x) {
    priority_queue<pair<Rank128, T>, vector<pair<Rank128, T>>,
                   decltype(&comp_by_first<T>)>::push({rank(x), x});
  }
  void push(T&& x) {
    priority_queue<pair<Rank128, T>, vector<pair<Rank128, T>>,
                   decltype(&comp_by_first<T>)>::push({rank(x), std::move(x)});
  }
  const T& top() const {
    return priority_queue<pair<Rank128, T>, vector<pair<Rank128, T>>,
                          decltype(&comp_by_first<T>)>::top()
        .second;
  }
};

// find a rank v such that k of i=0..n-1 have rank(i) < v
Rank128 find_cutoff(size_t k, size_t n, RankGenerator rank) {
  priority_queue<Rank128> q;
  for (size_t i = 0; i < n; i++) {
    q.push(rank(i));

    if (q.size() > k)
      q.pop();
  }

  auto v = q.top();
  v[1]++;  // assume no collision
  return v;
}

// O(1) memory search for cutoff
Rank128 find_cutoff_binsearch(size_t k, size_t n, RankGenerator rank) {
  Rank128 v{};
  for (int z = 0; z < 2; z++)
    for (int i = 63; i >= 0; i--) {
      Rank128 x = v;
      x[z] |= 1ULL << i;

      size_t cnt = 0;
      for (size_t j = 0; j < n && cnt <= k; j++)
        cnt += rank(j) < x;

      if (cnt <= k)
        v = x;
    }

  return v;
}
};  // namespace ranks
using namespace ranks;

struct PivotStack {
  vector<element> u, n;
  int pivot = -1;

  PivotStack(element root) : u({root}), n({root}) {}

  bool consider_edge(int v_idx, int i, RankGenerator pi) {
    return pivot == -1 && (i + 1 == u.size() || pi(u[i + 1]) < pi(v_idx)) && pi(v_idx) < pi(n[i]);
  }
  void update(const element& v, RankGenerator pi) {
    if (pivot != -1)
      return;

    for (int i = max(0, int(u.size()) - 2); i < u.size(); i++) {
      auto it = lower_bound(v.adj().begin(), v.adj().end(), u[i].idx);
      if (it == v.adj().end() || *it != u[i].idx)
        continue;

      if ((i + 1 == u.size() || pi(u[i + 1]) < pi(v)) && pi(v) < pi(n[i]))
        n[i] = v;
    }
  }

  void update(const element& v, int i, RankGenerator pi) {
    if (pivot != -1)
      return;

    assert(i < u.size() && u.size() - i <= 2);

    if ((i + 1 == u.size() || pi(u[i + 1]) < pi(v)) && pi(v) < pi(n[i]))
      n[i] = v;
  }

  void finish_pass() {
    if (pivot != -1)
      return;

    assert(!u.empty());
    while (u.back().idx == n.back().idx) {
      if (u.size() <= 2) {
        pivot = u.back().idx;
        u.clear();
        n.clear();
        u.shrink_to_fit();
        n.shrink_to_fit();
        return;
      }
      u.pop_back(), n.pop_back();
      u.pop_back(), n.pop_back();

      assert(!u.empty());
    }

    int i = u.size() - 1;
    u.push_back(n[i]), n.push_back(n[i]);
    n[i] = u[i];
  }
};

template <size_t K>
pair<int, uint32_t> find_pivot_R_with_cc(const element& u, const vector<TopSample>& R, RankGenerator pi) {
  uint32_t cc = 0;

  auto R_cutoff = pi(R.back());

  auto Q = u.adj();
  Q.push_back(u.idx);
  size_t num_to_sort = min(Q.size(), size_t(K + 1));
  partial_sort(Q.begin(), Q.begin() + num_to_sort, Q.end(), pi.comp_by_rank());

  for (int v_idx : Q) {
    auto pi_v = pi(v_idx);

    if (pi_v > R_cutoff)
      return {-1, cc};  // v not in R

    if (v_idx == u.idx)
      return {u.idx, cc};  // u is pivot

    auto it = lower_bound(R.begin(), R.end(), pi_v, pi.comp_by_rank());
    assert(it != R.end() && it->elem.idx == v_idx);

    cc += 1 + it->cc;
    cc = min(cc, uint32_t(K + 1));

    if (cc > K)
      return {u.idx, cc};  // cc in R exceeds K

    if (it->is_pivot)
      return {v_idx, cc};  // found pivot within cc limit
  }

  cout << "should never happen?" << endl;
  assert(0);

  if (cc > K)
    return {u.idx, cc};

  return {};
}

template <size_t K>
int find_pivot_R(const element& u,
                 const vector<TopSample>& R,
                 RankGenerator pi) {
  return find_pivot_R_with_cc<K>(u, R, pi).first;
}


template <size_t K, double SPACE_BUDGET, bool ONLY_B = false>
double simple_sampler_node(const vector<element>& sigma,
              RankGenerator pi,
              // const vector<TopSample>& R,
              exec_stats& stats) {
  const size_t n = sigma.size();

  size_t t = n * SPACE_BUDGET;
  t = min(t, n);

  RankGenerator h(rd());
  rank_queue<element> q(h);
  
  int nl = 0;

  for (const auto& u : sigma) {
    // if constexpr (ONLY_B) {
    //   if (find_pivot_R<K>(u, R, pi) != -1)
    //     continue;
    // }

    nl++;
    q.push(u);
    if (q.size() > t)
      q.pop();
  }

  vector<element> S;
  S.reserve(q.size());  
  for (; !q.empty(); q.pop())
    S.push_back(q.top());

  reverse(S.begin(), S.end()); // S sorted by h

  vector<PivotStack> stks;
  stks.reserve(S.size());
  for (const auto& u : S)
    stks.emplace_back(u);


  vector<std::array<int, 3>> loc_stks;
  for (int pass = 0; pass < K; pass++) {
    // cout << "running pass " << pass + 1 << endl;

    for (int i = 0; i < S.size(); i++)
      for (int p = max(0, int(stks[i].u.size()) - 2); p < stks[i].u.size(); p++)
        for (int v_idx : stks[i].u[p].adj())
          if (stks[i].consider_edge(v_idx, p, pi)) 
            loc_stks.push_back({v_idx, i, p});

    sort(std::execution::par_unseq, loc_stks.begin(), loc_stks.end());


    for (const auto& u : sigma) {
      auto it =
          lower_bound(loc_stks.begin(), loc_stks.end(), std::array{u.idx, 0, 0});
      for (; it != loc_stks.end() && (*it)[0] == u.idx; it++) {
        int i = (*it)[1], p = (*it)[2];
        stks[i].update(u, p, pi);
      }  
    }

    loc_stks.clear();
    loc_stks.shrink_to_fit();

    for (int i = 0; i < S.size(); i++)
      stks[i].finish_pass();
  }

  map<int, int> S_pivots;
  vector<int> pivots(S.size());
  for (int i = 0; i < S.size(); i++) {
    const auto& u = S[i];
    pivots[i] = stks[i].pivot;
    if (pivots[i] == -1)
      pivots[i] = u.idx;

    S_pivots[u.idx] = pivots[i];
  }

  ll count = 0;

  for (int i = 0; i < S.size(); i++) {
    const auto& u = S[i];
    
    for (int v_idx : u.adj()) {
      if (u.idx >= v_idx) continue;

      auto h_v = h(v_idx);
      if (h_v > h(S.back())) continue;

      auto it = S_pivots.find(v_idx);
      if (it == S_pivots.end())
        continue;

      if (pivots[i] != it->second)
        count++;
      else
        count--; 

      // auto it = lower_bound(S.begin(), S.end(), h_v, h.comp_by_rank());

      // if (it == S.end() || it->idx != v_idx)
      //   continue;

      // int j = distance(S.begin(), it);
      // if (pivots[i].first != pivots[j].first)
      //   count++;
      // else
      //   count--;
    }
  }

  sort(std::execution::par_unseq, pivots.begin(), pivots.end());

  for (auto it = pivots.begin(), nxt = it; it != pivots.end(); it = nxt) {
    nxt = find_if(it, pivots.end(), [it](const auto& x) { return x != *it; });
    int sz = distance(it, nxt);
    count += ll(sz) * (sz - 1) / 2;
  }
  double ratio = nl / double(S.size());
  double res = count * ratio * ratio;

  cout << "sampled " << S.size() << " nodes, estimated cost " << res << endl;

  stats.name = "simple_sampler_node";
  stats.pi_seed = pi.seed;
  stats.k = K;
  stats.r = t;
  stats.allocated_space = SPACE_BUDGET;
  stats.m_mis = res;

  return res;
}

template <size_t K, double SPACE_BUDGET, bool ONLY_B = false>
double simple_sampler_edge(const vector<element>& sigma,
              RankGenerator pi,
              // const vector<TopSample>& R,
              exec_stats& stats) {
  const size_t n = sigma.size();

  size_t t = n * SPACE_BUDGET;
  t = min(t, n);
  if (t % 2) t--;

  RankGenerator h(rd());
  rank_queue<element> q(h);
  
  int nl = 0;

  for (const auto& u : sigma) {
    // if constexpr (ONLY_B) {
    //   if (find_pivot_R<K>(u, R, pi) != -1)
    //     continue;
    // }

    nl++;
    q.push(u);
    if (q.size() > t)
      q.pop();
  }

  vector<element> S;
  S.reserve(q.size());  
  for (; !q.empty(); q.pop())
    S.push_back(q.top());

  reverse(S.begin(), S.end()); // S sorted by h

  vector<PivotStack> stks;
  stks.reserve(S.size());
  for (const auto& u : S)
    stks.emplace_back(u);


  vector<std::array<int, 3>> loc_stks;
  for (int pass = 0; pass < K; pass++) {
    // cout << "running pass " << pass + 1 << endl;

    for (int i = 0; i < S.size(); i++)
      for (int p = max(0, int(stks[i].u.size()) - 2); p < stks[i].u.size(); p++)
        for (int v_idx : stks[i].u[p].adj())
          if (stks[i].consider_edge(v_idx, p, pi)) 
            loc_stks.push_back({v_idx, i, p});

    sort(std::execution::par_unseq, loc_stks.begin(), loc_stks.end());


    for (const auto& u : sigma) {
      auto it =
          lower_bound(loc_stks.begin(), loc_stks.end(), std::array{u.idx, 0, 0});
      for (; it != loc_stks.end() && (*it)[0] == u.idx; it++) {
        int i = (*it)[1], p = (*it)[2];
        stks[i].update(u, p, pi);
      }  
    }

    loc_stks.clear();
    loc_stks.shrink_to_fit();

    for (int i = 0; i < S.size(); i++)
      stks[i].finish_pass();
  }

  ll count = 0;
  for (int i = 0; i < S.size(); i += 2) {
    const auto& u = S[i];
    const auto& v = S[i+1];
    int p_u = stks[i].pivot;
    if (p_u == -1) p_u = u.idx;
    int p_v = stks[i+1].pivot;
    if (p_v == -1) p_v = v.idx;

    if (find(u.adj().begin(), u.adj().end(), v.idx) != u.adj().end()) {
      count += p_u != p_v;
    }
    else {
      count += p_u == p_v;
    }
  }

  ll all_pairs = sigma.size() * ll(sigma.size() - 1) / 2;
  double res = all_pairs * (count / (S.size() / 2.0));

  cout << "sampled " << S.size() / 2 << " edges, count = " << count << " estimated cost " << res << endl;

  stats.name = "simple_sampler_edge";
  stats.pi_seed = pi.seed;
  stats.k = K;
  stats.r = t;
  stats.allocated_space = SPACE_BUDGET;
  stats.m_mis = res;

  return res;
}

template <size_t K, double SPACE_BUDGET>
double est_EA(const vector<element>& sigma, RankGenerator pi, const vector<TopSample>& R, exec_stats& stats) {
  const size_t n = sigma.size();

  constexpr double S1_ALLOC = min(1., SPACE_BUDGET * S1_BUDGET);
  constexpr double S2_ALLOC = min(1., SPACE_BUDGET * S2_BUDGET);
  stats.allocated_S1 = S1_ALLOC;
  stats.allocated_S2 = S2_ALLOC;

  // size_t t1 = ceil(12.0 / EPS / EPS * pow(n, 1 - BETA) * log2(n));
  size_t t1 = ceil(n * S1_ALLOC);
  // size_t n2 = 32 / EPS / EPS * pow(n, 2*BETA) * log2(n);
  size_t t2 = ceil(n * S2_ALLOC);

  t1 = min(t1, n), t2 = min(t2, n);

  vector<element> S1, S2;

  RankGenerator h1(rd()), h2(rd());
  rank_queue<element> q1(h1), q2(h2);

  map<int, int> S1_pivots;
  vector<int> S1_pivots_seen;

  // pass 1: reservoir sample S1
  for (const auto& u : sigma) {
    q1.push(u);

    if (q1.size() > t1)
      q1.pop();
  }

  auto h1_cutoff = h1(q1.top());

  S1.reserve(q1.size());

  for (; !q1.empty(); q1.pop()) {
    auto u = q1.top();

    S1.push_back(u);
    auto p_u = find_pivot_R<K>(u, R, pi);
    S1_pivots[u.idx] = p_u;

    if (p_u != -1)
      S1_pivots_seen.push_back(p_u);
  }
  
  stats.t1 = S1.size();

  // build frequency table of pivots in S1
  vector<pair<int, int>> S1_pivots_freq;
  sort(std::execution::par_unseq, S1_pivots_seen.begin(), S1_pivots_seen.end());
  for (auto it = S1_pivots_seen.begin(), nxt = it; it != S1_pivots_seen.end();
       it = nxt) {
    nxt = find_if(it, S1_pivots_seen.end(), [it](int x) { return x != *it; });
    S1_pivots_freq.emplace_back(*it, std::distance(it, nxt));
  }

  // const double lim_Xj = log2(n) / EPS / EPS;
  const int lim_Xj = 100;

  ll d_A = 0;
  int nl = 0;

  for (const auto& u : sigma) {
    auto p_u = find_pivot_R<K>(u, R, pi);

    // Xj = |{v in S1 | u ~ v XOR (p_u == p_v && p_u != null)}|
    ll Xj = 0;

    // # Xj += |{v in S1 | p_u == p_v && p_u != null}|
    if (p_u != -1) {
      auto it = lower_bound(S1_pivots_freq.begin(), S1_pivots_freq.end(), p_u,
                            [](const auto& x, int y) { return x.first < y; });

      if (it != S1_pivots_freq.end() && it->first == p_u) {
        Xj += it->second;
        if (h1(u.idx) <= h1_cutoff)
          Xj--;  // u in S1, dont count itself
      }
    }

    // Xj += |{v in S1 | u ~ v && p_u != p_v}|
    // Xj -= |{v in S1 | u ~ v && p_u == p_v && p_u != null}|
    for (int v_idx : u.adj()) {
      if (h1(v_idx) > h1_cutoff)
        continue;

      auto it = S1_pivots.find(v_idx);
      assert(it != S1_pivots.end());

      int p_v = it->second;

      Xj += p_u != p_v ? 1 : -(p_u != -1);
    }

    if (Xj > lim_Xj) {
      if (n == S1.size())
        d_A += Xj;
      else
        d_A += Xj * double(n) / S1.size();
    } else {
      nl++;

      q2.push(u);
      if (q2.size() > t2)
        q2.pop();
    }
  }

  auto h2_cutoff = h2(q2.top());
  S2.reserve(q2.size());

  for (; !q2.empty(); q2.pop())
    S2.push_back(q2.top());

  stats.t2 = S2.size();

  map<int, int> S2_pivots;
  vector<int> S2_pivots_seen;

  for (int i = 0; i < S2.size(); i++) {
    const auto& u = S2[i];
    auto p_u = find_pivot_R<K>(u, R, pi);
    S2_pivots[u.idx] = p_u;
    if (p_u != -1)
      S2_pivots_seen.push_back(p_u);
  }

  // build frequency table of pivots in S2
  vector<pair<int, int>> S2_pivots_freq;
  sort(std::execution::par_unseq, S2_pivots_seen.begin(), S2_pivots_seen.end());
  for (auto it = S2_pivots_seen.begin(), nxt = it; it != S2_pivots_seen.end();
       it = nxt) {
    nxt = find_if(it, S2_pivots_seen.end(), [it](int x) { return x != *it; });
    S2_pivots_freq.emplace_back(*it, std::distance(it, nxt));
  }

  // D = |{(u, v) in sigma x S2 | u ~ v XOR (p_u == p_v && p_u != null)}|
  ll D = 0;
  for (const auto& u : sigma) {
    auto p_u = find_pivot_R<K>(u, R, pi);

    // Yj = |{v in S2 | u ~ v XOR (p_u == p_v && p_u != null)}|
    ll Yj = 0;

    // # Yj += |{v in S2 | p_u == p_v && p_u != null}|
    if (p_u != -1) {
      auto it = lower_bound(S2_pivots_freq.begin(), S2_pivots_freq.end(), p_u,
                            [](const auto& x, int y) { return x.first < y; });
      if (it != S2_pivots_freq.end() && it->first == p_u) {
        Yj += it->second;

        if (S2_pivots.contains(u.idx))
          Yj--;  // u in S2, dont count itself
      }
    }

    // Yj += |{v in S2 | u ~ v && p_u != p_v}|
    // Yj -= |{v in S2 | u ~ v && p_u == p_v && p_u != null}|
    for (int v_idx : u.adj()) {
      if (h2(v_idx) > h2_cutoff)
        continue;  // quick prune

      auto it = S2_pivots.find(v_idx);
      if (it == S2_pivots.end())
        continue;
      int p_v = it->second;

      Yj += p_u != p_v ? 1 : -(p_u != -1);
    }

    D += Yj;
  }

  cout << "t1:    " << t1 << endl;
  cout << "|S1|:  " << S1.size() << endl;
  cout << "t2:    " << t2 << endl;
  cout << "|S2|:  " << S2.size() << endl;
  cout << "nl:    " << nl << endl;
  cout << "d_A:   " << d_A << endl;
  cout << "D:     " << D << endl;

  stats.d_A = d_A;
  stats.D = D;
  stats.nl = nl;

  // space_S1 = S1.size();
  // space_S2 = S2.size();

  if (nl == S2.size())
    d_A += D;
  else
    d_A += D * double(nl) / S2.size();

  cout << "est: " << d_A / 2 << endl;

  return d_A / 2;
}

template <size_t K, double SPACE_BUDGET>
double est_EB(const vector<element>& sigma, RankGenerator pi, const vector<TopSample>& R, exec_stats& stats) {
  const size_t n = sigma.size();

  // size_t t = ceil(8.0 / EPS / EPS * pow(n, 3*BETA) * log2(n));
  constexpr double S_ALLOC = min(1., SPACE_BUDGET);

  size_t t = ceil(n * S_ALLOC);
  t = min(t, n);

  size_t nB = 0;

  RankGenerator h_S(rd());
  rank_queue<element> q_S(h_S);

  ll N_size = 0;

  // pass 1
  for (const auto& u : sigma)
    if (find_pivot_R<K>(u, R, pi) == -1) {
      nB++;

      N_size += u.adj().size() + 1;
      q_S.push(u);

      while (N_size > t && t != n) {
        N_size -= q_S.top().adj().size() + 1;
        q_S.pop();
      }
    }

  auto h_S_cutoff = h_S(q_S.top());

  vector<element> S;
  S.reserve(q_S.size());

  while (!q_S.empty()) {
    S.push_back(q_S.top());
    q_S.pop();
  }

  vector<std::array<int, 3>> loc_N;

  vector<vector<element>> N(S.size());
  vector<vector<PivotStack>> stks2(S.size());
  for (int i = 0; i < S.size(); i++) {
    N[i].reserve(S[i].adj().size() + 1);
    stks2[i].reserve(S[i].adj().size() + 1);

    N[i].resize(S[i].adj().size(), S[i]);

    for (int j = 0; j < S[i].adj().size(); j++)
      loc_N.push_back({S[i].adj()[j], i, j});

    N[i].push_back(S[i]);
  }

  sort(std::execution::par_unseq, loc_N.begin(), loc_N.end());

  // pass 2
  for (const auto& u : sigma) {
    auto it = lower_bound(loc_N.begin(), loc_N.end(), std::array{u.idx, 0, 0});
    for (; it != loc_N.end() && (*it)[0] == u.idx; it++) {
      int i = (*it)[1], j = (*it)[2];
      N[i][j] = u;
    }
  }

  for (int i = 0; i < S.size(); i++) {
    for (const auto& u : N[i])
      stks2[i].emplace_back(u);
  }

  vector<std::array<int, 4>> loc_stks;
  for (int pass = 0; pass < K+1; pass++) {
    cout << "running pass " << pass << endl;

    for (int i = 0; i < N.size(); i++)
      for (int j = 0; j < N[i].size(); j++)
        if (stks2[i][j].pivot == -1)
          for (int p = max(0, int(stks2[i][j].u.size()) - 2); p < stks2[i][j].u.size(); p++)
            for (int v_idx : stks2[i][j].u[p].adj())
              if (stks2[i][j].consider_edge(v_idx, p, pi)) 
                loc_stks.push_back({v_idx, i, j, p});

    sort(std::execution::par_unseq, loc_stks.begin(), loc_stks.end());

    // pass 3..K+2
    for (const auto& u : sigma) {
      auto it = lower_bound(loc_stks.begin(), loc_stks.end(), std::array{u.idx, 0, 0, 0});
      for (; it != loc_stks.end() && (*it)[0] == u.idx; it++) {
        int i = (*it)[1], j = (*it)[2], p = (*it)[3];

        stks2[i][j].update(u, p, pi);
      }
    }

    loc_stks.clear();
    loc_stks.shrink_to_fit();

    ll tot_size_stks = 0;
    for (int i = 0; i < N.size(); i++)
      for (int j = 0; j < N[i].size(); j++) {
        stks2[i][j].finish_pass();

        if (stks2[i][j].pivot == -1)
          tot_size_stks += stks2[i][j].u.size() - 1;
      }

    ll old_S_size = S.size();
    while (tot_size_stks > t && S.size() > 1 && t != n) {
      N_size -= S.back().adj().size() + 1;

      for (auto& stk : stks2.back())
        if (stk.pivot == -1)
          tot_size_stks -= stk.u.size() - 1;

      S.pop_back();
      N.pop_back();
      stks2.pop_back();
    }

    if (S.size() != old_S_size) {
      cout << "Shrunk S from " << old_S_size << " -> " << S.size() << endl;
    }
  }

  loc_N.clear();
  loc_N.shrink_to_fit();

  auto comp_by_idx = [&](const auto& x, const auto& y) {
    return x.idx < y.idx;
  };

  for (int i = 0; i < N.size(); i++) {
    vector<element> nxt;

    for (int j = 0; j < N[i].size(); j++) {
      const auto& u = N[i][j];

      int p_u = stks2[i][j].pivot;
      if (p_u == -1)
        p_u = N[i][j].idx;

      if (p_u == S[i].idx)
        nxt.push_back(u);
    }

    swap(N[i], nxt);

    sort(std::execution::par_unseq, N[i].begin(), N[i].end(), comp_by_idx);
    for (int j = 0; j < N[i].size(); j++)
      for (int v_idx : N[i][j].adj())
        loc_N.push_back({v_idx, i, j});
  }

  ll n_in = 0, n_out = 0;

  sort(std::execution::par_unseq, loc_N.begin(), loc_N.end());

  // pass k+3
  for (const auto& u : sigma) {
    int p_u = find_pivot_R<K>(u, R, pi);
    if (p_u != -1)
      continue;

    auto it = lower_bound(loc_N.begin(), loc_N.end(), std::array{u.idx, 0, 0});
    for (; it != loc_N.end() && (*it)[0] == u.idx; it++) {
      int i = (*it)[1];

      auto it2 = lower_bound(N[i].begin(), N[i].end(), u, comp_by_idx);

      if (it2 == N[i].end() || it2->idx != u.idx)
        n_out += 1;
    }
  }

  for (int i = 0; i < N.size(); i++) {
    n_in += ll(N[i].size()) * (N[i].size() - 1) / 2;  // unordered?

    set<int> have;
    for (const auto& u : N[i])
      have.insert(u.idx);

    for (const auto& u : N[i]) {
      for (int v_idx : u.adj())
        if (v_idx < u.idx) {
          n_in -= have.contains(v_idx);
        }
    }
  }

  double sum_X = 0, sum_X2 = 0;
  double cnt = 0;
  for (int i = 0; i < S.size(); i++) {
    if (N[i].size() == 0)
      continue;
    double X = N[i].size();

    sum_X += X;
    sum_X2 += X * X;
    cnt += 1;
  }

  double avg = sum_X / cnt;
  double var = sqrt((sum_X2 - sum_X * sum_X / cnt) / (cnt - 1));

  cout << "sampled groups:      " << cnt << "\n";
  cout << "average group size:  " << avg << "\n";
  cout << "std. dev. of size:   " << var << "\n";
  cout.flush();

  int max_deg = 0;

  double avg_group = 0, avg_deg = 0;
  int count = 0;

  assert(S.size() == N.size());

  for (int i = 0; i < S.size(); i++) {
    max_deg = max(max_deg, (int)S[i].adj().size());

    if (N[i].empty())
      continue;

    avg_group += N[i].size();
    avg_deg += S[i].adj().size();
    count++;
  }

  avg_group /= count;
  avg_deg /= count;

  cout << "num groups remaining:  " << count << endl;
  cout << "avg group size:        " << avg_group << endl;
  // cout << "avg deg surviving: " << avg_deg << endl;
  cout << "max deg in S:          " << max_deg << endl;

  cout << "n_in:  " << n_in << endl;
  cout << "n_out: " << n_out << endl;
  cout << "nB:    " << nB << endl;
  cout << "|S|:   " << S.size() << endl;
  cout << "t:     " << t << endl;
  cout << "N_size:" << N_size << endl;

  stats.allocated_N = S_ALLOC;
  stats.nB = nB;
  stats.n_in = n_in;
  stats.n_out = n_out;
  stats.N_size = N_size;
  stats.t = S.size();
  stats.surviving_groups = count;

  // space_N = N_size;

  if (S.empty())
    return 0;

  return (n_in + n_out / 2.) * ((double)nB / S.size());
}

template <size_t K, double SPACE_BUDGET>
tuple<double, double, double> cc_cost(const vector<element>& sigma, RankGenerator pi, exec_stats& stats, size_t r = -1) {
  assert(stats.k == 0);

  const size_t n = sigma.size();

  cout << "entering cc cost with n = " << n << " and r = " << r << endl;

  if (r == -1) {
    constexpr double R_ALLOC = min(1., SPACE_BUDGET * R_BUDGET);
    stats.allocated_R = R_BUDGET;
    r = ceil(n * R_ALLOC);
  }
  else 
    stats.allocated_R = double(r) / n;
  
  r = min(r, n);
  stats.r = r;

  vector<TopSample> R;
  R.reserve(r);

  // RankGenerator pi(rd());
  Rank128 r_cutoff = find_cutoff(r, n, pi);

  size_t j = 0;
  for (const auto& u : sigma)
    if (pi(u) <= r_cutoff)
      R.emplace_back(u, 0, 0);

  std::sort(std::execution::par_unseq, R.begin(), R.end(), pi.comp_by_rank());

  map<int, int> loc_R;
  for (int i = 0; i < R.size(); i++)
    loc_R[R[i].elem.idx] = i;

  for (int i = 0; i < R.size(); i++) {
    const auto& u = R[i].elem;

    auto [p_u, cc] = find_pivot_R_with_cc<K>(u, R, pi);

    R[i].cc = cc;
    R[i].is_pivot = (p_u == u.idx);

    assert(cc == R[i].cc);

    // assert(R[i].is_pivot == (true_pivots[u.idx] == u.idx));
    // if (R[i].is_pivot) {
    //   if (true_pivots_R[u.idx] != u.idx) {
    //     cout << "at " << u.idx << " found self pivot but true_pivot_R is "
    //          << true_pivots_R[u.idx] << endl;
    //   }

    //   assert(true_pivots_R[u.idx] == u.idx);
    // }
    // assert(R[i].is_pivot == (true_pivots_R[u.idx] == u.idx));
    // assert(R[i].cc == true_cc[u.idx]);
  }

  std::cout << "finished find pivot on R" << std::endl;

  auto m_A = est_EA<K, SPACE_BUDGET * EST_EA_BUDGET>(sigma, pi, R, stats);
  auto m_B = est_EB<K, SPACE_BUDGET * EST_EB_BUDGET>(sigma, pi, R, stats);

  // auto m_B_simple = simple_sampler<K, SPACE_BUDGET, true>(sigma, pi, R);

  // cout << "simple m_B: " << m_B_simple << endl;

  // space_R = R.size();

  stats.m_A = m_A;
  stats.m_B = m_B;
  stats.m_mis = m_A + m_B;
  stats.r = r;
  stats.k = K;
  stats.pi_seed = pi.seed;
  stats.allocated_space = SPACE_BUDGET;
  stats.name = "cc-cost";

  // return (m_A + m_B + 3.0 / 8 * EPS * pow(n, 1 - BETA)) / (1 - EPS / 8);
  return {m_A + m_B, m_A, m_B};
}

template <size_t K>
tuple<ll, ll, ll> pruned_pivot_cc_cost(vector<element> sigma, RankGenerator pi, baseline_stats& baseline, int r) {
  const size_t n = sigma.size();

  r = min(r, (int)n);

  auto orig_sigma = sigma;

  // true_pivots_R.resize(n, -1);
  // true_pivots.resize(n, -1);
  // true_cc.resize(n, -1);

  vector<int> cc(n), pivot(n, -1), inv(n);
  vector<bool> in_A(n, 1);

  sort(std::execution::par_unseq, sigma.begin(), sigma.end(), pi.comp_by_rank());

  for (int i = 0; i < n; i++)
    inv[sigma[i].idx] = i;

  vector<element> R = {sigma.begin(), sigma.begin() + r};
  // auto old_R = R;
  // sort(std::execution::par_unseq, R.begin(), R.end(), pi.comp_by_rank());

  // for (int i = 0; i < R.size(); i++)
  //   assert(R[i].idx == old_R[i].idx);

  // for (int i = 0; i < r; i++) {
  //   const auto& u = sigma[i];
  //   assert(inv[u.idx] < r);
  //   assert(pi(u) <= pi(R.back()));
  // }

  // for (int i = r; i < n; i++) {
  //   const auto& u = sigma[i];
  //   assert(inv[u.idx] >= r);
  //   assert(pi(u) > pi(R.back()));
  // }

  for (int i = 0; i < n; i++) {
    const auto& u = sigma[i];
    assert((pi(u) <= pi(R.back())) == (inv[u.idx] < r));
    assert(is_sorted(u.adj().begin(), u.adj().end()));
  }

  auto find_pivot_in_R = [&](const element& u, int r, vector<int>& cc) -> int {
    auto Q = u.adj();
    Q.push_back(u.idx);
    if constexpr (K == 0) {
      sort(Q.begin(), Q.end(), pi.comp_by_rank());
    }
    else {
      size_t num_to_sort = min(Q.size(), size_t(K + 1));
      partial_sort(Q.begin(), Q.begin() + num_to_sort, Q.end(), pi.comp_by_rank());
    }

    assert((r == n) || ((pi(u) <= pi(R.back())) == (inv[u.idx] < r)));

    cc[u.idx] = 0;

    for (int v_idx : Q) {
      assert(r == n);
      if (inv[v_idx] >= r)
        return -1;

      if (v_idx == u.idx)
        return u.idx;

      cc[u.idx] += 1 + cc[v_idx];
      cc[u.idx] = min(cc[u.idx], int(K + 1));

      if constexpr (K) {
        if (cc[u.idx] > K)
          return u.idx;
      }

      if (pivot[v_idx] == v_idx)
        return v_idx;
    }

    cout << "should never happen" << endl;
    assert(0);

    return -2;
  };

  vector<int> fake_cc(n);
  vector<int> pivots_R(n, -1);
  
  for (auto& u : sigma) {
    pivot[u.idx] = find_pivot_in_R(u, n, fake_cc);
    pivots_R[u.idx] = find_pivot_in_R(u, r, cc);

    in_A[u.idx] = (pivots_R[u.idx] != -1);
  }

  // true_pivots = pivot;
  // true_pivots_R = pivots_R;
  // true_cc = fake_cc;
  // for (auto& u : sigma) {
  //   true_pivots[u.idx] = pivot[u.idx] = find_pivot_in_R(u, n, fake_cc);
  //   true_pivots_R[u.idx] = find_pivot_in_R(u, r, cc);

  //   in_A[u.idx] = (true_pivots_R[u.idx] != -1);
  // }

  // true_cc = cc;

  vector<vector<int>> buckets(n);

  for (int i = 0; i < n; i++)
    buckets[pivot[i]].push_back(i);

  using ll = long long;
  ll E_mis = 0, E_A = 0, E_B = 0;
  ll missing_edges = 0, extra_edges = 0;

  for (int i = 0; i < n; i++) {
    E_mis += ll(buckets[i].size()) * (buckets[i].size() - 1) / 2;

    missing_edges += ll(buckets[i].size()) * (buckets[i].size() - 1) / 2;
  }

  for (auto& u : sigma)
    for (int v_idx : u.adj())
      if (u.idx < v_idx) {
        E_mis += pivot[u.idx] != pivot[v_idx] ? 1 : -1;

        if (pivot[u.idx] != pivot[v_idx])
          extra_edges++;
        else
          missing_edges--;
      }

  for (auto& u : sigma)
    for (int v_idx : u.adj())
      if (u.idx < v_idx && pivot[u.idx] != pivot[v_idx]) {
        if (in_A[u.idx] || in_A[v_idx])
          E_A++;
        else
          E_B++;
      }

  // set<pair<int, int>> edgeset;
  // for (auto& u : sigma)
  //   for (int v_idx : u.adj())
  //     if (u.idx < v_idx)
  //       edgeset.insert(pair{u.idx, v_idx});

  for (auto& Ni : buckets) {
    for (auto u : Ni)
      for (auto v : Ni)
        if (u < v) {
          auto it = lower_bound(orig_sigma[u].adj().begin(),
                                orig_sigma[u].adj().end(), v);
          if (it == orig_sigma[u].adj().end() || *it != v) {
            if (in_A[u] || in_A[v])
              E_A++;
            else
              E_B++;
          }
        }
  }

  double sum_X = 0, sum_X2 = 0;
  double sum_X_A = 0, sum_X2_A = 0;
  double sum_X_B = 0, sum_X2_B = 0;
  double sum_Y = 0, sum_Y2 = 0;
  double sum_Y_A = 0, sum_Y2_A = 0;
  double sum_Y_B = 0, sum_Y2_B = 0;
  double cnt = 0;
  double cnt_A = 0;
  double cnt_B = 0;

  for (int i = 0; i < n; i++) {
    if (buckets[i].empty())
      continue;

    double X = buckets[i].size();
    sum_X += X;
    sum_X2 += X * X;
    cnt += 1;

    ll internal = 0;
    ll external = 0;

    for (auto& u : buckets[i])
      for (auto& v : buckets[i])
        if (u < v) {
          auto it = lower_bound(orig_sigma[u].adj().begin(),
                                orig_sigma[u].adj().end(), v);
          if (it == orig_sigma[u].adj().end() || *it != v)
            internal += 2;
        }

    for (int u_idx : buckets[i])
      for (int v_idx : orig_sigma[u_idx].adj())
        if (pivot[v_idx] != pivot[u_idx] && !in_A[v_idx])
          external += 1;

    double Y = (internal + external) / 2.0;
    sum_Y += Y;
    sum_Y2 += Y * Y;

    // cerr << X << "," << Y << " ";

    if (in_A[i]) {
      sum_X_A += X;
      sum_X2_A += X * X;

      sum_Y_A += Y;
      sum_Y2_A += Y * Y;

      cnt_A += 1;
    } else {
      sum_X_B += X;
      sum_X2_B += X * X;

      sum_Y_B += Y;
      sum_Y2_B += Y * Y;

      cnt_B += 1;
    }
  }

  // cerr << endl;

  double avg_X = sum_X / cnt;
  double var_X = sqrt((sum_X2 - sum_X * sum_X / cnt) / cnt);
  double avg_Y = sum_Y / cnt;
  double var_Y = sqrt((sum_Y2 - sum_Y * sum_Y / cnt) / cnt);

  cout << "\n";
  cout << "all clusters" << "\n";
  cout << "# of groups: " << cnt << "\n";
  cout << "avg. size:   " << avg_X << "\n";
  cout << "std. dev. :  " << var_X << "\n";
  // cout << "avg. edges:        " << avg_Y << "\n";
  // cout << "std. dev. edges:   " << var_Y << "\n";
  cout << "\n";

  double avg_X_A = sum_X_A / cnt_A;
  double var_X_A = sqrt((sum_X2_A - sum_X_A * sum_X_A / cnt_A) / cnt_A);
  double avg_Y_A = sum_Y_A / cnt_A;
  double var_Y_A = sqrt((sum_Y2_A - sum_Y_A * sum_Y_A / cnt_A) / cnt_A);

  cout << "clusters in A" << "\n";
  cout << "# of groups: " << cnt_A << "\n";
  cout << "avg. size:   " << avg_X_A << "\n";
  cout << "std. dev. :  " << var_X_A << "\n";
  // cout << "avg. edges:        " << avg_Y_A << "\n";
  // cout << "std. dev. edges:   " << var_Y_A << "\n";
  cout << "\n";

  double avg_X_B = sum_X_B / cnt_B;
  double var_X_B = sqrt((sum_X2_B - sum_X_B * sum_X_B / cnt_B) / cnt_B);
  double avg_Y_B = sum_Y_B / cnt_B;
  double var_Y_B = sqrt((sum_Y2_B - sum_Y_B * sum_Y_B / cnt_B) / cnt_B);

  cout << "clusters in B" << "\n";
  cout << "# of groups: " << cnt_B << "\n";
  cout << "avg. size:   " << avg_X_B << "\n";
  cout << "std. dev. :  " << var_X_B << "\n";

  cout << "avg. edges:        " << avg_Y_B << "\n";
  cout << "std. dev. edges:   " << var_Y_B << "\n";
  cout << "\n";

  cout.flush();

  // cout << "% E_mis check " << (sum_Y/2 - E_mis) / E_mis << "\n";
  // cout << "% E_A check " << (sum_Y_A - E_A) / E_A << "\n";
  // cout << "% E_B check " << (sum_Y_B - E_B) / E_B << "\n";
  assert(E_A + E_B == E_mis);

  cout << "extra edges:   " << extra_edges << endl;
  cout << "missing edges: " << missing_edges << endl;

  ll n_A = 0;
  ll edges = 0;
  ll unlucky = 0;
  for (int i = 0; i < n; i++) {
    n_A += in_A[i];
    // true_count_A += in_A[i];
    edges += sigma[i].adj().size();

    unlucky += K && fake_cc[i] > K;
  }
  ll n_B = n - n_A;

  baseline.k = K;
  baseline.r = r;
  baseline.pi_seed = pi.seed;
  baseline.E_A = E_A;
  baseline.E_B = E_B;
  baseline.E_mis = E_mis;
  baseline.n = n;
  baseline.false_positives = extra_edges;
  baseline.false_negatives = missing_edges;
  baseline.n_A = n_A;
  baseline.n_B = n_B;
  baseline.edges = edges;
  
  return {E_A + E_B, E_A, E_B};
}

template <size_t ITERS, size_t K, double SPACE_BUDGET>
std::array<exec_stats, ITERS> run_cc_cost_exp(const std::vector<element> &sigma, ranks::RankGenerator pi, size_t r) {
  std::array<exec_stats, ITERS> stats{};
  for (int i = 0; i < ITERS; i++) {
    stats[i] = exec_stats();
    stats[i].name = "cc_cost_" + to_string(K);
    cc_cost<K, SPACE_BUDGET>(sigma, pi, stats[i], r);
  }
  return stats;
}

template<size_t K>
baseline_stats pruned_pivot_stats(const std::vector<element> &sigma, ranks::RankGenerator pi, size_t r) {
  baseline_stats baseline{};
  baseline.name = "pruned_pivot";
  pruned_pivot_cc_cost<K>(sigma, pi, baseline, r);
  return baseline;
}

baseline_stats pivot_stats(const std::vector<element> &sigma, ranks::RankGenerator pi, size_t r) {
  auto res = pruned_pivot_stats<0>(sigma, pi, r);
  res.name = "pivot";
  return res;
}

template <size_t ITERS, size_t K, double SPACE_BUDGET>
auto run_experiment_fixed_space(const vector<element>& sigma, ranks::RankGenerator pi) {
  cout << "Running experiment with K=" << K << ", SPACE = " << SPACE_BUDGET << ", ITERS = " << ITERS << endl; 
  const int n = sigma.size();

  constexpr double R_ALLOC = min(1., SPACE_BUDGET * R_BUDGET);
  int r = ceil(n * R_ALLOC);
  r = min(r, n);

  auto pivot = pivot_stats(sigma, pi, r);
  auto pruned_pivot = pruned_pivot_stats<K>(sigma, pi, r);
  auto cc_cost_res = run_cc_cost_exp<ITERS, K, SPACE_BUDGET>(sigma, pi, r);
  return tuple{pivot, pruned_pivot, cc_cost_res};
}

template <size_t ITERS, size_t K, double... SPACE_BUDGETs>
auto run_experiment(const std::vector<element> &sigma) {
  // RankGenerator pi(rd());
  // std::array<tuple<baseline_stats, baseline_stats, std::array<exec_stats, ITERS>>, sizeof...(SPACE_BUDGETs)> res{};
  
  const int n = sigma.size();
  constexpr std::array<double, sizeof...(SPACE_BUDGETs)> budgets = { SPACE_BUDGETs... };
  // std::array<RankGenerator, ITERS> pi = views::iota(0, ITERS) |
  //                                       views::transform([](int _) { return RankGenerator(rd()); }) |
  //                                       ranges::to<std::array<RankGenerator, ITERS>>();
  vector<RankGenerator> pi;
  for (int i = 0; i < ITERS; i++) 
    pi.emplace_back(rd());

  std::array<baseline_stats, ITERS> res_pivot, res_pruned_pivot;
  std::array<std::array<exec_stats, sizeof...(SPACE_BUDGETs)>, ITERS> res{};
  
  #pragma omp parallel
  {
    #pragma omp single nowait
    {
      for (int i = 0; i < ITERS; i++) {
        #pragma omp task shared(res_pivot, res_pruned_pivot, sigma, pi)
        {
          res_pivot[i] = pivot_stats(sigma, pi[i], n);
        }
        #pragma omp task shared(res_pivot, res_pruned_pivot, sigma, pi)
        {
          res_pruned_pivot[i] = pruned_pivot_stats<K>(sigma, pi[i], n);
        }
      }
    }
  }

  cout << "finished precomp!" << endl;

  auto launch = [&]<double SPACE, size_t I>() {
    #pragma omp task shared(res, sigma, pi)
    {
      constexpr double R_ALLOC = min(1., SPACE * R_BUDGET);
      const int r = min((int)ceil(n * R_ALLOC), n);

      #pragma omp taskloop grainsize(1) shared(res, sigma, pi)
      for (int j = 0; j < ITERS; ++j) {
        cc_cost<K, SPACE>(sigma, pi[j], res[j][I], r);            
      }
    }
  };

  #pragma omp parallel
  {
    if (sigma.size() > 1'000'000) {
      omp_set_num_threads(48);
    }

    #pragma omp single nowait
    {
      [&]<size_t... I> (index_sequence<I...>) {
        (launch.template operator()<SPACE_BUDGETs, I>(), ...);
      }(make_index_sequence<sizeof...(SPACE_BUDGETs)>());
    }
  }

  cout << "completed!" << endl;
  // res = std::array{run_experiment_fixed_space<ITERS, K, SPACE_BUDGETs>(sigma, pi)...};
  return tuple{res_pivot, res_pruned_pivot, res};
}

void run(string input_file, float threshold) {
  cout << "reading in graph from " << input_file << "..." << endl;

  // vector<element> G = deserialize(input_file) | views::enumerate |
  //                     views::transform([](auto&& p) {
  //                       return element{(int)get<0>(p), std::move(get<1>(p))};
  //                     }) |
  //                     ranges::to<vector>();

  auto in = deserialize(input_file, threshold);
  vector<element> G;
  G.reserve(in.size());
  for (int i = 0; i < in.size(); i++)
    G.push_back({i, in[i]});

  cout << "finished reading graph!" << endl;
  
  // auto res = run_experiment<10, 15, .01, .02, .04, .08, .16>(G);
  auto [res_pivot, res_pruned_pivot, res_cc_cost] = run_experiment<30, 15, .01, .02, .04, .08, .16>(G);
  // auto [res_pivot, res_pruned_pivot, res_cc_cost] = run_experiment<3, 1, .01>(G);

  string out_folder = "./results/";
  assert(folder_exists(out_folder));

  std::filesystem::path input_path(input_file);
  string basename = input_path.stem().string();
  auto output_file = out_folder + timestamp_filename(basename + "_results_", ".txt");

  std::ofstream of(output_file);

  of << "[\n";

  for (int i = 0; i < res_pivot.size(); i++) {
    write_baseline_stats_json(of, res_pivot[i], true);
    write_baseline_stats_json(of, res_pruned_pivot[i], true);
  }

  for (int i = 0; i < res_cc_cost.size(); i++) { 
    for (int j = 0; j < res_cc_cost[i].size(); j++) {
      write_exec_stats_json(of, res_cc_cost[i][j], i+1 < res_cc_cost.size() || j+1 < res_cc_cost[i].size()); 
    }
  }

  of << "]" << endl;

  of.close();
}

template<size_t... Ks>
using k_seq = integer_sequence<size_t, Ks...>;

template<double... Bs>
class space_seq {};

template<size_t L, size_t R, size_t... Is>
constexpr auto make_range_helper(index_sequence<Is...>) { return k_seq<(L + Is)...>{}; }

template<size_t L, size_t R>
using k_range = decltype(make_range_helper<L, R>(make_index_sequence<R - L + 1>{}));

template<int... Ks, double... Bs>
void output_packs(k_seq<Ks...>, space_seq<Bs...>) {
  cout << "Ks(" << (sizeof...(Ks)) << "):";
  ((std::cout << " " << Ks), ...);
  cout << "\n";
  cout << "Bs(" << sizeof...(Bs) << "):";
  ((cout << " " << Bs), ...);
  cout << "\n";
}

double get_memory_usage_percent() {
    std::ifstream status("/proc/self/status");
    std::string line;
    long vm_rss = 0;
    
    while (std::getline(status, line)) {
        if (line.substr(0, 6) == "VmRSS:") {
            sscanf(line.c_str(), "VmRSS: %ld", &vm_rss);
            break;
        }
    }
    
    // Get total memory
    std::ifstream meminfo("/proc/meminfo");
    long total_mem = 0;
    while (std::getline(meminfo, line)) {
        if (line.substr(0, 9) == "MemTotal:") {
            sscanf(line.c_str(), "MemTotal: %ld", &total_mem);
            break;
        }
    }
    
    return (double)vm_rss / total_mem * 100.0;
}

template <size_t ITERS, bool SIMPLE_NODES = 0, bool SIMPLE_EDGES = 0, size_t... Ks, double... Bs>
auto run_experiment(const vector<element>& sigma, k_seq<Ks...>, space_seq<Bs...>) {
  constexpr size_t N_K = sizeof...(Ks);
  constexpr size_t N_B = sizeof...(Bs);

  cout << "Ks(" << (sizeof...(Ks)) << "):";
  ((std::cout << " " << Ks), ...);
  cout << "\n";
  cout << "Bs(" << sizeof...(Bs) << "):";
  ((cout << " " << Bs), ...);
  cout << endl;

  const int n = sigma.size();

  vector<RankGenerator> pi;
  for (int i = 0; i < ITERS; i++) 
    pi.emplace_back(rd());

  std::array<baseline_stats, ITERS> res_pivot{};
  std::array<baseline_stats, N_K * ITERS> res_pruned_pivot{};
  std::array<exec_stats, N_K * N_B * ITERS> res_cc{}, res_node{}, res_edge{};
  
  auto sigma_ptr = &sigma;
  auto pi_ptr = &pi;
  auto res_node_ptr = &res_node;
  auto res_edge_ptr = &res_edge;
  auto res_pruned_pivot_ptr = &res_pruned_pivot;
  auto res_pivot_ptr = &res_pivot;
  auto res_cc_ptr = &res_cc;

  omp_set_num_threads(48);
  
  #pragma omp parallel
  #pragma omp single nowait
  {
    // Pivot iterations
    #pragma omp task priority(1)
    {
      #pragma omp taskloop grainsize(1)
      for (int i = 0; i < ITERS; i++) {
        (*res_pivot_ptr)[i] = pivot_stats((*sigma_ptr), (*pi_ptr)[i], n);
      }
    }

    cout << "exited pivots" << endl;

    // Pruned Pivot iterations
    [&]<size_t... I> (index_sequence<I...>) {
      cout << "entered pruned pivot builder" << endl;

      ([&]{
        cout << "entered " << I << " for pruned pivot" << endl;
        #pragma omp task priority(1)
        {
          #pragma omp taskloop grainsize(1) 
          for (int j = 0; j < ITERS; ++j) {
            (*res_pruned_pivot_ptr)[I*ITERS + j] = pruned_pivot_stats<Ks>((*sigma_ptr), (*pi_ptr)[j], n);
          }
        }
      }(), ...);
    }(make_index_sequence<sizeof...(Ks)>());
  }

  cout << "finished pivot and pruned pivot!" << endl;

  if (sigma.size() > 1'000'000) {
    omp_set_num_threads(48);
  }

  const double baseline_mem = get_memory_usage_percent();
  auto total_kb_running = std::make_shared<std::atomic<double>>(0.0);

  #pragma omp parallel
  #pragma omp single nowait
  {
    // Use OpenMP depend token for robust release: create a single release task
    // with depend(out: token), and all worker tasks depend(in: token).
    char token = 0;

    // Create worker tasks (each depends on token)
    [&]<size_t... Is, size_t... Js> (index_sequence<Is...>, index_sequence<Js...>) {
      ([&]{
        constexpr auto I = Is;
        constexpr auto K = Ks;
        ([&]{
          constexpr auto J = Js;
          constexpr auto B = Bs;
          constexpr double R_ALLOC = min(1., B * R_BUDGET);
          const int r = min((int)ceil(n * R_ALLOC), n);

          for (int i = 0; i < ITERS; ++i) {
            #pragma omp task firstprivate(i, sigma_ptr, pi_ptr, res_node_ptr, res_edge_ptr, res_cc_ptr, total_kb_running, baseline_mem) depend(in: token)
            {
              // Wait if memory usage is too high
              while (true) {
                double curr_mem = get_memory_usage_percent();
                double curr_kb = *total_kb_running;
                double unit = curr_kb > 1 ? (curr_mem - baseline_mem) / curr_kb : 0.;
                double projected_mem = baseline_mem + unit * (curr_kb + K * B);
                if (curr_mem < 70.0 && projected_mem < 90.0)
                  break;                
                
                cout << "waiting for more memory! currently using " << curr_mem << " and have unit=" << unit << ", curr kb=" << curr_kb << ", my kb=" << K*B << ", projected usage=" << projected_mem << endl;

                #pragma omp taskyield
                std::this_thread::sleep_for(std::chrono::milliseconds(1000));  // Sleep 1s to avoid busy waiting
              }
              
              total_kb_running->fetch_add(K * B);

              if constexpr (SIMPLE_NODES) {
                auto& stat_node = (*res_node_ptr)[I*ITERS*N_B + J*ITERS + i];
                simple_sampler_node<K, B>(*sigma_ptr, (*pi_ptr)[i], stat_node);
              }
              if constexpr (SIMPLE_EDGES) {
                auto& stat_edge = (*res_edge_ptr)[I*ITERS*N_B + J*ITERS + i];
                simple_sampler_edge<K, B>(*sigma_ptr, (*pi_ptr)[i], stat_edge);
              }
              auto& stat_cc = (*res_cc_ptr)[I*ITERS*N_B + J*ITERS + i];
              cc_cost<K, B>(*sigma_ptr, (*pi_ptr)[i], stat_cc, r);

              total_kb_running->fetch_sub(K * B);
            }
          }

        }(), ...);
      }(), ...);
    }(make_index_sequence<N_K>(), make_index_sequence<N_B>());

    // Release task: after all worker tasks are created, this completes and releases them
    #pragma omp task depend(out: token)
    {
      // No-op: just releases the token
    }
  }

  cout << "completed!" << endl;
  // res = std::array{run_experiment_fixed_space<ITERS, K, SPACE_BUDGETs>(sigma, pi)...};
  if constexpr (SIMPLE_NODES && SIMPLE_EDGES) {
    return tuple{res_pivot, res_pruned_pivot, res_cc, res_node, res_edge};
  }
  else if constexpr (SIMPLE_NODES) {
    return tuple{res_pivot, res_pruned_pivot, res_cc, res_node};
  }
  else if constexpr (SIMPLE_EDGES) {
    return tuple{res_pivot, res_pruned_pivot, res_cc, res_edge};
  }
  else {
    return tuple{res_pivot, res_pruned_pivot, res_cc};
  }
}