"""Basic Retriever."""
from abc import abstractmethod
from typing import Dict, List, Optional

from mmengine.dist import is_main_process

from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.utils.prompt import PromptList


class BaseRetriever:
    """Base class for In-context Learning Example Retriever, without any
    retrieval method implemented.

    Args:
        dataset (`BaseDataset`): Any BaseDataset instances.
            Attributes of ``reader``, ``train`` and ``test`` will be used.
        ice_separator (`Optional[str]`): The separator between each in-context
            example template when origin `PromptTemplate` is provided. Defaults
            to '\n'.
        ice_eos_token (`Optional[str]`): The end of sentence token for
            in-context example template when origin `PromptTemplate` is
            provided. Defaults to '\n'.
        ice_num (`Optional[int]`): The number of in-context example template
            when origin `PromptTemplate` is provided. Defaults to 1.
    """
    index_ds = None
    test_ds = None

    def __init__(self,
                 dataset,
                 ice_separator: Optional[str] = '\n',
                 ice_eos_token: Optional[str] = '\n',
                 ice_num: Optional[int] = 1) -> None:
        self.ice_separator = ice_separator
        self.ice_eos_token = ice_eos_token
        self.ice_num = ice_num
        self.is_main_process = is_main_process()
        self.dataset_reader = dataset.reader
        self.index_ds = dataset.train
        self.test_ds = dataset.test

    @abstractmethod
    def retrieve(self) -> List[List[int]]:
        """Retrieve the in-context example index for each test example."""

    def get_labels(
            self,
            ice_template: Optional[PromptTemplate] = None,
            prompt_template: Optional[PromptTemplate] = None) -> List[str]:
        """Get the labels of the dataset, especially useful for ppl inferencer.
        If `ice_template` is provided, the labels will be the keys of the
        template. If `prompt_template` is provided, the labels will be the keys
        of the template. If neither of them is provided, the labels will be the
        unique values of the output column.

        Args:
            ice_template (`Optional[PromptTemplate]`): The template for
                in-context example. Defaults to None.
            prompt_template (`Optional[PromptTemplate]`): The template for
                prompt. Defaults to None.
        """
        if prompt_template is not None and isinstance(prompt_template.template,
                                                      Dict):
            labels = list(prompt_template.template.keys())
        elif ice_template is not None and ice_template.ice_token is not None \
                and isinstance(ice_template.template, Dict):
            labels = list(ice_template.template.keys())
        else:
            labels = list(set(self.test_ds[self.dataset_reader.output_column]))
        return labels

    def generate_ice(self,
                     idx_list: List[int],
                     ice_template: Optional[PromptTemplate] = None) -> str:
        """Generate the in-context example for one test example. If
        `ice_template` is an instance of `PromptTemplate`, the `ice_separator`
        and `ice_eos_token` will be set as empty.

        Args:
            idx_list (`List[int]`): The index of in-context examples for the
                test example.
            ice_template (`Optional[PromptTemplate]`): The template for
                in-context example. Defaults to None.
        """
        if ice_template is None:
            assert len(
                idx_list
            ) == 0, 'You have not specified ice_template while retrieving examples from train set! Please either specify ice_template or use `ZeroRetriever`.'  # noqa

        if ice_template is not None and ice_template.prompt_type == 'meta':
            ice_separator, ice_eos_token = '', ''
        else:
            ice_separator = self.ice_separator
            ice_eos_token = self.ice_eos_token

        generated_ice_list = []
        for idx in idx_list:
            generated_ice_list.append(
                ice_template.generate_ice_item(
                    self.index_ds[idx],
                    self.index_ds[idx][self.dataset_reader.output_column]))
        if len(generated_ice_list) > 0 and isinstance(generated_ice_list[0],
                                                      PromptList):
            generated_ice = []
            for ice in generated_ice_list:
                generated_ice += ice + ice_separator
            generated_ice.append(ice_eos_token)
        else:
            generated_ice = ice_separator.join(
                generated_ice_list) + ice_eos_token
        return generated_ice

    def generate_label_prompt(self,
                              idx: int,
                              ice: str,
                              label,
                              ice_template: Optional[PromptTemplate] = None,
                              prompt_template: Optional[PromptTemplate] = None,
                              remain_sep: Optional[bool] = False) -> str:
        """Generate the prompt for one test example in perpelxity evaluation
        with `prompt_template`. If `prompt_template` is not provided, the
        `ice_template` will be used to generate the prompt.

        Args:
            idx (`int`): The index of the test example.
            ice (`str`): The in-context example for the test example.
            label (`str`): The label of the test example.
            ice_template (`Optional[PromptTemplate]`): The template for
                in-context example. Defaults to None.
            prompt_template (`Optional[PromptTemplate]`): The template for
                prompt. Defaults to None.
            remain_sep (`Optional[bool]`): Whether to remain the sep token.
                Defaults to False.
        """
        if prompt_template is not None and ice_template is not None:
            if prompt_template.ice_token is not None:
                return prompt_template.generate_label_prompt_item(
                    self.test_ds[idx], ice, label, remain_sep)
            else:
                raise NotImplementedError(
                    'ice_token of prompt_template is not provided')
        elif ice_template is not None and prompt_template is None:
            if ice_template.ice_token is not None:
                return ice_template.generate_label_prompt_item(
                    self.test_ds[idx], ice, label, remain_sep)
            else:
                raise NotImplementedError(
                    'ice_token of ice_template is not provided')
        elif ice_template is None and prompt_template is not None:
            return prompt_template.generate_label_prompt_item(
                self.test_ds[idx], ice, label, remain_sep)
        else:
            raise NotImplementedError(
                'Leaving prompt as empty is not supported')

    def generate_prompt_for_generate_task(
            self,
            idx,
            ice,
            gen_field_replace_token='',
            ice_template: Optional[PromptTemplate] = None,
            prompt_template: Optional[PromptTemplate] = None):
        """Generate the prompt for one test example in generative evaluation
        with `prompt_template`. If `prompt_template` is not provided, the
        `ice_template` will be used to generate the prompt. The token
        represented by `gen_field_replace_token` will not be replaced by the
        generated text, or it will leaks the answer.

        Args:
            idx (`int`): The index of the test example.
            ice (`str`): The in-context example for the test example.
            gen_field_replace_token (`str`): The token of the answer in the
                prompt. Defaults to ''.
            ice_template (`Optional[PromptTemplate]`): The template for
                in-context example. Defaults to None.
            prompt_template (`Optional[PromptTemplate]`): The template for
                prompt. Defaults to None.
        """
        if prompt_template is not None and ice_template is not None:
            if prompt_template.ice_token is not None:
                return prompt_template.generate_item(
                    self.test_ds[idx],
                    output_field=self.dataset_reader.output_column,
                    output_field_replace_token=gen_field_replace_token,
                    ice_field_replace_token=ice)
            else:
                raise NotImplementedError(
                    'ice_token of prompt_template is not provided')
        elif ice_template is not None and prompt_template is None:
            if ice_template.ice_token is not None:
                return ice_template.generate_item(
                    self.test_ds[idx],
                    output_field=self.dataset_reader.output_column,
                    output_field_replace_token=gen_field_replace_token,
                    ice_field_replace_token=ice)
            else:
                raise NotImplementedError(
                    'ice_token of ice_template is not provided')
        elif ice_template is None and prompt_template is not None:
            return prompt_template.generate_item(
                self.test_ds[idx],
                output_field=self.dataset_reader.output_column,
                output_field_replace_token=gen_field_replace_token,
                ice_field_replace_token=ice)
        else:
            raise NotImplementedError(
                'Leaving prompt as empty is not supported')

    def generate_prompt_for_adv_generate_task(
            self,
            idx,
            ice,
            extra_prompt=dict(),
            gen_field_replace_token='',
            ice_template: Optional[PromptTemplate] = None,
            prompt_template: Optional[PromptTemplate] = None):
        """Generate the prompt for one test example in generative evaluation
        with `prompt_template`. If `prompt_template` is not provided, the
        `ice_template` will be used to generate the prompt. The token
        represented by `gen_field_replace_token` will not be replaced by the
        generated text, or it will leaks the answer.

        Args:
            idx (`int`): The index of the test example.
            ice (`str`): The in-context example for the test example.
            gen_field_replace_token (`str`): The token of the answer in the
                prompt. Defaults to ''.
            ice_template (`Optional[PromptTemplate]`): The template for
                in-context example. Defaults to None.
            prompt_template (`Optional[PromptTemplate]`): The template for
                prompt. Defaults to None.
        """
        if prompt_template is not None and ice_template is not None:
            if prompt_template.ice_token is not None:
                return prompt_template.generate_item(
                    {
                        **self.test_ds[idx],
                        **extra_prompt
                    },
                    output_field=self.dataset_reader.output_column,
                    output_field_replace_token=gen_field_replace_token,
                    ice_field_replace_token=ice)
            else:
                raise NotImplementedError(
                    'ice_token of prompt_template is not provided')
        elif ice_template is not None and prompt_template is None:
            if ice_template.ice_token is not None:
                return ice_template.generate_item(
                    {
                        **self.test_ds[idx],
                        **extra_prompt
                    },
                    output_field=self.dataset_reader.output_column,
                    output_field_replace_token=gen_field_replace_token,
                    ice_field_replace_token=ice)
            else:
                raise NotImplementedError(
                    'ice_token of ice_template is not provided')
        elif ice_template is None and prompt_template is not None:
            return prompt_template.generate_item(
                {
                    **self.test_ds[idx],
                    **extra_prompt
                },
                output_field=self.dataset_reader.output_column,
                output_field_replace_token=gen_field_replace_token,
                ice_field_replace_token=ice)
        else:
            raise NotImplementedError(
                'Leaving prompt as empty is not supported')
