#include "candidates.h"

#include <algorithm>
#include <atomic>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <stdexcept>
#include <thread>
#include <utility>
#include <vector>

CandidateList build_nn_candidates(const Instance &inst,
                                  const DistanceProvider &dist,
                                  int k,
                                  double beta,
                                  int threads) {
    CandidateList cand;
    cand.n = inst.n;
    cand.k = k;
    cand.nodes.resize(static_cast<std::size_t>(inst.n) * k, -1);
    cand.inv_dist_beta.resize(static_cast<std::size_t>(inst.n) * k, 0.0f);

    threads = std::max(1, threads);
    threads = std::min(threads, inst.n);
    std::atomic<int> next{0};
    std::vector<std::thread> pool;
    pool.reserve(threads);

    for (int t = 0; t < threads; ++t) {
        pool.emplace_back([&, t]() {
            std::vector<std::pair<double, int>> buf;
            buf.reserve(inst.n - 1);
            while (true) {
                int i = next.fetch_add(1);
                if (i >= inst.n) {
                    break;
                }
                buf.clear();
                for (int j = 0; j < inst.n; ++j) {
                    if (i == j) {
                        continue;
                    }
                    double d2 = dist.distance_sq(i, j);
                    buf.emplace_back(d2, j);
                }
                if (k < static_cast<int>(buf.size())) {
                    std::nth_element(buf.begin(), buf.begin() + k, buf.end(),
                                     [](const auto &a, const auto &b) { return a.first < b.first; });
                    buf.resize(k);
                }
                std::sort(buf.begin(), buf.end(),
                          [](const auto &a, const auto &b) { return a.first < b.first; });
                for (int c = 0; c < k; ++c) {
                    int node = buf[c].second;
                    cand.nodes[i * k + c] = node;
                    cand.inv_dist_beta[i * k + c] = static_cast<float>(dist.inv_dist_beta(i, node, beta));
                }
            }
        });
    }
    for (auto &th : pool) {
        th.join();
    }

    return cand;
}

// Simple and efficient heatmap loading (buffered token scan)
class FastScanner {
public:
    explicit FastScanner(const std::string &path) {
        file_ = std::fopen(path.c_str(), "rb");
        if (!file_) {
            throw std::runtime_error("Failed to open file: " + path);
        }
    }

    ~FastScanner() {
        if (file_) {
            std::fclose(file_);
        }
    }

    bool next_double(double &val) {
        char tok[kTokenMax];
        if (!next_token(tok, sizeof(tok))) {
            return false;
        }
        val = std::strtod(tok, nullptr);
        return true;
    }

    bool next_int(int &val) {
        char tok[kTokenMax];
        if (!next_token(tok, sizeof(tok))) {
            return false;
        }
        val = static_cast<int>(std::strtol(tok, nullptr, 10));
        return true;
    }

private:
    static constexpr std::size_t kBufferSize = 1 << 20;
    static constexpr std::size_t kTokenMax = 256;
    std::FILE *file_ = nullptr;
    char buffer_[kBufferSize];
    std::size_t len_ = 0;
    std::size_t pos_ = 0;

    bool refill() {
        len_ = std::fread(buffer_, 1, kBufferSize, file_);
        pos_ = 0;
        return len_ > 0;
    }

    bool skip_space() {
        while (true) {
            if (pos_ >= len_) {
                if (!refill()) {
                    return false;
                }
            }
            unsigned char c = static_cast<unsigned char>(buffer_[pos_]);
            if (!std::isspace(c)) {
                return true;
            }
            ++pos_;
        }
    }

    bool next_token(char *out, std::size_t max) {
        if (!skip_space()) {
            return false;
        }
        std::size_t idx = 0;
        while (true) {
            if (pos_ >= len_) {
                if (!refill()) {
                    break;
                }
            }
            unsigned char c = static_cast<unsigned char>(buffer_[pos_]);
            if (std::isspace(c)) {
                break;
            }
            if (idx + 1 < max) {
                out[idx++] = static_cast<char>(c);
            }
            ++pos_;
        }
        out[idx] = '\0';
        return true;
    }
};

CandidateList build_heatmap_candidates(const Instance &inst,
                                       const DistanceProvider &dist,
                                       const CandidateList &nn,
                                       int k,
                                       double beta,
                                       const std::string &heatmap_path,
                                       double threshold,
                                       bool confidence_mode,
                                       double confidence_gamma) {
    FastScanner scanner(heatmap_path);
    int file_n = 0;
    if (!scanner.next_int(file_n)) {
        throw std::runtime_error("Failed to read heatmap header");
    }
    if (file_n != inst.n) {
        throw std::runtime_error("Heatmap N mismatch");
    }

    CandidateList cand;
    cand.n = inst.n;
    cand.k = k;
    cand.nodes.resize(static_cast<std::size_t>(inst.n) * k, -1);
    cand.inv_dist_beta.resize(static_cast<std::size_t>(inst.n) * k, 0.0f);

    struct HeatmapEdge {
        double dist2 = 0.0;
        int node = -1;
        double conf = 0.0;
    };

    std::vector<int> marks(inst.n, -1);
    int mark_id = 0;
    std::vector<HeatmapEdge> pool;
    pool.reserve(static_cast<std::size_t>(k) * 2);

    for (int i = 0; i < inst.n; ++i) {
        pool.clear();
        mark_id++;
        for (int j = 0; j < inst.n; ++j) {
            double val = 0.0;
            if (!scanner.next_double(val)) {
                throw std::runtime_error("Unexpected EOF in heatmap");
            }
            if (j == i) {
                continue;
            }
            if (val > threshold) {
                double d2 = dist.distance_sq(i, j);
            pool.push_back({d2, j, val});
            marks[j] = mark_id;
        }
        }

        // If fewer than K, fill with nearest neighbors by true distance (deduplicated)
        const int *nn_row = nn.row(i);
        int nn_k = nn.k;
        for (int c = 0; static_cast<int>(pool.size()) < k && c < nn_k; ++c) {
            int node = nn_row[c];
            if (marks[node] == mark_id) {
                continue;
            }
            marks[node] = mark_id;
            double d2 = dist.distance_sq(i, node);
            pool.push_back({d2, node, 0.0});
        }

        if (static_cast<int>(pool.size()) > k) {
            std::nth_element(pool.begin(), pool.begin() + k, pool.end(),
                             [](const auto &a, const auto &b) { return a.dist2 < b.dist2; });
            pool.resize(k);
        }
        std::sort(pool.begin(), pool.end(),
                  [](const auto &a, const auto &b) { return a.dist2 < b.dist2; });

        if (static_cast<int>(pool.size()) < k) {
            throw std::runtime_error("Failed to fill candidate list");
        }

        for (int c = 0; c < k; ++c) {
            int node = pool[c].node;
            cand.nodes[i * k + c] = node;
            double inv = dist.inv_dist_beta(i, node, beta);
            if (confidence_mode) {
                double conf = pool[c].conf;
                if (conf <= 0.0) {
                    conf = 1e-9;
                }
                double factor = conf;
                if (confidence_gamma == 2.0) {
                    factor = conf * conf;
                } else if (confidence_gamma != 1.0) {
                    factor = std::pow(conf, confidence_gamma);
                }
                inv *= factor;
            }
            cand.inv_dist_beta[i * k + c] = static_cast<float>(inv);
        }
    }

    return cand;
}
