#ifndef RADIXSORT_H
#define RADIXSORT_H

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iterator>
#include <vector>
#include <numeric>

namespace radix_sort {

// Helper function to extract the bits of a double for a specific byte position
inline uint64_t extract_byte(uint64_t u, int byte) {
  return (u >> (byte * 8)) & 0xFF;
}

static inline uint64_t floatFlip(uint64_t u) {
  if (u >> 63) { // Negative
    return ~u;
  } else { // Positive
    return u ^ (1ULL << 63);
  }
}

static inline uint64_t inv_floatFlip(uint64_t u) {
  if (u >> 63) { // Positive
    return u ^ (1ULL << 63);
  } else { // Negative
    return ~u;
  }
}

// Helper function to count sort by a specific byte position
void counting_sort_by_byte(uint64_t *arr, long n, int byte) {
  std::vector<int> count(256, 0);

  // Count each byte occurrence
  for (long i = 0; i < n; ++i) {
    ++count[extract_byte(arr[i], byte)];
  }

  // Transform count into positions
  std::partial_sum(count.begin(), count.end(), count.begin());

  // Build the output array
  std::vector<uint64_t> output(n);
  for (long i = n - 1; i >= 0; --i) {
    output[--count[extract_byte(arr[i], byte)]] = arr[i];
  }

  // Copy the sorted elements back into original array
  std::copy(output.begin(), output.end(), arr);
}

// Radix sort implementation for vector of doubles
void radix_sort(std::vector<double>::iterator begin, std::vector<double>::iterator end) {
  long n = std::distance(begin, end);

  // Copy the input array to a buffer
  uint64_t *buf = new uint64_t[n];
  for (long i = 0; i < n; ++i) {
    buf[i] = *(uint64_t *)&begin[i];
    buf[i] = floatFlip(buf[i]);
  }

  // Sorting by each byte. A double has 8 bytes.
  for (int byte = 0; byte < 8; ++byte) {
    counting_sort_by_byte(buf, n, byte);
  }

  // Copy sorted array back to the original range
  for (long i = 0; i < n; ++i) {
    buf[i] = inv_floatFlip(buf[i]);
    begin[i] = *(double *)&buf[i];
  }
}

#endif  // RADIX_SORT_H

}  // namespace radix_sort
