"""Nearest neighbor index for text data."""

import os
import json

import faiss
import numpy as np

from tqdm import tqdm

from afsl.adapters.faiss import *
from afsl.acquisition_functions.itl_noiseless import ITLNoiseless
from afsl.acquisition_functions.itl import ITL
from afsl.acquisition_functions.ctl import CTL
from afsl.acquisition_functions.vtl import VTL
from afsl.acquisition_functions.lazy_vtl import LazyVTL
from afsl.acquisition_functions.random import Random
from afsl.acquisition_functions.uncertainty_sampling import UncertaintySampling

import logging

from metric import Metric, get_index_class
from utils import get_device

logger = logging.getLogger(__name__)


class PileIndex:
    """Nearest neighbor index."""

    def __init__(self, index: faiss.IndexFlat, data_dict: dict, metric: Metric, normalized: bool, embedding_model=None):
        """Initialize pile index.

        Parameters
        ----------
        index : faiss.IndexFlat
            Nearest neighbor index.
        data_dict : dict
            Dictionary mapping index position to data item.
            data_dict[i] should correspond to the vector at index position i.
        embedding_model : TextEmbedding
            Embedding model for text data.
            Should be identical to the embedding model used to create the index.

        Returns
        -------
        PileIndex
            Pile nearest neighbor index.
        """

        self.index = index
        assert isinstance(self.index, get_index_class(metric)), f"Index is not of the correct type for metric `{metric}`: {type(self.index)}"
        self.data_dict = data_dict
        self.metric = metric
        self.normalized = normalized
        assert len(self.data_dict) == self.index.ntotal

        self.embedding_model = embedding_model
        if self.embedding_model is not None:
            assert hasattr(self.embedding_model, "embedding_dimension")

    def vector_query(self, query_vector: np.ndarray, num_neighbors: int):
        """Nearest neighbor vector query.

        Parameters
        ----------
        query_vector : np.ndarray
            Vector to query.
        num_neighbors : int
            Number of neighbors to return.

        Returns
        -------
        np.ndarray, List[str]
            Pair of vectors and data items.
        """

        assert self.index.d == query_vector.shape[1]

        results = self.index.search_and_reconstruct(query_vector, num_neighbors)
        values = results[0].reshape(num_neighbors)
        neighbors = results[1].reshape(num_neighbors)
        vectors = results[2].reshape(num_neighbors, -1)
        data_items = [self.data_dict[i] for i in neighbors]

        return values, neighbors, vectors, data_items

    def string_query(self, query_str: str, num_neighbors: int):
        """Nearest neighbor string query.

        Parameters
        ----------
        query_str : str
            String to query.
        num_neighbors : int
            Number of neighbors to return.

        Returns
        -------
        np.ndarray, List[str]
            Pair of vectors and data items.
        """

        assert self.embedding_model

        # Embed query
        query_vector = self.embedding_model([query_str]).cpu().numpy()

        return self.vector_query(query_vector, num_neighbors)


def data_to_dict(data_path: str):
    """Read Pile data file into dictionary.

    Parameters
    ----------
    data_path : str
        Path to Pile data file.
        Assumes json line format with 'text' key.

    Returns
    -------
    dict
        Dictionary mapping index to data item.
    """

    print("Reading data file: ", data_path)

    texts = []
    with open(data_path, "r") as data_file:
        for line in tqdm(data_file):
            texts.append(json.loads(line)["text"])

    return dict(zip(range(len(texts)), texts))

def load_index(index_path: str, metric: Metric, normalized: bool):
    loaded_index = faiss.read_index(index_path)

    vectors = loaded_index.reconstruct_n(0, loaded_index.ntotal)
    if normalized:
        faiss.normalize_L2(vectors)
    Index = get_index_class(metric)
    index = Index(loaded_index.d)
    index.add(vectors)

    return index

def build_index(data_path: str, index_path: str, metric: Metric, normalized: bool):
    """Build index from Pile data and index files."""

    index = load_index(index_path, metric, normalized)
    data_dict = data_to_dict(data_path)

    return index, data_dict


def build_roberta_index(data_file: str, metric: Metric, normalized: bool):
    """Convenience method to build roberta index.

    Parameters
    ----------
    data_file : str
        Name of Pile data file.

    Returns
    -------
    PileIndex
        Pile index.
    """

    data_path = os.path.join("pile/train", data_file)
    index_path = os.path.join("indexes/roberta-large", data_file + ".index")
    assert os.path.exists(data_path), str(data_path)
    assert os.path.exists(index_path), str(index_path)
    index, data_dict = build_index(data_path, index_path, metric, normalized)
    return PileIndex(index, data_dict, metric, normalized)


def split_index_data(pile_index: PileIndex, num_splits: int):
    """Split index and data into num_splits pieces.

    Parameters
    ----------
    pile_index : PileIndex
        Index to split.
    num_splits : int
        Number of splits to make.

    Yields
    ------
    PileIndex
    """

    Index = get_index_class(pile_index.metric)
    index = pile_index.index
    data_dict = pile_index.data_dict

    chunk_size, remainder = divmod(index.ntotal, num_splits)
    for i in range(0, num_splits):
        offset = i * chunk_size
        # handle last chunk
        if i == num_splits - 1:
            chunk_size += remainder
        vectors = index.reconstruct_n(offset, chunk_size)
        data_split = [data_dict[k] for k in range(offset, offset + chunk_size)]
        index_split = Index(index.d)
        index_split.add(vectors)
        yield PileIndex(index_split, data_split, pile_index.metric, pile_index.normalized, embedding_model=pile_index.embedding_model)


def get_neighbours(
    index,
    query,
    num_neighbours,
    acquisition_function_name,
    k,
    seed,
    noise,
):
    if acquisition_function_name == "NearestNeighbour":
        t_start = time.time()
        values, indices, vectors, data_items = index.vector_query(query, num_neighbours)
        if index.metric == Metric.ABSIP:
            values2, indices2, vectors2, data_items2 = index.vector_query(-query, num_neighbours)
        else:
            values2, indices2, vectors2, data_items2 = np.empty([0]), np.empty([0], dtype=int), np.empty([0,vectors.shape[1]]), np.empty([0,])
        if index.metric == Metric.L2:
            sorted_indices = np.arange(num_neighbours)
        else:
            sorted_indices = np.argsort(-np.concatenate([values, values2]))[:num_neighbours]  # indices of neighbors with largest values
        t_retrieval = time.time() - t_start
        return np.concatenate([values, values2])[sorted_indices], np.concatenate([indices, indices2])[sorted_indices], np.concatenate([vectors, vectors2])[sorted_indices], np.concatenate([data_items, data_items2])[sorted_indices].tolist(), RetrievalTime(faiss=t_retrieval, afsl=0)

    if acquisition_function_name == "ONN":
        t_start = time.time()
        values, indices, vectors, data_items = index.vector_query(query, 1)
        t_retrieval = time.time() - t_start
        return (
            np.repeat(values, num_neighbours, axis=0),
            [indices[0]] * num_neighbours,
            np.repeat(vectors, num_neighbours, axis=0),
            list(np.repeat(data_items, num_neighbours, axis=0)),
            RetrievalTime(faiss=t_retrieval, afsl=0) ,
        )

    threads = 1

    if acquisition_function_name == "Random-preselected":
        acquisition_function = Random(
            mini_batch_size=num_neighbours, num_workers=threads
        )
    elif acquisition_function_name == "ITL":
        acquisition_function = ITL(
            target=torch.Tensor(),
            num_workers=threads,
            subsample=False,
            force_nonsequential=False,
            noise_std=noise,
        )
    elif acquisition_function_name == "ITL-noiseless":
        acquisition_function = ITLNoiseless(
            target=torch.Tensor(),
            num_workers=threads,
            subsample=False,
            force_nonsequential=False,
        )
    elif acquisition_function_name == "UncertaintySampling":
        acquisition_function = UncertaintySampling(
            num_workers=threads,
            subsample=False,
            noise_std=noise,
        )
    elif acquisition_function_name == "CTL":
        acquisition_function = CTL(
            target=torch.Tensor(),
            num_workers=threads,
            subsample=False,
            force_nonsequential=False,
            noise_std=noise,
        )
    elif acquisition_function_name == "VTL":
        acquisition_function = VTL(
            target=torch.Tensor(),
            num_workers=threads,
            subsample=False,
            force_nonsequential=False,
            noise_std=noise,
        )
    elif acquisition_function_name == "LazyVTL":
        acquisition_function = LazyVTL(
            target=torch.Tensor(),
            num_workers=threads,
            subsample=False,
            noise_std=noise,
        )
    else:
        raise NotImplementedError

    itl = Retriever(
        index.index,
        acquisition_function,
        device=get_device(),
    )

    values, indices, vectors, times = itl.search(query, N=num_neighbours, k=k, threads=threads)
    data_items = [index.data_dict[i] for i in indices]

    return values, indices, vectors, data_items, times
