"""Topk FedRetriever"""

from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from accelerate import Accelerator

from openicl import DatasetReader, PromptTemplate
from openicl.icl_dataset_reader import DatasetEncoder
from openicl.icl_retriever import BaseRetriever, TopkRetriever
from openicl.utils.collators import DataCollatorWithPaddingAndCuda
from openicl.utils.logging import get_logger

import torch
from torch.utils.data import DataLoader

from typing import List, Union, Optional, Tuple, Dict, Any
import faiss
import copy
import numpy as np
import tqdm


logger = get_logger(__name__)


class TopkFedRetriever(TopkRetriever):
    def __init__(
        self,
        dataset_reader: DatasetReader,
        ice_separator: Optional[str] = "\n",
        ice_eos_token: Optional[str] = "\n",
        prompt_eos_token: Optional[str] = "",
        ice_num: Optional[int] = 1,
        index_split: Optional[str] = "train",
        test_split: Optional[str] = "test",
        tokenizer_name: Optional[str] = "gpt2-xl",
        batch_size: Optional[int] = 1,
        accelerator: Optional[Accelerator] = None,
    ) -> None:
        BaseRetriever.__init__(
            self,
            dataset_reader,
            ice_separator,
            ice_eos_token,
            prompt_eos_token,
            ice_num,
            index_split,
            test_split,
            accelerator,
        )
        self.index = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.batch_size = batch_size
        self.tokenizer_name = tokenizer_name

        # set tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.tokenizer.padding_side = "right"

    def knn_search(self, ice_num, query_dataloader, orig_idxs, model):
        res_list = self.forward(
            model,
            query_dataloader,
            orig_idxs=orig_idxs,
            process_bar=True,
            information="Embedding test set...",
        )
        rtr_idx_list = [[] for _ in range(len(res_list))]
        logger.info("Retrieving data for test set...")
        for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
            idx = entry["metadata"]["id"]
            embed = np.expand_dims(entry["embed"], axis=0)
            near_ids = self.index.search(embed, ice_num)[1][0].tolist()
            rtr_idx_list[idx] = near_ids
        return rtr_idx_list  # list of each test query's corresponding ICE's id's (not 'idx' value for each sample)

    def knn_search_with_budget(
        self, res_list: List[Any], ice_budgets: List[int], model=None
    ):
        rtr_idx_list = [[] for _ in range(len(res_list))]
        logger.info("Retrieving data for test set given local budget...")
        for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
            idx = entry["metadata"]["id"]
            embed = np.expand_dims(entry["embed"], axis=0)
            # print(f"ice_budgets[{idx}]={ice_budgets[idx]}")  # TODO: for debug
            if ice_budgets[idx] > 0:
                near_ids = self.index.search(embed, ice_budgets[idx])[1][0].tolist()
            else:
                near_ids = []
            rtr_idx_list[idx] = near_ids

        return rtr_idx_list  # list of each test query's corresponding ICE's id's (not 'idx' value for each sample)

    def retrieve_with_budget(
        self, res_list: List[Any], ice_budgets: List[int], model=None
    ) -> List[List]:
        model.eval()
        if self.index is None:
            self.index = self.create_index(model)

        return self.knn_search_with_budget(
            res_list, ice_budgets=ice_budgets, model=model
        )

    def create_dataloader(self, datalist):
        encode_dataset = DatasetEncoder(datalist, tokenizer=self.tokenizer)
        co = DataCollatorWithPaddingAndCuda(
            tokenizer=self.tokenizer, device=self.device
        )
        dataloader = DataLoader(
            encode_dataset, batch_size=self.batch_size, collate_fn=co
        )
        return dataloader

    def retrieve(
        self,
        query_dataset: Union[Dataset, DatasetDict],
        split: Optional[str] = None,
        ice_num: Optional[int] = None,
        model=None,
    ) -> List[List]:
        # set index search for trainset embedding vectors
        model.eval()
        if self.index is None:
            self.index = self.create_index(model)

        if split is not None and isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[split]

        query_datalist = self.dataset_reader.generate_input_field_corpus(query_dataset)
        query_dataloader = self.create_dataloader(query_datalist)

        if ice_num is None:
            ice_num = self.ice_num

        orig_idxs = query_dataset["idx"]

        return self.knn_search(
            ice_num, query_dataloader, orig_idxs=orig_idxs, model=model
        )

    def create_index(self, model=None):
        orig_idxs = self.index_ds["idx"]
        self.select_datalist = self.dataset_reader.generate_input_field_corpus(
            self.index_ds
        )
        dataloader = self.create_dataloader(self.select_datalist)
        index = faiss.IndexIDMap(
            faiss.IndexFlatIP(model.get_sentence_embedding_dimension())
        )
        res_list = self.forward(
            model,
            dataloader,
            orig_idxs=orig_idxs,
            process_bar=True,
            information="Creating index for index set...",
        )
        id_list = np.array(
            [res["metadata"]["id"] for res in res_list]
        )  # (train_sample_num,)
        self.embed_list = np.stack(
            [res["embed"] for res in res_list]
        )  # (train_sample_num, embedding_size)
        index.add_with_ids(self.embed_list, id_list)
        return index

    def forward(
        self, model, dataloader, orig_idxs=None, process_bar=False, information=""
    ):
        res_list = []
        _dataloader = copy.deepcopy(dataloader)
        batch_size = dataloader.batch_size
        if process_bar:
            logger.info(information)
            _dataloader = tqdm.tqdm(_dataloader, disable=not self.is_main_process)
        for batch_idx, entry in enumerate(_dataloader):
            cur_bs = len(entry["metadata"])
            batch_orig_idxs = orig_idxs[
                batch_idx * batch_size : batch_idx * batch_size + cur_bs
            ]
            with torch.no_grad():
                metadata = entry.pop("metadata")
                raw_text = self.tokenizer.batch_decode(
                    entry["input_ids"], skip_special_tokens=True, verbose=False
                )
                res = model.encode(raw_text, show_progress_bar=False)
            res_list.extend(
                [
                    {"embed": r, "metadata": m, "idx": orig_id}
                    for r, m, orig_id in zip(res, metadata, batch_orig_idxs)
                ]
            )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}
        return res_list

    def generate_label_prompt(
        self,
        idx: int = None,
        query: Dict = None,
        ice: str = None,
        label=None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        remain_sep: Optional[bool] = False,
    ) -> str:
        # specify query directly or the idx
        assert (idx is not None) or (query is not None)
        # if idx specified, then the query is indexed from self.test_ds[idx]
        if idx is None:
            assert isinstance(query, Dict)

        if query is None:
            assert self.test_ds is not None
            query = self.test_ds[idx]

        if prompt_template is not None:
            return (
                prompt_template.generate_label_prompt_item(
                    query, ice, label, remain_sep
                )
                + self.prompt_eos_token
            )
        elif ice_template is not None and ice_template.ice_token is not None:
            return (
                ice_template.generate_label_prompt_item(query, ice, label, remain_sep)
                + self.prompt_eos_token
            )
        else:
            prefix_prompt = " ".join(
                list(
                    map(
                        str,
                        [query[ctx] for ctx in self.dataset_reader.input_columns],
                    )
                )
            )
            return ice + prefix_prompt + " " + str(label) + self.prompt_eos_token


class TopkServerRetriever(BaseRetriever):
    def __init__(
        self,
        dataset_reader: DatasetReader,
        ice_separator: Optional[str] = "\n",
        ice_eos_token: Optional[str] = "\n",
        prompt_eos_token: Optional[str] = "",
        ice_num: Optional[int] = 1,
        index_split: Optional[str] = "train",
        test_split: Optional[str] = "test",
        tokenizer_name: Optional[str] = "gpt2-xl",
        batch_size: Optional[int] = 1,
        accelerator: Optional[Accelerator] = None,
    ) -> None:
        BaseRetriever.__init__(
            self,
            dataset_reader,
            ice_separator,
            ice_eos_token,
            prompt_eos_token,
            ice_num,
            index_split,
            test_split,
            accelerator,
        )
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.batch_size = batch_size
        self.tokenizer_name = tokenizer_name
        gen_datalist = self.dataset_reader.generate_input_field_corpus(self.test_ds)

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.tokenizer.padding_side = "right"

        self.encode_dataset = DatasetEncoder(gen_datalist, tokenizer=self.tokenizer)
        co = DataCollatorWithPaddingAndCuda(
            tokenizer=self.tokenizer, device=self.device
        )
        self.dataloader = DataLoader(
            self.encode_dataset, batch_size=self.batch_size, collate_fn=co
        )
        self.orig_idxs = self.test_ds["idx"]

    def create_index(self, model=None):
        self.select_datalist = self.dataset_reader.generate_input_field_corpus(
            self.index_ds
        )
        orig_idxs = self.index_ds["idx"]
        encode_datalist = DatasetEncoder(self.select_datalist, tokenizer=self.tokenizer)
        co = DataCollatorWithPaddingAndCuda(
            tokenizer=self.tokenizer, device=self.device
        )
        dataloader = DataLoader(
            encode_datalist, batch_size=self.batch_size, collate_fn=co
        )
        index = faiss.IndexIDMap(
            faiss.IndexFlatIP(model.get_sentence_embedding_dimension())
        )
        res_list = self.forward(
            model,
            dataloader,
            orig_idxs=orig_idxs,
            process_bar=False,
            information="Creating index for index set...",
        )
        id_list = np.array(
            [res["metadata"]["id"] for res in res_list]
        )  # (train_sample_num,)
        self.embed_list = np.stack(
            [res["embed"] for res in res_list]
        )  # (train_sample_num, embedding_size)
        index.add_with_ids(self.embed_list, id_list)
        return index

    def forward(
        self, model, dataloader, orig_idxs=None, process_bar=False, information=""
    ):
        res_list = []
        _dataloader = copy.deepcopy(dataloader)
        batch_size = dataloader.batch_size
        if process_bar:
            logger.info(information)
            _dataloader = tqdm.tqdm(_dataloader, disable=not self.is_main_process)
        for batch_idx, entry in enumerate(_dataloader):
            cur_bs = len(entry["metadata"])
            batch_orig_idxs = orig_idxs[
                batch_idx * batch_size : batch_idx * batch_size + cur_bs
            ]
            with torch.no_grad():
                metadata = entry.pop("metadata")
                raw_text = self.tokenizer.batch_decode(
                    entry["input_ids"], skip_special_tokens=True, verbose=False
                )
                res = model.encode(raw_text, show_progress_bar=False)
            res_list.extend(
                [
                    {"embed": r, "metadata": m, "idx": orig_id}
                    for r, m, orig_id in zip(res, metadata, batch_orig_idxs)
                ]
            )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}
        return res_list

    def knn_search(self, ice_num, model, use_trange=True):
        res_list = self.forward(
            model,
            self.dataloader,
            orig_idxs=self.orig_idxs,
            process_bar=use_trange,
            information="Embedding test set...",
        )
        rtr_idx_list = [[] for _ in range(len(res_list))]
        if use_trange:
            range_obj = tqdm.tqdm(res_list, disable=not self.is_main_process)
            logger.info("Retrieving data for test set...")
        else:
            range_obj = res_list
        for entry in range_obj:
            idx = entry["metadata"]["id"]
            embed = np.expand_dims(entry["embed"], axis=0)
            near_ids = self.index.search(embed, ice_num)[1][0].tolist()
            rtr_idx_list[idx] = near_ids
        return rtr_idx_list  # list of each test query's corresponding ICE's id's (not 'idx' value for each sample)

    def retrieve(self, model=None, ice_num=None, use_trange=True):
        model.eval()
        self.index = self.create_index(model)
        if ice_num is None:
            ice_num = self.ice_num
        return self.knn_search(ice_num, model, use_trange)
