from datasets import Dataset, DatasetDict
from typing import List, Union, Optional, Tuple, Dict
from openicl import DatasetReader, PromptTemplate
from openicl.icl_retriever import BM25Retriever, BaseRetriever
from openicl.utils.check_type import _check_str
from accelerate import Accelerator
from rank_bm25 import BM25Okapi
import numpy as np
from tqdm import trange
from nltk.tokenize import word_tokenize


class BM25FedRetriever(BM25Retriever):
    def __init__(
        self,
        dataset_reader: DatasetReader,
        ice_separator: Optional[str] = "\n",
        ice_eos_token: Optional[str] = "\n",
        prompt_eos_token: Optional[str] = "",
        ice_num: Optional[int] = 1,
        index_split: Optional[str] = "train",
        test_split: Optional[str] = "test",
        accelerator: Optional[Accelerator] = None,
    ) -> None:
        BaseRetriever.__init__(
            self,
            dataset_reader,
            ice_separator,
            ice_eos_token,
            prompt_eos_token,
            ice_num,
            index_split,
            test_split,
            accelerator,
        )
        # self.test_ds = None
        self.index_corpus = [
            word_tokenize(data)
            for data in self.dataset_reader.generate_input_field_corpus(self.index_ds)
        ]  # trainset corpus, list of list of words
        self.bm25 = BM25Okapi(self.index_corpus)

    def retrieve(
        self,
        query_dataset: Union[Dataset, DatasetDict],
        split: Optional[str] = None,
        ice_num: Optional[int] = None,
        use_trange: bool = True,
        **kwargs
    ) -> List[List]:
        if ice_num is None:
            ice_num = self.ice_num

        if split is not None and isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[split]

        query_corpus = [
            word_tokenize(data)
            for data in self.dataset_reader.generate_input_field_corpus(query_dataset)
        ]  # get corpus
        rtr_idx_list = []
        if use_trange:
            range_obj = trange(len(query_corpus), disable=not self.is_main_process)
        else:
            range_obj = range(len(query_corpus))
        for idx in range_obj:
            query = query_corpus[idx]
            scores = self.bm25.get_scores(
                query
            )  # np.array with size (train_sample_num,), the similarity of current test query and all train samples
            near_ids = list(np.argsort(scores)[::-1][:ice_num])
            near_ids = [int(a) for a in near_ids]
            rtr_idx_list.append(
                near_ids
            )  # selected train sample idx for current test sample, list of integers
        return rtr_idx_list

    def generate_label_prompt(
        self,
        idx: int = None,
        query: Dict = None,
        ice: str = None,
        label=None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        remain_sep: Optional[bool] = False,
    ) -> str:
        # specify query directly or the idx
        assert (idx is not None) or (query is not None)
        # if idx specified, then the query is indexed from self.test_ds[idx]
        if idx is None:
            assert isinstance(query, Dict)

        if query is None:
            assert self.test_ds is not None
            query = self.test_ds[idx]

        if prompt_template is not None:
            return (
                prompt_template.generate_label_prompt_item(
                    query, ice, label, remain_sep
                )
                + self.prompt_eos_token
            )
        elif ice_template is not None and ice_template.ice_token is not None:
            return (
                ice_template.generate_label_prompt_item(query, ice, label, remain_sep)
                + self.prompt_eos_token
            )
        else:
            prefix_prompt = " ".join(
                list(
                    map(
                        str,
                        [query[ctx] for ctx in self.dataset_reader.input_columns],
                    )
                )
            )
            return ice + prefix_prompt + " " + str(label) + self.prompt_eos_token


class BM25ServerRetriever(BM25Retriever):
    def __init__(
        self,
        dataset_reader: DatasetReader,
        ice_separator: Optional[str] = "\n",
        ice_eos_token: Optional[str] = "\n",
        prompt_eos_token: Optional[str] = "",
        ice_num: Optional[int] = 1,
        index_split: Optional[str] = "train",
        test_split: Optional[str] = "test",
        accelerator: Optional[Accelerator] = None,
    ) -> None:
        super().__init__(
            dataset_reader,
            ice_separator,
            ice_eos_token,
            prompt_eos_token,
            ice_num,
            index_split,
            test_split,
            accelerator,
        )

    def retrieve(self, use_trange=True, **kargs) -> List[List]:
        return super().retrieve(use_trange=use_trange)
