"""Topk Retriever"""

from openicl import DatasetReader
from openicl.icl_dataset_reader import DatasetEncoder
from openicl.icl_retriever import BaseRetriever
from openicl.utils.collators import DataCollatorWithPaddingAndCuda
from openicl.utils.logging import get_logger
import torch
from torch.utils.data import DataLoader
from typing import Optional
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import tqdm
import faiss
import copy
import numpy as np
from accelerate import Accelerator

logger = get_logger(__name__)


class TopkRetriever(BaseRetriever):
    """Topk In-context Learning Retriever Class
        Class of Topk Retriever.

    Attributes:
        dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
        ice_separator (:obj:`str`, optional): A string that separates each in-context example.
        ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
        prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
        ice_num (:obj:`int`, optional): The number of data in the in-context examples.
        index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
        test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
        index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
        test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
        accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
        batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
        model (:obj:`SentenceTransformer`): An instance of :obj:`SentenceTransformer` class, used to calculate embeddings.
        tokenizer (:obj:`AutoTokenizer`): Tokenizer for :obj:`model`.
        index (:obj:`IndexIDMap`): Index generated with FAISS.
    """

    model = None

    def __init__(
        self,
        dataset_reader: DatasetReader,
        ice_separator: Optional[str] = "\n",
        ice_eos_token: Optional[str] = "\n",
        prompt_eos_token: Optional[str] = "",
        sentence_transformers_model_name: Optional[str] = "all-mpnet-base-v2",
        ice_num: Optional[int] = 1,
        index_split: Optional[str] = "train",
        test_split: Optional[str] = "test",
        tokenizer_name: Optional[str] = "gpt2-xl",
        batch_size: Optional[int] = 1,
        accelerator: Optional[Accelerator] = None,
    ) -> None:
        super().__init__(
            dataset_reader,
            ice_separator,
            ice_eos_token,
            prompt_eos_token,
            ice_num,
            index_split,
            test_split,
            accelerator,
        )
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.batch_size = batch_size
        self.tokenizer_name = tokenizer_name
        gen_datalist = self.dataset_reader.generate_input_field_corpus(self.test_ds)

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.tokenizer.padding_side = "right"

        self.encode_dataset = DatasetEncoder(gen_datalist, tokenizer=self.tokenizer)
        co = DataCollatorWithPaddingAndCuda(
            tokenizer=self.tokenizer, device=self.device
        )
        self.dataloader = DataLoader(
            self.encode_dataset, batch_size=self.batch_size, collate_fn=co
        )

        self.model = SentenceTransformer(sentence_transformers_model_name)

        self.model = self.model.to(self.device)
        self.model.eval()

        self.index = self.create_index()

    def create_index(self):
        self.select_datalist = self.dataset_reader.generate_input_field_corpus(
            self.index_ds
        )
        encode_datalist = DatasetEncoder(self.select_datalist, tokenizer=self.tokenizer)
        co = DataCollatorWithPaddingAndCuda(
            tokenizer=self.tokenizer, device=self.device
        )
        dataloader = DataLoader(
            encode_datalist, batch_size=self.batch_size, collate_fn=co
        )
        index = faiss.IndexIDMap(
            faiss.IndexFlatIP(self.model.get_sentence_embedding_dimension())
        )
        res_list = self.forward(
            dataloader, process_bar=True, information="Creating index for index set..."
        )
        id_list = np.array(
            [res["metadata"]["id"] for res in res_list]
        )  # (train_sample_num,)
        self.embed_list = np.stack(
            [res["embed"] for res in res_list]
        )  # (train_sample_num, embedding_size)
        index.add_with_ids(self.embed_list, id_list)
        return index

    def knn_search(self, ice_num):
        sample_num = len(self.dataloader.dataset)
        rtr_idx_list = [[] for _ in range(sample_num)]

        if ice_num > 0:
            res_list = self.forward(
                self.dataloader, process_bar=True, information="Embedding test set..."
            )
            logger.info("Retrieving data for test set...")
            for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
                idx = entry["metadata"]["id"]
                embed = np.expand_dims(entry["embed"], axis=0)
                near_ids = self.index.search(embed, ice_num)[1][0].tolist()
                rtr_idx_list[idx] = near_ids

        return rtr_idx_list  # list of each test query's corresponding ICE's id's (not 'idx' value for each sample)

    def forward(self, dataloader, process_bar=False, information=""):
        res_list = []
        _dataloader = copy.deepcopy(dataloader)
        if process_bar:
            logger.info(information)
            _dataloader = tqdm.tqdm(_dataloader, disable=not self.is_main_process)
        for _, entry in enumerate(_dataloader):
            with torch.no_grad():
                metadata = entry.pop("metadata")
                raw_text = self.tokenizer.batch_decode(
                    entry["input_ids"], skip_special_tokens=True, verbose=False
                )
                res = self.model.encode(raw_text, show_progress_bar=False)
            res_list.extend(
                [{"embed": r, "metadata": m} for r, m in zip(res, metadata)]
            )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}}
        return res_list

    def retrieve(self):
        return self.knn_search(self.ice_num)
