# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import faiss  # @manual=//faiss/python:pyfaiss

from .benchmark_io import BenchmarkIO
from .utils import timer

logger = logging.getLogger(__name__)


# Important: filenames end with . without extension (npy, codec, index),
# when writing files, you are required to filename + "npy" etc.

@dataclass
class IndexDescriptorClassic:
    bucket: Optional[str] = None
    # either path or factory should be set,
    # but not both at the same time.
    path: Optional[str] = None
    factory: Optional[str] = None
    codec_alias: Optional[str] = None
    construction_params: Optional[List[Dict[str, int]]] = None
    search_params: Optional[Dict[str, int]] = None
    # range metric definitions
    # key: name
    # value: one of the following:
    #
    # radius
    #    [0..radius) -> 1
    #    [radius..inf) -> 0
    #
    # [[radius1, score1], ...]
    #    [0..radius1) -> score1
    #    [radius1..radius2) -> score2
    #
    # [[radius1_from, radius1_to, score1], ...]
    #    [radius1_from, radius1_to) -> score1,
    #    [radius2_from, radius2_to) -> score2
    range_metrics: Optional[Dict[str, Any]] = None
    radius: Optional[float] = None
    training_size: Optional[int] = None

    def __hash__(self):
        return hash(str(self))

@dataclass
class DatasetDescriptor:
    # namespace possible values:
    # 1. a hive namespace
    # 2. 'std_t', 'std_d', 'std_q' for the standard datasets
    #    via faiss.contrib.datasets.dataset_from_name()
    #    t - training, d - database, q - queries
    #    eg. "std_t"
    # 3. 'syn' for synthetic data
    # 4. None for local files
    namespace: Optional[str] = None

    # tablename possible values, corresponding to the
    # namespace value above:
    # 1. a hive table name
    # 2. name of the standard dataset as recognized
    #    by faiss.contrib.datasets.dataset_from_name()
    #    eg. "bigann1M"
    # 3. d_seed, eg. 128_1234 for 128 dimensional vectors
    #    with seed 1234
    # 4. a local file name (relative to benchmark_io.path)
    tablename: Optional[str] = None

    # partition names and values for hive
    # eg. ["ds=2021-09-01"]
    partitions: Optional[List[str]] = None

    # number of vectors to load from the dataset
    num_vectors: Optional[int] = None

    embedding_column: Optional[str] = None

    # only when the embedding column is a map
    embedding_column_key: Optional[Any] = None

    embedding_id_column: Optional[str] = None

    # only used when previous_assignment_table is set
    # this represents the centroid id that the embedding was mapped to
    # in a previous clustering job
    centroid_id_column: Optional[str] = None

    # filters on the dataset where each filter is a
    # string rep of a filter expression
    filters: Optional[List[str]] = None

    # unused in open-source
    splits_distribution: Optional[List[List[bytes]]] = None

    # unused in open-source
    splits: Optional[List[bytes]] = None

    # unused in open-source
    serialized_df: Optional[str] = None

    sampling_rate: Optional[float] = None

    # sampling column for xdb
    sampling_column: Optional[str] = None

    # blob store
    bucket: Optional[str] = None
    path: Optional[str] = None

    # desc_name
    desc_name: Optional[str] = None

    filename_suffix: Optional[str] = None

    normalize_L2: bool = False

    def __hash__(self):
        return hash(self.get_filename())

    def get_filename(
        self,
        prefix: Optional[str] = None,
    ) -> str:
        if self.desc_name is not None:
            return self.desc_name

        filename = ""
        if prefix is not None:
            filename += prefix + "_"
        if self.namespace is not None:
            filename += self.namespace + "_"
        assert self.tablename is not None
        filename += self.tablename
        if self.partitions is not None:
            filename += "_" + "_".join(
                self.partitions
            ).replace("=", "_").replace("/", "_")
        if self.num_vectors is not None:
            filename += f"_{self.num_vectors}"
        if self.filename_suffix is not None:
            filename += f"_{self.filename_suffix}"
        filename += "."

        self.desc_name = filename
        return self.desc_name

    def get_kmeans_filename(self, k):
        return f"{self.get_filename()}kmeans_{k}."

    def k_means(self, io, k, dry_run):
        logger.info(f"k_means {k} {self}")
        kmeans_vectors = DatasetDescriptor(
            tablename=f"{self.get_filename()}kmeans_{k}"
        )
        kmeans_filename = kmeans_vectors.get_filename() + "npy"
        meta_filename = kmeans_vectors.get_filename() + "json"
        if not io.file_exist(kmeans_filename) or not io.file_exist(
            meta_filename
        ):
            if dry_run:
                return None, None, kmeans_filename
            x = io.get_dataset(self)
            kmeans = faiss.Kmeans(d=x.shape[1], k=k, gpu=True)
            _, t, _ = timer("k_means", lambda: kmeans.train(x))
            io.write_nparray(kmeans.centroids, kmeans_filename)
            io.write_json({"k_means_time": t}, meta_filename)
        else:
            t = io.read_json(meta_filename)["k_means_time"]
        return kmeans_vectors, t, None

@dataclass
class IndexBaseDescriptor:
    d: int
    metric: str
    desc_name: Optional[str] = None
    flat_desc_name: Optional[str] = None
    bucket: Optional[str] = None
    path: Optional[str] = None
    num_threads: int = 1

    def get_name(self) -> str:
        raise NotImplementedError()

    def get_path(self, benchmark_io: BenchmarkIO) -> Optional[str]:
        if self.path is not None:
            return self.path
        self.path = benchmark_io.get_remote_filepath(self.desc_name)
        return self.path

    @staticmethod
    def param_dict_list_to_name(param_dict_list):
        if not param_dict_list:
            return ""
        l = 0
        n = ""
        for param_dict in param_dict_list:
            n += IndexBaseDescriptor.param_dict_to_name(param_dict, f"cp{l}")
            l += 1
        return n

    @staticmethod
    def param_dict_to_name(param_dict, prefix="sp"):
        if not param_dict:
            return ""
        n = prefix
        for name, val in param_dict.items():
            if name == "snap":
                continue
            if name == "lsq_gpu" and val == 0:
                continue
            if name == "use_beam_LUT" and val == 0:
                continue
            n += f"_{name}_{val}"
        if n == prefix:
            return ""
        n += "."
        return n


@dataclass
class CodecDescriptor(IndexBaseDescriptor):
    # either path or factory should be set,
    # but not both at the same time.
    factory: Optional[str] = None
    construction_params: Optional[List[Dict[str, int]]] = None
    training_vectors: Optional[DatasetDescriptor] = None
    normalize_l2: bool = False
    is_spherical: bool = False
    FILENAME_PREFIX: str = "xt"

    def __post_init__(self):
        self.get_name()

    def is_trained(self):
        return self.factory is None and self.path is not None

    def is_valid(self):
        return self.factory is not None or self.path is not None

    def get_name(self) -> str:
        if self.desc_name is not None:
            return self.desc_name
        if self.factory is not None:
            self.desc_name = self.name_from_factory()
            return self.desc_name
        if self.path is not None:
            self.desc_name = self.name_from_path()
            return self.desc_name
        raise ValueError("name, factory or path must be set")

    def flat_name(self) -> str:
        if self.flat_desc_name is not None:
            return self.flat_desc_name
        self.flat_desc_name = f"Flat.d_{self.d}.{self.metric.upper()}."
        return self.flat_desc_name

    def path(self, benchmark_io) -> str:
        if self.path is not None:
            return self.path
        return benchmark_io.get_remote_filepath(self.get_name())

    def name_from_factory(self) -> str:
        assert self.factory is not None
        name = f"{self.factory.replace(',', '_')}."
        assert self.d is not None
        assert self.metric is not None
        name += f"d_{self.d}.{self.metric.upper()}."
        if self.factory != "Flat":
            assert self.training_vectors is not None
            name += self.training_vectors.get_filename(CodecDescriptor.FILENAME_PREFIX)
        name += IndexBaseDescriptor.param_dict_list_to_name(self.construction_params)
        return name

    def name_from_path(self):
        assert self.path is not None
        filename = os.path.basename(self.path)
        ext = filename.split(".")[-1]
        if filename.endswith(ext):
            name = filename[:-len(ext)]
        else: # should never hit this rather raise value error
            name = filename
        return name

    def alias(self, benchmark_io: BenchmarkIO):
        if hasattr(benchmark_io, "bucket"):
            return CodecDescriptor(desc_name=self.get_name(), bucket=benchmark_io.bucket, path=self.get_path(benchmark_io), d=self.d, metric=self.metric)
        return CodecDescriptor(desc_name=self.get_name(), d=self.d, metric=self.metric)


@dataclass
class IndexDescriptor(IndexBaseDescriptor):
    codec_desc: Optional[CodecDescriptor] = None
    database_desc: Optional[DatasetDescriptor] = None
    FILENAME_PREFIX: str = "xb"

    def __hash__(self):
        return hash(str(self))

    def __post_init__(self):
        self.get_name()

    def is_built(self):
        return self.codec_desc is None and self.database_desc is None

    def get_name(self) -> str:
        if self.desc_name is None:
            self.desc_name = self.codec_desc.get_name() + self.database_desc.get_filename(prefix=IndexDescriptor.FILENAME_PREFIX)

        return self.desc_name

    def flat_name(self):
        if self.flat_desc_name is not None:
            return self.flat_desc_name
        self.flat_desc_name = self.codec_desc.flat_name() + self.database_desc.get_filename(prefix=IndexDescriptor.FILENAME_PREFIX)
        return self.flat_desc_name

    # alias is used to refer when index is uploaded to blobstore and refered again
    def alias(self, benchmark_io: BenchmarkIO):
        if hasattr(benchmark_io, "bucket"):
            return IndexDescriptor(desc_name=self.get_name(), bucket=benchmark_io.bucket, path=self.get_path(benchmark_io), d=self.d, metric=self.metric)
        return IndexDescriptor(desc_name=self.get_name(), d=self.d, metric=self.metric)

@dataclass
class KnnDescriptor(IndexBaseDescriptor):
    index_desc: Optional[IndexDescriptor] = None
    gt_index_desc: Optional[IndexDescriptor] = None
    query_dataset: Optional[DatasetDescriptor] = None
    search_params: Optional[Dict[str, int]] = None
    reconstruct: bool = False
    FILENAME_PREFIX: str = "q"
    # range metric definitions
    # key: name
    # value: one of the following:
    #
    # radius
    #    [0..radius) -> 1
    #    [radius..inf) -> 0
    #
    # [[radius1, score1], ...]
    #    [0..radius1) -> score1
    #    [radius1..radius2) -> score2
    #
    # [[radius1_from, radius1_to, score1], ...]
    #    [radius1_from, radius1_to) -> score1,
    #    [radius2_from, radius2_to) -> score2
    range_metrics: Optional[Dict[str, Any]] = None
    radius: Optional[float] = None
    k: int = 1

    range_ref_index_desc: Optional[str] = None

    def __hash__(self):
        return hash(str(self))

    def get_name(self):
        if self.desc_name is not None:
            return self.desc_name
        name = self.index_desc.get_name()
        name += IndexBaseDescriptor.param_dict_to_name(self.search_params)
        name += self.query_dataset.get_filename(KnnDescriptor.FILENAME_PREFIX)
        name += f"k_{self.k}."
        name += f"t_{self.num_threads}."
        if self.reconstruct:
            name += "rec."
        else:
            name += "knn."
        self.desc_name = name
        return name

    def flat_name(self):
        if self.flat_desc_name is not None:
            return self.flat_desc_name
        name = self.index_desc.flat_name()
        name += self.query_dataset.get_filename(KnnDescriptor.FILENAME_PREFIX)
        name += f"k_{self.k}."
        name += f"t_{self.num_threads}."
        if self.reconstruct:
            name += "rec."
        else:
            name += "knn."
        self.flat_desc_name = name
        return name
