#include <faiss/IndexIVFPQPanorama.h>
#include <omp.h>
#include <cstdint>
#include <memory>
#include <mutex>

#include <algorithm>
#include <cinttypes>
#include <cstdio>
#include <iostream>
#include <limits>
#include <numeric>

#include <faiss/utils/hamming.h>
#include <faiss/utils/utils.h>

#include <faiss/IndexFlat.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/CodePacker.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>

namespace faiss {

static float quantization_time = 0;
static float search_preassigned_time = 0;
static float set_query_time = 0;
static float set_list_time = 0;
static float init_result_time = 0;
static float iterate_codes_time = 0;
static float scan_codes_time = 0;
static float reorder_result_time = 0;
static float compute_distances_time = 0;
static float malloc_time = 0;

static uint64_t total_active = 0;
static uint64_t total_points = 0;

IndexIVFPQPanorama::IndexIVFPQPanorama(
        Index* quantizer,
        size_t d,
        size_t nlist,
        size_t M,
        size_t nbits_per_idx,
        int n_levels,
        float epsilon,
        MetricType metric,
        bool own_invlists)
        : IndexIVFPQ(
                  quantizer,
                  d,
                  nlist,
                  M,
                  nbits_per_idx,
                  metric,
                  own_invlists),
          n_levels(n_levels),
          added(false),
          chunk_size(code_size / n_levels),
          levels_size(d / n_levels),
          nbits_per_idx(nbits_per_idx),
          m_level_width(M / n_levels),
          epsilon(epsilon) {
    FAISS_ASSERT(M % n_levels == 0);

    const size_t KB_16 = 16 * 1024; // 16KB in bytes
    batch_size = KB_16;
    FAISS_ASSERT(batch_size % 64 == 0 && batch_size > 0);
    FAISS_ASSERT(batch_size % 4 == 0);

    printf("N levels = %d\n", n_levels);
    printf("M = code_size = %zu\n", M);
    printf("Nbits per idx = %u (fixed)\n", 8);
    printf("Nlist = %zu\n", nlist);
    printf("Batch size = %zuB\n", batch_size);

    FAISS_ASSERT(m_level_width > 0);
    FAISS_ASSERT(nbits_per_idx == 8);
    FAISS_ASSERT(M == code_size);
    FAISS_ASSERT(metric == METRIC_L2);
}

void IndexIVFPQPanorama::add(idx_t n, const float* x) {
    FAISS_ASSERT(!added);
    added = true;

    num_points = n;
    IndexIVFPQ::add(n, x);

    size_t new_n = 0;
    column_offsets = new size_t[nlist];
    for (size_t i = 0; i < nlist; i++) {
        column_offsets[i] = new_n;
        new_n += invlists->list_size(i) * code_size;
    }

    column_storage = new uint8_t[code_size * n];
    column_storage_offsets = new uint8_t*[nlist * pq.M];

    for (size_t list_no = 0; list_no < nlist; list_no++) {
        size_t list_size = invlists->list_size(list_no);
        for (size_t m = 0; m < pq.M; m++) {
            size_t col_offset = column_offsets[list_no];
            size_t m_offset = m * 1 * list_size; // 1 = nbits_per_idx / 8
            column_storage_offsets[list_no * pq.M + m] =
                    column_storage + col_offset + m_offset;
            for (size_t point_idx = 0; point_idx < list_size; point_idx++) {
                uint8_t* dest =
                        column_storage + col_offset + m_offset + point_idx;
                const uint8_t* codes = invlists->get_codes(list_no);
                const uint8_t* src = codes + point_idx * code_size + m;
                memcpy(dest, src, 1);
            }
        }
    }

    cum_sums = new float[(n_levels + 1) * n];
    cum_sum_offsets = new size_t[nlist];

    init_exact_distances = new float[n];
    init_exact_distances_offsets = new size_t[nlist];

    size_t cum_size = 0;
    size_t init_size = 0;
    for (size_t list_no = 0; list_no < nlist; list_no++) {
        cum_sum_offsets[list_no] = cum_size;
        cum_size += invlists->list_size(list_no) * (n_levels + 1);

        init_exact_distances_offsets[list_no] = init_size;
        init_size += invlists->list_size(list_no);
    }

    for (size_t list_no = 0; list_no < nlist; list_no++) {
        const idx_t* idx = invlists->get_ids(list_no);
        size_t list_size = invlists->list_size(list_no);

        // Get the current centroid
        std::vector<float> centroid(d);
        quantizer->reconstruct(list_no, centroid.data());

        for (size_t point = 0; point < list_size; point++) {
            float init_exact_distance = 0.0f;

            idx_t id = idx[point];
            std::vector<float> vector(d);
            // We have agreed that if by_residual is true, we are calculating
            // distance between residuals so reconstruct_from_offset is wrong
            // because it'll approximate the original vector and not the
            // residual.
            decode_from_offset(list_no, point, vector.data());

            std::vector<float> suffix_sums(d + 1);
            suffix_sums[d] = 0.0f;

            for (int j = d - 1; j >= 0; j--) {
                init_exact_distance +=
                        vector[j] * vector[j] + 2 * vector[j] * centroid[j];
                float squaredVal = vector[j] * vector[j];
                suffix_sums[j] = suffix_sums[j + 1] + squaredVal;
            }

            // Extract level sums and take square root
            for (int level = 0; level < n_levels; level++) {
                int start_idx = level * levels_size;
                size_t offset =
                        cum_sum_offsets[list_no] + level * list_size + point;
                if (start_idx < d) {
                    cum_sums[offset] = sqrt(suffix_sums[start_idx]);
                } else {
                    cum_sums[offset] = 0.0f;
                }
            }

            // Last level sum
            size_t offset =
                    cum_sum_offsets[list_no] + n_levels * list_size + point;
            cum_sums[offset] = 0.0f;

            size_t init_offset = init_exact_distances_offsets[list_no];
            init_exact_distances[init_offset + point] = init_exact_distance;
        }
    }
}

// added purely for easier benchmarks
void IndexIVFPQPanorama::search(
        idx_t n,
        const float* x,
        idx_t k,
        float* distances,
        idx_t* labels,
        const SearchParameters* params_in) const {

    double search_start = omp_get_wtime();
    verification_time = 0.0;

    FAISS_THROW_IF_NOT(k > 0);
    const IVFSearchParameters* params = nullptr;
    if (params_in) {
        params = dynamic_cast<const IVFSearchParameters*>(params_in);
        FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
    }
    const size_t nprobe =
            std::min(nlist, params ? params->nprobe : this->nprobe);
    FAISS_THROW_IF_NOT(nprobe > 0);

    // search function for a subset of queries
    auto sub_search_func = [this, k, nprobe, params](
                                   idx_t n,
                                   const float* x,
                                   float* distances,
                                   idx_t* labels,
                                   IndexIVFStats* ivf_stats) {
        std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
        std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);

        quantizer->search(
                n,
                x,
                nprobe,
                coarse_dis.get(),
                idx.get(),
                params ? params->quantizer_params : nullptr);

        invlists->prefetch_lists(idx.get(), n * nprobe);

        search_preassigned(
                n,
                x,
                k,
                idx.get(),
                coarse_dis.get(),
                distances,
                labels,
                false,
                params,
                ivf_stats);
    };

    if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
        int nt = std::min(omp_get_max_threads(), int(n));
        std::vector<IndexIVFStats> stats(nt);
        std::mutex exception_mutex;
        std::string exception_string;

#pragma omp parallel for if (nt > 1)
        for (idx_t slice = 0; slice < nt; slice++) {
            IndexIVFStats local_stats;
            idx_t i0 = n * slice / nt;
            idx_t i1 = n * (slice + 1) / nt;
            if (i1 > i0) {
                try {
                    sub_search_func(
                            i1 - i0,
                            x + i0 * d,
                            distances + i0 * k,
                            labels + i0 * k,
                            &stats[slice]);
                } catch (const std::exception& e) {
                    std::lock_guard<std::mutex> lock(exception_mutex);
                    exception_string = e.what();
                }
            }
        }

        if (!exception_string.empty()) {
            FAISS_THROW_MSG(exception_string.c_str());
        }

        // collect stats
        for (idx_t slice = 0; slice < nt; slice++) {
            indexIVF_stats.add(stats[slice]);
        }
    } else {
        // handle parallelization at level below (or don't run in parallel at
        // all)
        sub_search_func(n, x, distances, labels, &indexIVF_stats);
    }

    double search_end = omp_get_wtime();
    printf("Search time = %f ms\n", (search_end - search_start) * 1000);
    printf("Verification time = %f ms\n", verification_time * 1000);
}

void IndexIVFPQPanorama::search_preassigned(
        idx_t n,
        const float* x,
        idx_t k,
        const idx_t* keys, // which clusters
        const float* coarse_dis,
        float* distances,
        idx_t* labels,
        bool store_pairs,
        const IVFSearchParameters* params,
        IndexIVFStats* ivf_stats) const {
    FAISS_THROW_IF_NOT(k > 0);

    idx_t nprobe = params ? params->nprobe : this->nprobe;
    nprobe = std::min((idx_t)nlist, nprobe);
    FAISS_THROW_IF_NOT(nprobe > 0);

    const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
    idx_t max_codes = params ? params->max_codes : this->max_codes;
    IDSelector* sel = params ? params->sel : nullptr;
    const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
    if (selr) {
        if (selr->assume_sorted) {
            sel = nullptr; // use special IDSelectorRange processing
        } else {
            selr = nullptr; // use generic processing
        }
    }

    FAISS_THROW_IF_NOT_MSG(
            !(sel && store_pairs),
            "selector and store_pairs cannot be combined");

    FAISS_THROW_IF_NOT_MSG(
            !invlists->use_iterator || (max_codes == 0 && store_pairs == false),
            "iterable inverted lists don't support max_codes and store_pairs");

    size_t nlistv = 0, ndis = 0, nheap = 0;

    using HeapForIP = CMin<float, idx_t>;
    using HeapForL2 = CMax<float, idx_t>;

    bool interrupt = false;
    std::mutex exception_mutex;
    std::string exception_string;

    int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
    bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);

    FAISS_THROW_IF_NOT_MSG(
            max_codes == 0 || pmode == 0 || pmode == 3,
            "max_codes supported only for parallel_mode = 0 or 3");

    if (max_codes == 0) {
        max_codes = unlimited_list_size;
    }

    [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
            (pmode == 0           ? false
                     : pmode == 3 ? n > 1
                     : pmode == 1 ? nprobe > 1
                                  : nprobe * n > 1);

    void* inverted_list_context =
            params ? params->inverted_list_context : nullptr;

    // set up sim_table cache
    const size_t sim_table_size = pq.ksub * pq.M;

    std::vector<float> sim_table_cache(nprobe * sim_table_size);
    std::vector<float> dis0s_cache(nprobe);

    std::vector<uint8_t> batch_storage(batch_size);
    std::vector<uint8_t*> batch_offsets(4);
    for (size_t m = 0; m < 4; m++) {
        batch_offsets[m] = batch_storage.data() + m * batch_size / 4;
    }

    size_t max_num_codes = 0;
    for (size_t i = 0; i < nlist; i++) {
        max_num_codes = std::max(max_num_codes, invlists->list_size(i));
    }

    std::vector<float> suffixSums(d + 1);
    std::vector<float> query_cum_norms(n_levels + 1);
    std::vector<float> query(d);
    std::vector<float> exact_distances(max_num_codes);
    std::vector<uint8_t> bitset(max_num_codes);
    std::vector<uint32_t> indices(max_num_codes);

#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
    {
        std::unique_ptr<InvertedListScanner> scanner(
                get_InvertedListScanner(store_pairs, sel, params));

        /*****************************************************
         * Depending on parallel_mode, there are two possible ways
         * to organize the search. Here we define local functions
         * that are in common between the two
         ******************************************************/

        // initialize + reorder a result heap

        auto init_result = [&](float* simi, idx_t* idxi) {
            if (!do_heap_init)
                return;
            if (metric_type == METRIC_INNER_PRODUCT) {
                heap_heapify<HeapForIP>(k, simi, idxi);
            } else {
                heap_heapify<HeapForL2>(k, simi, idxi);
            }
        };

        auto add_local_results = [&](const float* local_dis,
                                     const idx_t* local_idx,
                                     float* simi,
                                     idx_t* idxi) {
            if (metric_type == METRIC_INNER_PRODUCT) {
                heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
            } else {
                heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
            }
        };

        auto reorder_result = [&](float* simi, idx_t* idxi) {
            if (!do_heap_init)
                return;
            if (metric_type == METRIC_INNER_PRODUCT) {
                heap_reorder<HeapForIP>(k, simi, idxi);
            } else {
                heap_reorder<HeapForL2>(k, simi, idxi);
            }
        };

        // single list scan using the current scanner (with query
        // set porperly) and storing results in simi and idxi
        auto scan_one_list = [&](size_t list_no,
                                 const uint8_t* storage,
                                 uint8_t* const* storage_offset,
                                 const float* cum_sums,
                                 float query_cum_norm,
                                 uint8_t* bitset,
                                 float* exact_distances,
                                 idx_t cluster_id,
                                 float coarse_dis_i,
                                 float* simi,
                                 idx_t* idxi,
                                 idx_t list_size_max,
                                 size_t level,
                                 uint8_t* batch_storage,
                                 uint8_t** batch_offsets,
                                 uint32_t* indices,
                                 size_t* num_active) {
            if (cluster_id < 0) {
                return (size_t)0;
            }
            FAISS_THROW_IF_NOT_FMT(
                    cluster_id < (idx_t)nlist,
                    "Invalid key=%" PRId64 " nlist=%zd\n",
                    cluster_id,
                    nlist);

            if (invlists->is_empty(cluster_id, inverted_list_context)) {
                return (size_t)0;
            }

            size_t sim_table_offset =
                    list_no * sim_table_size + level * pq.ksub * chunk_size;
            if (level == 0) {
                scanner->set_list_panorama(
                        cluster_id,
                        coarse_dis_i,
                        sim_table_cache.data() + sim_table_offset,
                        dis0s_cache.data() + list_no,
                        true);
            } else {
                scanner->set_list_panorama(
                        cluster_id,
                        coarse_dis_i,
                        sim_table_cache.data() + sim_table_offset,
                        dis0s_cache.data() + list_no,
                        false);
                scanner->set_sim_table(
                        sim_table_cache.data() + sim_table_offset,
                        dis0s_cache[list_no]);
            }

            nlistv++;

            try {
                FAISS_ASSERT(!invlists->use_iterator);
                size_t list_size = invlists->list_size(cluster_id);
                if (list_size > list_size_max) {
                    list_size = list_size_max;
                }

                std::unique_ptr<InvertedLists::ScopedIds> sids;
                const idx_t* ids = nullptr;

                if (!store_pairs) {
                    sids = std::make_unique<InvertedLists::ScopedIds>(
                            invlists, cluster_id);
                    ids = sids->get();
                }

                if (selr) { // IDSelectorRange
                    // restrict search to a section of the inverted list
                    size_t jmin, jmax;
                    selr->find_sorted_ids_bounds(list_size, ids, &jmin, &jmax);
                    list_size = jmax - jmin;
                    if (list_size == 0) {
                        return (size_t)0;
                    }
                    ids += jmin;
                }

                total_active += *num_active;
                total_points += list_size;


                double verification_start = omp_get_wtime();
                nheap += scanner->scan_codes_panorama(
                        list_size,
                        m_level_width,
                        epsilon,
                        storage,
                        storage_offset,
                        bitset,
                        exact_distances,
                        ids,
                        simi,
                        idxi,
                        k,
                        cum_sums,
                        query_cum_norm,
                        batch_storage,
                        batch_offsets,
                        indices,
                        num_active,
                        batch_size);
                double verification_end = omp_get_wtime();
                verification_time += verification_end - verification_start;

                return list_size;
            } catch (const std::exception& e) {
                std::lock_guard<std::mutex> lock(exception_mutex);
                exception_string =
                        demangle_cpp_symbol(typeid(e).name()) + "  " + e.what();
                interrupt = true;
                return size_t(0);
            }
        };

        /****************************************************
         * Actual loops
         ****************************************************/

        FAISS_ASSERT(pmode == 0);
        if (pmode == 0) {
#pragma omp for
            for (idx_t i = 0; i < n; i++) {
                if (interrupt) {
                    continue;
                }

                scanner->set_query(x + i * d);
                suffixSums[d] = 0.0f;

                const float* query = x + i * d;

                for (int j = d - 1; j >= 0; --j) {
                    float squaredVal = query[j] * query[j];
                    suffixSums[j] = suffixSums[j + 1] + squaredVal;
                }

                // Extract level sums and take square root
                for (int level_idx = 0; level_idx < n_levels; level_idx++) {
                    int startIdx = level_idx * levels_size;
                    if (startIdx < d) {
                        query_cum_norms[level_idx] =
                                sqrt(suffixSums[startIdx]);
                    } else {
                        query_cum_norms[level_idx] = 0.0f;
                    }
                }
                query_cum_norms[n_levels] = 0.0f;

                float* simi = distances + i * k;
                idx_t* idxi = labels + i * k;

                init_result(simi, idxi);

                idx_t nscan = 0;

                for (size_t list_no = 0; list_no < nprobe; list_no++) {
                    idx_t cluster_id = keys[i * nprobe + list_no];
                    size_t active_num = invlists->list_size(cluster_id);
                    std::iota(indices.begin(), indices.begin() + active_num, 0);
                    std::fill(bitset.begin(), bitset.begin() + active_num, 1);

                    // initialize with the first cum sums of each point
                    for (size_t idx = 0; idx < active_num; idx++) {
                        exact_distances[idx] = init_exact_distances
                                [init_exact_distances_offsets[cluster_id] +
                                 idx];
                    }

                    for (size_t level = 0; level < n_levels; level++) {
                        size_t list_size = invlists->list_size(cluster_id);
                        size_t column_offset = column_offsets[cluster_id];
                        size_t level_offset = list_size * chunk_size * level;

                        const uint8_t* storage =
                                column_storage + column_offset + level_offset;

                        uint8_t* const* storage_offset =
                                column_storage_offsets + cluster_id * pq.M +
                                level * chunk_size;

                        size_t cum_sum_offset = cum_sum_offsets[cluster_id];
                        const float* cum_sums_2 = cum_sums + cum_sum_offset +
                                (level + 1) * list_size;

                        float query_cum_norm = query_cum_norms[level + 1];

                        nscan += scan_one_list(
                                list_no,
                                storage,
                                storage_offset,
                                cum_sums_2,
                                query_cum_norm,
                                bitset.data(),
                                exact_distances.data(),
                                cluster_id,
                                coarse_dis[i * nprobe + list_no],
                                simi,
                                idxi,
                                max_codes - nscan,
                                level,
                                batch_storage.data(),
                                batch_offsets.data(),
                                indices.data(),
                                &active_num);
                        if (nscan >= max_codes) {
                            break;
                        }
                    }
                }

                ndis += nscan;
                reorder_result(simi, idxi);

                if (InterruptCallback::is_interrupted()) {
                    interrupt = true;
                }
            }
        }
    }
    if (interrupt) {
        if (!exception_string.empty()) {
            FAISS_THROW_FMT(
                    "search interrupted with: %s", exception_string.c_str());
        } else {
            FAISS_THROW_MSG("computation interrupted");
        }
    }

    if (ivf_stats == nullptr) {
        ivf_stats = &indexIVF_stats;
    }
    ivf_stats->nq += n;
    ivf_stats->nlist += nlistv;
    ivf_stats->ndis += ndis;
    ivf_stats->nheap_updates += nheap;

    printf("avg_level: %f\n", (float)total_active / total_points);
}

IndexIVFPQPanorama::IndexIVFPQPanorama()
        : IndexIVFPQ(),
        n_levels(0),
        column_storage(nullptr),
        column_storage_offsets(nullptr),
        column_offsets(nullptr),
        cum_sums(nullptr),
        cum_sum_offsets(nullptr),
        init_exact_distances(nullptr),
        init_exact_distances_offsets(nullptr),
        chunk_size(0),
        levels_size(0),
        added(false),
        num_points(0),
        batch_size(0),
        nbits_per_idx(0),
        m_level_width(0),
        epsilon(1.0) {
    // Default constructor for serialization - fields will be filled in by read_index
    is_trained = true;
}

void IndexIVFPQPanorama::set_nlevels(int new_levels) {
    // Note: This is a placeholder implementation
    // The actual implementation would need to:
    // 1. Recalculate derived fields like chunk_size, levels_size, m_level_width
    // 2. Reallocate and recompute cum_sums arrays
    // 3. Restructure column storage if needed
    // For now, we'll just leave it empty as requested
    printf("IndexIVFPQPanorama::set_nlevels(%d) called - placeholder implementation\n", new_levels);
}

} // namespace faiss