import numpy as np
import os
from collections import Counter
import copy

from pe.histogram import Histogram
from pe.constant.data import CLEAN_HISTOGRAM_COLUMN_NAME
from pe.constant.data import LOOKAHEAD_EMBEDDING_COLUMN_NAME
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.constant.data import HISTOGRAM_NEAREST_NEIGHBORS_VOTING_IDS_COLUMN_NAME
from pe.logging import execution_logger


class NearestNeighbors(Histogram):
    """Compute the nearest neighbors histogram. Each private sample will vote for their closest `num_nearest_neighbors`
    synthetic samples to construct the histogram. The l2 norm of the votes from each private sample is normalized to 1.
    """

    def __init__(
        self,
        embedding,
        mode,
        lookahead_degree,
        lookahead_log_folder=None,
        voting_details_log_folder=None,
        api=None,
        num_nearest_neighbors=1,
        backend="auto",
        vote_normalization_level="sample",
    ):
        """Constructor.

        :param embedding: The :py:class:`pe.embedding.Embedding` object to compute the embedding of samples
        :type embedding: :py:class:`pe.embedding.Embedding`
        :param mode: The distance metric to use for finding the nearest neighbors. It should be one of the following:
            "l2" (l2 distance), "cos_sim" (cosine similarity), "ip" (inner product). Not all backends support all
            modes
        :type mode: str
        :param lookahead_degree: The degree of lookahead to compute the embedding of synthetic samples. If it is 0, the
            original embedding is used. If it is greater than 0, the embedding of the synthetic samples is computed by
            averaging the embeddings of the synthetic samples generated by the variation API for `lookahead_degree`
            times
        :type lookahead_degree: int
        :param lookahead_log_folder: The folder to save the logs of the lookahead. If it is None, the logs are not
            saved. Defaults to None
        :type lookahead_log_folder: str, optional
        :param voting_details_log_folder: The folder to save the logs of the voting details. If it is None, the logs
            are not saved. Defaults to None
        :type voting_details_log_folder: str, optional
        :param api: The API to generate synthetic samples. It should be provided when `lookahead_degree` is greater
            than 0. Defaults to None
        :type api: :py:class:`pe.api.API`, optional
        :param num_nearest_neighbors: The number of nearest neighbors to consider for each private sample, defaults to
            1
        :type num_nearest_neighbors: int, optional
        :param backend: The backend to use for finding the nearest neighbors. It should be one of the following:
            "faiss" (FAISS), "sklearn" (scikit-learn), "auto" (using FAISS if available, otherwise scikit-learn).
            Defaults to "auto". FAISS supports GPU and is much faster when the number of synthetic samples and/or
            private samples is large. It requires the installation of `faiss-gpu` or `faiss-cpu` package. See
            https://faiss.ai/
        :type backend: str, optional
        :param vote_normalization_level: The level of normalization for the votes. It should be one of the following:
            "sample" (normalize the votes from each private sample to have l2 norm = 1), "client" (normalize the votes
            from all private samples of the same client to have l2 norm = 1). Defaults to "sample"
        :type vote_normalization_level: str, optional
        :raises ValueError: If the `api` is not provided when `lookahead_degree` is greater than 0
        :raises ValueError: If the `backend` is unknown
        """
        super().__init__()
        self._embedding = embedding
        self._mode = mode
        self._lookahead_degree = lookahead_degree
        self._lookahead_log_folder = lookahead_log_folder
        self._voting_details_log_folder = voting_details_log_folder
        self._api = api
        self._num_nearest_neighbors = num_nearest_neighbors
        if self._lookahead_degree > 0 and self._api is None:
            raise ValueError("API should be provided when lookahead_degree is greater than 0")
        if backend.lower() == "faiss":
            from pe.histogram.nearest_neighbor_backend.faiss import search

            self._search = search
        elif backend.lower() == "sklearn":
            from pe.histogram.nearest_neighbor_backend.sklearn import search

            self._search = search
        elif backend.lower() == "auto":
            from pe.histogram.nearest_neighbor_backend.auto import search

            self._search = search
        else:
            raise ValueError(f"Unknown backend: {backend}")

        self._vote_normalization_level = vote_normalization_level

    def _log_lookahead(self, syn_data, lookahead_id):
        """Log the lookahead data.

        :param syn_data: The lookahead data
        :type syn_data: :py:class:`pe.data.Data`
        :param lookahead_id: The ID of the lookahead
        :type lookahead_id: int
        """
        if self._lookahead_log_folder is None:
            return
        labels = set(list(syn_data.data_frame[LABEL_ID_COLUMN_NAME].values))
        assert len(labels) == 1
        label = list(labels)[0]
        iteration = syn_data.metadata["iteration"]
        log_folder = os.path.join(
            self._lookahead_log_folder, f"{iteration}", f"label-id{label}_lookahead{lookahead_id}"
        )
        syn_data.save_checkpoint(log_folder)

    def _log_voting_details(self, priv_data, syn_data, ids):
        """Log the voting details.

        :param priv_data: The private data
        :type priv_data: :py:class:`pe.data.Data`
        :param syn_data: The synthetic data
        :type syn_data: :py:class:`pe.data.Data`
        :param ids: The IDs of the nearest neighbors for each private sample
        :type ids: np.ndarray
        """
        if self._voting_details_log_folder is None:
            return
        labels = set(list(priv_data.data_frame[LABEL_ID_COLUMN_NAME].values))
        assert len(labels) == 1
        label = list(labels)[0]
        iteration = syn_data.metadata["iteration"]
        log_folder = os.path.join(self._voting_details_log_folder, f"{iteration}", f"label-id{label}")
        priv_data = copy.deepcopy(priv_data)
        priv_data.data_frame[HISTOGRAM_NEAREST_NEIGHBORS_VOTING_IDS_COLUMN_NAME] = list(ids)
        priv_data.save_checkpoint(log_folder)

    def _compute_lookahead_embedding(self, syn_data):
        """Compute the embedding of synthetic samples with lookahead.

        :param syn_data: The synthetic data
        :type syn_data: :py:class:`pe.data.Data`
        :return: The synthetic data with the computed embedding in the column
            :py:const:`pe.constant.data.LOOKAHEAD_EMBEDDING_COLUMN_NAME`
        :rtype: :py:class:`pe.data.Data`
        """
        if self._lookahead_degree == 0:
            syn_data = self._embedding.compute_embedding(syn_data)
            syn_data.data_frame[LOOKAHEAD_EMBEDDING_COLUMN_NAME] = syn_data.data_frame[self._embedding.column_name]
        else:
            embedding_list = []
            for lookahead_id in range(self._lookahead_degree):
                variation_data = self._api.variation_api(syn_data=syn_data)
                variation_data = self._embedding.compute_embedding(variation_data)
                self._log_lookahead(syn_data=variation_data, lookahead_id=lookahead_id)
                embedding_list.append(
                    np.stack(
                        variation_data.data_frame[self._embedding.column_name].values,
                        axis=0,
                    )
                )
            embedding = np.mean(embedding_list, axis=0)
            syn_data.data_frame[LOOKAHEAD_EMBEDDING_COLUMN_NAME] = list(embedding)
        self._log_lookahead(syn_data=syn_data, lookahead_id=-1)

        return syn_data

    def compute_histogram(self, priv_data, syn_data):
        """Compute the nearest neighbors histogram.

        :param priv_data: The private data
        :type priv_data: :py:class:`pe.data.Data`
        :param syn_data: The synthetic data
        :type syn_data: :py:class:`pe.data.Data`
        :raises ValueError: If the `vote_normalization_level` is unknown
        :return: The private data, possibly with the additional embedding column, and the synthetic data, with the
            computed histogram in the column :py:const:`pe.constant.data.CLEAN_HISTOGRAM_COLUMN_NAME` and possibly with
            the additional embedding column
        :rtype: tuple[:py:class:`pe.data.Data`, :py:class:`pe.data.Data`]
        """
        execution_logger.info(
            f"Histogram: computing nearest neighbors histogram for {len(priv_data.data_frame)} private "
            f"samples and {len(syn_data.data_frame)} synthetic samples"
        )

        priv_data = self._embedding.compute_embedding(priv_data)
        syn_data = self._compute_lookahead_embedding(syn_data)

        priv_embedding = np.stack(priv_data.data_frame[self._embedding.column_name].values, axis=0).astype(np.float32)
        syn_embedding = np.stack(syn_data.data_frame[LOOKAHEAD_EMBEDDING_COLUMN_NAME].values, axis=0).astype(
            np.float32
        )

        _, ids = self._search(
            syn_embedding=syn_embedding,
            priv_embedding=priv_embedding,
            num_nearest_neighbors=self._num_nearest_neighbors,
            mode=self._mode,
        )
        self._log_voting_details(priv_data=priv_data, syn_data=syn_data, ids=ids)

        priv_data = priv_data.reset_index(drop=True)
        if self._vote_normalization_level == "client":
            priv_data_list = priv_data.split_by_client()
        elif self._vote_normalization_level == "sample":
            priv_data_list = priv_data.split_by_index()
        else:
            raise ValueError(f"Unknown vote normalization level: {self._vote_normalization_level}")

        count = np.zeros(shape=syn_embedding.shape[0], dtype=np.float32)
        for sub_priv_data in priv_data_list:
            sub_count = np.zeros(shape=syn_embedding.shape[0], dtype=np.float32)
            sub_ids = ids[sub_priv_data.data_frame.index]
            counter = Counter(list(sub_ids.flatten()))
            sub_count[list(counter.keys())] = list(counter.values())
            sub_count /= np.linalg.norm(sub_count)
            count += sub_count

        syn_data.data_frame[CLEAN_HISTOGRAM_COLUMN_NAME] = count

        execution_logger.info(
            f"Histogram: finished computing nearest neighbors histogram for {len(priv_data.data_frame)} private "
            f"samples and {len(syn_data.data_frame)} synthetic samples"
        )

        return priv_data, syn_data
