from abc import abstractmethod
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union

from .prompt import PromptList

PromptType = Union[PromptList, str]


class BaseModel:
    """Base class for model wrapper.

    Args:
        path (str): The path to the model.
        max_seq_len (int): The maximum sequence length of the model. Defaults
            to 2048.
        tokenizer_only (bool): If True, only the tokenizer will be initialized.
            Defaults to False.
        meta_template (Dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
        generation_kwargs (Dict, optional): The generation kwargs for the
            model. Defaults to dict().
    """

    is_api: bool = False

    def __init__(self,
                 path: str,
                 max_seq_len: int = 2048,
                 tokenizer_only: bool = False,
                 meta_template: Optional[Dict] = None,
                 generation_kwargs: Optional[Dict] = dict()):
        self.path = path
        self.max_seq_len = max_seq_len
        self.tokenizer_only = tokenizer_only
        # meta template
        self.template_parser = LMTemplateParser(meta_template)
        self.eos_token_id = None
        if meta_template and 'eos_token_id' in meta_template:
            self.eos_token_id = meta_template['eos_token_id']
        self.generation_kwargs = generation_kwargs

    @abstractmethod
    def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
        """Generate results given a list of inputs.

        Args:
            inputs (List[str]): A list of strings.
            max_out_len (int): The maximum length of the output.

        Returns:
            List[str]: A list of generated strings.
        """
        raise NotImplementedError(f'{self.__class__.__name__} does not support'
                                  ' gen-based evaluation yet, try ppl-based '
                                  'instead.')

    @abstractmethod
    def get_ppl(self,
                inputs: List[str],
                mask_length: Optional[List[int]] = None) -> List[float]:
        """Get perplexity scores given a list of inputs.

        Args:
            inputs (List[str]): A list of strings.
            mask_length (Optional[List[int]]): A list of mask lengths. If
                provided, the perplexity scores will be calculated with the
                first mask_length[i] tokens masked out. It's okay to skip
                its implementation if advanced features in PPLInfernecer is
                not needed.

        Returns:
            List[float]: A list of perplexity scores.
        """
        raise NotImplementedError(f'{self.__class__.__name__} does not support'
                                  ' ppl-based evaluation yet, try gen-based '
                                  'instead.')

    @abstractmethod
    def get_token_len(self, prompt: str) -> int:
        """Get lengths of the tokenized strings.

        Args:
            prompt (str): Input string.

        Returns:
            int: Length of the input tokens
        """

    def parse_template(self, prompt_template: PromptType, mode: str) -> str:
        """Parse a prompt template, and wrap it with meta template if
        applicable.

        Args:
            prompt_template (List[str or PromptList]): A prompt
                template (potentially before being wrapped by meta template).
            mode (str): Parsing mode. Choices are 'ppl' and 'gen'.

        Returns:
            str: The final string.
        """
        return self.template_parser.parse_template(prompt_template, mode)

    def get_ppl_from_template(self,
                              templates: List[PromptType],
                              mask_length=None):
        """Get perplexity given a list of templates.

        Args:
            templates (List[PromptType]): A list of templates.
            mask_length (List[int]): A list of mask lengths. If provided, the
                perplexity will be calculated only on the unmasked tokens.
        """
        inputs = self.parse_template(templates, mode='ppl')
        return self.get_ppl(inputs, mask_length)

    def generate_from_template(self, templates: List[PromptType],
                               max_out_len: int, **kwargs):
        """Generate completion from a list of templates.

        Args:
            templates (List[PromptType]): A list of templates.
            max_out_len (int): The maximum length of the output.
        """
        inputs = self.parse_template(templates, mode='gen')
        return self.generate(inputs, max_out_len=max_out_len, **kwargs)

    def get_token_len_from_template(
            self,
            templates: Union[PromptType, List[PromptType]],
            mode: str = 'ppl') -> Union[List[int], int]:
        """Get lengths given a list of templates.

        Args:
            templates (Union[List[str], str]): Input template(s).
            mode (str): Parsing mode. Choices are 'ppl' and 'gen'.

        Returns:
            Union[List[int], int]: Length(s) of the input tokens. If the input
            is a list, a list of lengths will be returned. Otherwise, an int
            will be returned.
        """
        prompts = self.parse_template(templates, mode=mode)
        assert isinstance(prompts, (list, str)), 'tokens must be list or str'
        is_batched = isinstance(prompts,
                                list) and not isinstance(prompts, PromptList)
        if not is_batched:
            prompts = [prompts]
        prompts = [str(prompt) for prompt in prompts]
        token_lens = [self.get_token_len(prompt) for prompt in prompts]
        return token_lens[0] if not is_batched else token_lens

    def to(self, device):
        self.model.to(device)


class LMTemplateParser:
    """Intermidate prompt template parser, specifically for language models.

    Args:
        meta_template (Dict): The meta template for the model.
    """

    def __init__(self, meta_template: Optional[Dict] = None):
        self.meta_template = meta_template
        if meta_template:
            assert 'round' in meta_template, 'round is required in meta' \
                ' template'
            assert isinstance(meta_template['round'], list)
            keys_to_check = ['round']

            if 'reserved_roles' in meta_template:
                assert isinstance(meta_template['reserved_roles'], list)
                keys_to_check.append('reserved_roles')

            self.roles: Dict[str, dict] = dict()  # maps role name to config
            for meta_key in keys_to_check:
                for item in meta_template[meta_key]:
                    assert isinstance(item, (str, dict))
                    if isinstance(item, dict):
                        assert item['role'] not in self.roles, \
                            'role in meta prompt must be unique!'
                        self.roles[item['role']] = item.copy()
                        # convert list of string and int into a raw string
                        # for the ease of future prompt processing
                        for key in ['begin', 'end']:
                            value = self.roles[item['role']].get(key, '')
                            if isinstance(value, list):
                                self.roles[item['role']][
                                    key] = self._encode_speical_tokens(value)

    def parse_template(self, prompt_template: PromptType, mode: str) -> str:
        """Parse a prompt template, and wrap it with meta template if
        applicable.

        Args:
            prompt_template (List[str or PromptList]): A prompt
                template (potentially before being wrapped by meta template).
            mode (str): Parsing mode. Choices are 'ppl' and 'gen'.

        Returns:
            str: The final string.
        """
        assert isinstance(prompt_template, (str, list, PromptList, tuple))
        if not isinstance(prompt_template, (str, PromptList)):
            return [self.parse_template(p, mode=mode) for p in prompt_template]

        assert mode in ['ppl', 'gen']
        if isinstance(prompt_template, str):
            return prompt_template
        if self.meta_template:

            prompt = ''
            # Whether to keep generating the prompt
            generate = True

            section_stack = []  # stores tuples: (section_name, start_idx)

            for i, item in enumerate(prompt_template):
                if not generate:
                    break
                if isinstance(item, str):
                    prompt += item
                elif isinstance(item, dict) and 'section' in item:
                    if item['pos'] == 'end':
                        section_name, start_idx = section_stack.pop(-1)
                        assert section_name == item['section']
                        if section_name in ['round', 'ice']:
                            dialogue = prompt_template[start_idx:i]
                            round_ranges = self._split_rounds(
                                dialogue, self.meta_template['round'])
                            # Consider inserting multiple round examples into
                            # template
                            for i in range(len(round_ranges) - 1):
                                start = round_ranges[i]
                                end = round_ranges[i + 1]
                                round_template = dialogue[start:end]
                                role_dict = self._update_role_dict(
                                    round_template)
                                new_str, generate = self._prompt2str(
                                    self.meta_template['round'],
                                    role_dict,
                                    # Start generating only when the mode is in
                                    # generation and the template reaches the
                                    # last round
                                    for_gen=mode == 'gen'
                                    and i == len(round_ranges) - 2
                                    and section_name == 'round')
                                prompt += new_str
                    elif item['pos'] == 'begin':
                        assert item['section'] in [
                            'begin', 'round', 'end', 'ice'
                        ]
                        section_stack.append((item['section'], i + 1))
                    else:
                        raise ValueError(f'Invalid pos {item["pos"]}')
                # if in "begin" or "end" section
                elif section_stack[-1][0] in ['begin', 'end']:
                    role_dict = self._update_role_dict(item)
                    new_str, generate = self._prompt2str(
                        item,
                        role_dict,
                        # never stop generation
                        for_gen=False)
                    prompt += new_str

            prompt = self.meta_template.get('begin', '') + prompt
            if generate:
                prompt += self.meta_template.get('end', '')

        else:
            # in case the model does not have any meta template
            prompt = ''
            last_sep = ''
            for item in prompt_template:
                if isinstance(item, dict) and set(['section', 'pos']) == set(
                        item.keys()):
                    continue
                if isinstance(item, str):
                    if item:
                        prompt += last_sep + item
                elif item.get('prompt', ''):  # it's a dict
                    prompt += last_sep + item.get('prompt', '')
                last_sep = '\n'
        return prompt

    def _split_rounds(
            self, prompt_template: List[Union[str, Dict]],
            single_round_template: List[Union[str, Dict]]) -> List[int]:
        """Split the prompt template into rounds, based on single round
        template.

        Return the index ranges of each round. Specifically,
        prompt_template[res[i]:res[i+1]] represents the i-th round in the
        template.
        """
        role_idxs = {
            role_cfg['role']: i
            for i, role_cfg in enumerate(single_round_template)
            if not isinstance(role_cfg, str)
        }
        last_role_idx = -1
        cutoff_idxs = [0]
        for idx, template in enumerate(prompt_template):
            if isinstance(template, str):
                continue
            role_idx = role_idxs[template['role']]
            if role_idx <= last_role_idx:
                cutoff_idxs.append(idx)
            last_role_idx = role_idx
        cutoff_idxs.append(len(prompt_template))
        return cutoff_idxs

    def _update_role_dict(self, prompt: Union[List, str,
                                              Dict]) -> Dict[str, Dict]:
        """Update the default role dict with the given prompt(s)."""
        assert isinstance(prompt, (str, list, dict))
        role_dict = deepcopy(self.roles)
        if isinstance(prompt, str):
            return role_dict
        if isinstance(prompt, dict):
            prompt = [prompt]
        for p in prompt:
            if isinstance(p, dict):
                role = p['role']
                if role not in self.roles:
                    role = p.get('fallback_role', None)
                    if not role:
                        print(f'{p} neither has an appropriate role nor '
                              'a fallback role.')
                role_dict[role].update(p)
        return role_dict

    def _prompt2str(self,
                    prompt: Union[List, str, Dict],
                    role_dict: Dict[str, Dict],
                    for_gen: bool = False) -> Tuple[str, bool]:
        """Convert the prompts to a string, given an updated role_dict.

        Args:
            prompts (Union[List, str, dict]): The prompt(s) to be converted.
            role_dict (Dict[str, Dict]): The updated role dict.
            for_gen (bool): If True, the prompts will be converted for
                generation tasks. The conversion stops before the first
                role whose "generate" is set to True.

        Returns:
            Tuple[str, bool]: The converted string, and whether the follow-up
            conversion should be proceeded.
        """
        assert isinstance(prompt, (list, str, dict))

        if isinstance(prompt, str):
            return prompt, True
        if isinstance(prompt, dict):
            return self._role2str(prompt, role_dict, for_gen)

        res = ''
        for p in prompt:
            new_str, cont = self._prompt2str(p, role_dict, for_gen)
            res += new_str
            if not cont:
                break
        return res, cont

    def _role2str(self,
                  role_prompt: Dict,
                  role_dict: Dict[str, Dict],
                  for_gen: bool = False) -> Tuple[str, bool]:
        """Convert a role prompt to a string, given an updated role_dict.

        Args:
            role_prompt (Dict): The role prompt to be converted.
            role_dict (Dict[str, Dict]): The updated role dict.
            for_gen (bool): If True, the prompts will be converted for
                generation tasks. The conversion stops before the first
                role whose "generate" is set to True.

        Returns:
            Tuple[str, bool]: The converted string, and whether the follow-up
            conversion should be proceeded.
        """
        merged_prompt = role_dict.get(
            role_prompt['role'],
            role_dict.get(role_prompt.get('fallback_role')))
        res = merged_prompt.get('begin', '')
        if for_gen and merged_prompt.get('generate', False):
            return res, False
        # res += merged_prompt.get('prompt', '') + merged_prompt.get('end', '')
        res += merged_prompt.get('prompt', '') + merged_prompt.get('end', '')
        return res, True

    def _encode_speical_tokens(self, prompt: List[Union[str, int]]) -> str:
        """Encode the special tokens in the prompt.

        Now this is left for the future work
        """
        raise NotImplementedError('Using List[str|int] is as the begin or end'
                                  'of a prompt is not supported yet.')
        res = ''
        for item in prompt:
            if isinstance(item, str):
                res += item
            else:
                res += f'<META_TOKEN_{item}>'
        return res
