/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

// -*- c++ -*-

#include <faiss/invlists/DirectMap.h>

#include <cassert>
#include <cstdio>

#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/invlists/BlockInvertedLists.h>

namespace faiss {

DirectMap::DirectMap() : type(NoMap) {}

void DirectMap::set_type(
        Type new_type,
        const InvertedLists* invlists,
        size_t ntotal) {
    FAISS_THROW_IF_NOT(
            new_type == NoMap || new_type == Array || new_type == Hashtable);

    if (new_type == type) {
        // nothing to do
        return;
    }

    array.clear();
    hashtable.clear();
    type = new_type;

    if (new_type == NoMap) {
        return;
    } else if (new_type == Array) {
        array.resize(ntotal, -1);
    } else if (new_type == Hashtable) {
        hashtable.reserve(ntotal);
    }

    for (size_t key = 0; key < invlists->nlist; key++) {
        size_t list_size = invlists->list_size(key);
        InvertedLists::ScopedIds idlist(invlists, key);

        if (new_type == Array) {
            for (long ofs = 0; ofs < list_size; ofs++) {
                FAISS_THROW_IF_NOT_MSG(
                        0 <= idlist[ofs] && idlist[ofs] < ntotal,
                        "direct map supported only for seuquential ids");
                array[idlist[ofs]] = lo_build(key, ofs);
            }
        } else if (new_type == Hashtable) {
            for (long ofs = 0; ofs < list_size; ofs++) {
                hashtable[idlist[ofs]] = lo_build(key, ofs);
            }
        }
    }
}

void DirectMap::clear() {
    array.clear();
    hashtable.clear();
}

idx_t DirectMap::get(idx_t key) const {
    if (type == Array) {
        FAISS_THROW_IF_NOT_MSG(key >= 0 && key < array.size(), "invalid key");
        idx_t lo = array[key];
        FAISS_THROW_IF_NOT_MSG(lo >= 0, "-1 entry in direct_map");
        return lo;
    } else if (type == Hashtable) {
        auto res = hashtable.find(key);
        FAISS_THROW_IF_NOT_MSG(res != hashtable.end(), "key not found");
        return res->second;
    } else {
        FAISS_THROW_MSG("direct map not initialized");
    }
}

void DirectMap::add_single_id(idx_t id, idx_t list_no, size_t offset) {
    if (type == NoMap)
        return;

    if (type == Array) {
        assert(id == array.size());
        if (list_no >= 0) {
            array.push_back(lo_build(list_no, offset));
        } else {
            array.push_back(-1);
        }
    } else if (type == Hashtable) {
        if (list_no >= 0) {
            hashtable[id] = lo_build(list_no, offset);
        }
    }
}

void DirectMap::check_can_add(const idx_t* ids) {
    if (type == Array && ids) {
        FAISS_THROW_MSG("cannot have array direct map and add with ids");
    }
}

/********************* DirectMapAdd implementation */

DirectMapAdd::DirectMapAdd(DirectMap& direct_map, size_t n, const idx_t* xids)
        : direct_map(direct_map), type(direct_map.type), n(n), xids(xids) {
    if (type == DirectMap::Array) {
        FAISS_THROW_IF_NOT(xids == nullptr);
        ntotal = direct_map.array.size();
        direct_map.array.resize(ntotal + n, -1);
    } else if (type == DirectMap::Hashtable) {
        // can't parallel update hashtable so use temp array
        all_ofs.resize(n, -1);
    }
}

void DirectMapAdd::add(size_t i, idx_t list_no, size_t ofs) {
    if (type == DirectMap::Array) {
        direct_map.array[ntotal + i] = lo_build(list_no, ofs);
    } else if (type == DirectMap::Hashtable) {
        all_ofs[i] = lo_build(list_no, ofs);
    }
}

DirectMapAdd::~DirectMapAdd() {
    if (type == DirectMap::Hashtable) {
        for (int i = 0; i < n; i++) {
            idx_t id = xids ? xids[i] : ntotal + i;
            direct_map.hashtable[id] = all_ofs[i];
        }
    }
}

/********************************************************/

using ScopedCodes = InvertedLists::ScopedCodes;
using ScopedIds = InvertedLists::ScopedIds;

size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
    size_t nlist = invlists->nlist;
    std::vector<idx_t> toremove(nlist);

    size_t nremove = 0;
    BlockInvertedLists* block_invlists =
            dynamic_cast<BlockInvertedLists*>(invlists);
    if (type == NoMap) {
        if (block_invlists != nullptr) {
            return block_invlists->remove_ids(sel);
        }
        // exhaustive scan of IVF
#pragma omp parallel for
        for (idx_t i = 0; i < nlist; i++) {
            idx_t l0 = invlists->list_size(i), l = l0, j = 0;
            ScopedIds idsi(invlists, i);
            while (j < l) {
                if (sel.is_member(idsi[j])) {
                    l--;
                    invlists->update_entry(
                            i,
                            j,
                            invlists->get_single_id(i, l),
                            ScopedCodes(invlists, i, l).get());
                } else {
                    j++;
                }
            }
            toremove[i] = l0 - l;
        }
        // this will not run well in parallel on ondisk because of
        // possible shrinks
        for (idx_t i = 0; i < nlist; i++) {
            if (toremove[i] > 0) {
                nremove += toremove[i];
                invlists->resize(i, invlists->list_size(i) - toremove[i]);
            }
        }
    } else if (type == Hashtable) {
        FAISS_THROW_IF_MSG(
                block_invlists,
                "remove with hashtable is not supported with BlockInvertedLists");
        const IDSelectorArray* sela =
                dynamic_cast<const IDSelectorArray*>(&sel);
        FAISS_THROW_IF_NOT_MSG(
                sela, "remove with hashtable works only with IDSelectorArray");

        for (idx_t i = 0; i < sela->n; i++) {
            idx_t id = sela->ids[i];
            auto res = hashtable.find(id);
            if (res != hashtable.end()) {
                size_t list_no = lo_listno(res->second);
                size_t offset = lo_offset(res->second);
                idx_t last = invlists->list_size(list_no) - 1;
                hashtable.erase(res);
                if (offset < last) {
                    idx_t last_id = invlists->get_single_id(list_no, last);
                    invlists->update_entry(
                            list_no,
                            offset,
                            last_id,
                            ScopedCodes(invlists, list_no, last).get());
                    // update hash entry for last element
                    hashtable[last_id] = lo_build(list_no, offset);
                }
                invlists->resize(list_no, last);
                nremove++;
            }
        }

    } else {
        FAISS_THROW_MSG("remove not supported with this direct_map format");
    }
    return nremove;
}

void DirectMap::update_codes(
        InvertedLists* invlists,
        int n,
        const idx_t* ids,
        const idx_t* assign,
        const uint8_t* codes) {
    FAISS_THROW_IF_NOT(type == Array);

    size_t code_size = invlists->code_size;

    for (size_t i = 0; i < n; i++) {
        idx_t id = ids[i];
        FAISS_THROW_IF_NOT_MSG(
                0 <= id && id < array.size(), "id to update out of range");
        { // remove old one
            idx_t dm = array[id];
            int64_t ofs = lo_offset(dm);
            int64_t il = lo_listno(dm);
            size_t l = invlists->list_size(il);
            if (ofs != l - 1) { // move l - 1 to ofs
                int64_t id2 = invlists->get_single_id(il, l - 1);
                array[id2] = lo_build(il, ofs);
                invlists->update_entry(
                        il, ofs, id2, invlists->get_single_code(il, l - 1));
            }
            invlists->resize(il, l - 1);
        }
        { // insert new one
            int64_t il = assign[i];
            size_t l = invlists->list_size(il);
            idx_t dm = lo_build(il, l);
            array[id] = dm;
            invlists->add_entry(il, id, codes + i * code_size);
        }
    }
}

} // namespace faiss
