"""Nearest neighbor index for text data."""

import json
import os
from collections.abc import Sequence
from pile_index import load_index
from metric import Metric

import faiss
import numpy as np

class PileIndexOptimized:
    """Nearest neighbor index."""

    def __init__(
        self,
        index: faiss.IndexFlat,
        offsets: Sequence[int],
        data_path: str,
        embedding_model=None,
    ):
        self.index = index
        self.offsets = offsets
        self.data_path = data_path
        assert len(self.offsets) == self.index.ntotal

        self.embedding_model = embedding_model
        if self.embedding_model is not None:
            assert hasattr(self.embedding_model, "embedding_dimension")

    def vector_query(self, query_vector: np.ndarray, num_neighbors: int):
        """Nearest neighbor vector query.

        Parameters
        ----------
        query_vector : np.ndarray
            Vector to query.
        num_neighbors : int
            Number of neighbors to return.

        Returns
        -------
        np.ndarray, List[str]
            Pair of vectors and data items.
        """

        assert self.index.d == query_vector.shape[1]

        results = self.index.search_and_reconstruct(query_vector, num_neighbors)
        values = results[0].reshape(num_neighbors)
        neighbors = results[1].reshape(num_neighbors)
        vectors = results[2].reshape(num_neighbors, -1)
        data_items = self.get_data_items(neighbors)

        return values, neighbors, vectors, data_items

    def get_data_items(self, neighbors: Sequence[int]):
        texts = []
        with open(self.data_path, "r") as f:
            for x in neighbors:
                offset = self.offsets[x]
                f.seek(offset)
                texts.append(json.loads(f.readline())["text"])

        return texts

    def string_query(self, query_str: str, num_neighbors: int):
        """Nearest neighbor string query.

        Parameters
        ----------
        query_str : str
            String to query.
        num_neighbors : int
            Number of neighbors to return.

        Returns
        -------
        np.ndarray, List[str]
            Pair of vectors and data items.
        """

        assert self.embedding_model

        # Embed query
        query_vector = self.embedding_model([query_str]).cpu().numpy()

        return self.vector_query(query_vector, num_neighbors)


def create_data_offsets(data_path: str, offsets_path: str):
    # Create an index of byte offsets for each line
    if not os.path.exists(offsets_path):
        offsets = []
        with open(data_path, "r") as f:
            offset = 0
            for line in f:
                offsets.append(offset)
                offset += len(line)

        with open(offsets_path, "w") as f:
            json.dump(offsets, f)

    with open(offsets_path, "r") as f:
        return json.load(f)


def build_roberta_index_optimized(data_file: str, metric: Metric, normalized: bool):
    """Convenience method to build roberta index.

    Parameters
    ----------
    data_file : str
        Name of Pile data file.

    Returns
    -------
    PileIndex
        Pile index.
    """

    data_path = os.path.join("pile/train", data_file)
    index_path = os.path.join("indexes/roberta-large", data_file + ".index")
    offsets_path = os.path.join("pile/train", data_file + ".offsets")

    assert os.path.exists(data_path), str(data_path)
    assert os.path.exists(index_path), str(index_path)
    faiss_index = load_index(index_path, metric, normalized)
    data_offsets = create_data_offsets(data_path, offsets_path)

    return PileIndexOptimized(faiss_index, data_offsets, data_path)
