import abc
from abc import ABC
from typing import Dict, List, Tuple, Optional, Union
import os
import random
import nltk

import hexa.tasks.constants as CONST
import hexa.tasks.prompts as PROMPT
import hexa.tasks.tod_core as tod
from hexa.tasks.utils import f1_metric, extract_entities, _normalize_persona_line, normalize_reply
from hexa.tasks.msc import calc_f1_msc
from hexa.tasks.summarizer import build_summarizer
from hexa.tasks.utils import build_all_personas


SUMMARIZER = None
ALL_PERSONAS = None
nlp = None


class AbstractMutator(ABC):
    def __init__(self, config):
        self.config = config
        self.episodic_mutation = False  # Set this to True for ones that require episodic, e.g. PersonaAsDocs

    def __call__(self, episode: List[Tuple[Dict, Dict]]) -> List[List[Tuple[Dict, Dict]]]:
        if self.episodic_mutation:
            new_episode = self.mutate_episode(episode)
            if len(new_episode) == 0 or isinstance(new_episode[0], list):
                return new_episode
            elif isinstance(new_episode[0], tuple):
                return [new_episode]
            else:
                raise ValueError("Inappropriate return type from the mutator")
        else:
            new_episodes = []
            new_epi = []
            num_excluded = 0  #TODO:  possibly return this in the future for logging
            for curr_idx, (d, aux) in enumerate(episode):
                d, _include_data = self.mutate(d, **aux)
                if isinstance(_include_data, bool):
                    _keep_episode = True
                elif isinstance(_include_data, tuple):
                    _include_data, _keep_episode = _include_data
                else:
                    raise ValueError("Unexpected type returned from the mutator")
                if _include_data:
                    new_epi.append((d, aux))
                    if d.get(CONST.EPISODE_END, False):
                        new_episodes.append(new_epi)
                        new_epi = []
                else:
                    num_excluded += 1
                if not _keep_episode:
                    num_excluded += len(new_epi) + len(episode) - curr_idx - 1
                    new_epi = []
                    break

            if new_epi:
                new_episodes.append(new_epi)


            return new_episodes

    def mutate(self, message: Dict, **kwargs) -> Tuple[Dict, bool]:
        """
        returns processed message and include_data (bool)
        :param message:
        :param kwargs:
        :return: processed message and keep (bool), rather to keep this message in the new dataset or not
        """
        raise NotImplementedError

    def mutate_episode(self, episode: List[Tuple[Dict, Dict]]) -> Union[List[List[Tuple[Dict, Dict]]], List[Tuple[Dict, Dict]]]:
        """
        returns processed message and include_data (bool)
        :param: list of (message, aux) representing an episode
        :return: list of new episodes
        """
        raise NotImplementedError


class FlattenMutator(AbstractMutator):
    __name__ = 'flatten'
    def __init__(self, config):
        super().__init__(config)
        self.episodic_mutation = True

    def mutate_episode(self, episode):
        dialog_history = []
        new_episodes = []
        for message, keep in episode:
            dialog_history.append(message[CONST.MESSAGE_TEXT])
            message[CONST.MESSAGE_TEXT] = '\n'.join(dialog_history)
            dialog_history.append(random.choice(message[CONST.LABELS]))
            message[CONST.EPISODE_END] = True
            new_episodes.append([(message, keep)])

        return new_episodes


class AddSelectedSentencesToText(AbstractMutator):
    __name__ = 'add_selected_sentences_text'
    def mutate(self, message, **kwargs):
        sentences = message.get(CONST.RETRIEVED_SENTENCES, '')
        sentences = ' '.join(sentences)
        message[CONST.MESSAGE_TEXT] += f'\n{CONST.KNOWLEDGE_TOKEN} {sentences} {CONST.END_KNOWLEDGE_TOKEN}'
        return message, True


class TodToSrm(AbstractMutator):
    __name__ = 'tod_to_srm'
    system_silence = f"{tod.STANDARD_SYSTEM_UTTERANCE}{CONST.DUMMY_TEXT}"
    user_silence = f"{tod.STANDARD_USER_UTTERANCE}{CONST.DUMMY_TEXT}"

    def __init__(self, config):
        super().__init__(config)
        self.context = []
        self.knowledge: List = []

    def mutate(self, message, **kwargs):
        episode_begin = kwargs.get(CONST.EPISODE_BEGIN, True)
        if episode_begin:
            # a new episode, reset context
            self.context = []

        user_utt: str = message['text']
        system_utt: str = message['labels'][0]
        system_utt_clean = system_utt.replace(tod.STANDARD_SYSTEM_UTTERANCE, '').replace('\n', ' ')

        keep_message = False
        if user_utt.startswith(tod.STANDARD_API_SCHEMAS):
            # start of convo; don't need this
            return message, False
        elif user_utt.startswith(tod.STANDARD_USER_UTTERANCE) and user_utt != self.user_silence:
            # add user utterances to the conversation
            self.context.append(user_utt.replace(tod.STANDARD_USER_UTTERANCE, ''))

        if user_utt.startswith(tod.STANDARD_RESP) and user_utt != tod.STANDARD_RESP:
            # here's the knowledge grounding
            self.knowledge.append(user_utt.replace(tod.STANDARD_RESP, '').replace('\n', ''))

        if self.knowledge and not (system_utt == self.system_silence or system_utt.startswith(tod.STANDARD_CALL)):
            str_knowledge = f"{CONST.KNOWLEDGE_TOKEN} {'. '.join(self.knowledge)} {CONST.END_KNOWLEDGE_TOKEN}"
            message['text'] = '\n'.join(self.context + [str_knowledge])
            message['labels'] =  [system_utt_clean]
            if f1_metric(system_utt_clean, [str_knowledge]) > 0.0:
                # filter out pathological examples
                keep_message = True
            # reset knowledge
            self.knowledge = []

        if not system_utt.startswith(tod.STANDARD_CALL) and system_utt != self.system_silence:
            self.context.append(system_utt_clean)

        return message, keep_message


class TodToDialog(AbstractMutator):
    __name__ = 'tod_to_dialog'
    system_silence = f"{tod.STANDARD_SYSTEM_UTTERANCE}{CONST.DUMMY_TEXT}"
    user_silence = f"{tod.STANDARD_USER_UTTERANCE}{CONST.DUMMY_TEXT}"
    user_done = f"{tod.STANDARD_DONE}"

    def __init__(self, config):
        super().__init__(config)
        self.context = []
        self.knowledge: List = []

    def mutate(self, message, **kwargs):
        user_utt: str = message['text']
        system_utt: str = message['labels'][0]
        system_utt_clean = system_utt.replace(tod.STANDARD_SYSTEM_UTTERANCE, '').replace('\n', ' ')

        keep_message = False
        if user_utt.startswith(tod.STANDARD_API_SCHEMAS) or not system_utt_clean:
            # start of convo; don't need this
            return message, False
        elif user_utt.startswith(tod.STANDARD_USER_UTTERANCE) and user_utt != self.user_silence:
            user_utt = user_utt.replace(tod.STANDARD_USER_UTTERANCE, '').replace('\n', ' ')
            if user_utt == self.user_done:
                self.context = []
                return message, False
            if user_utt:
                self.context.append(user_utt)

        if self.context and not (system_utt == self.system_silence or system_utt.startswith(tod.STANDARD_CALL)):
            if system_utt_clean == self.user_done:
                self.context = []
                return message, False
            message['text'] = '\n'.join(self.context)
            message['labels'] =  [system_utt_clean]
            keep_message = True
            self.context = []

        elif not system_utt.startswith(tod.STANDARD_CALL) and system_utt != self.system_silence:
            self.context.append(system_utt_clean)

        return message, keep_message


class SearchQueryClassificationMixin(AbstractMutator):
    """
    Changes the message in the following ways:
    1. Makes the task *only one line of context*
    2. Adds a prompt to the end of the message, indicating a binary choice of __search-required__
    3. Changes the label to indicate whether to search or not.
    """
    PROMPT = CONST.IS_SEARCH_REQUIRED
    LABEL: str
    __name__ = None
    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        assert self.get_label(message)
        if not message[CONST.MESSAGE_TEXT].endswith(self.PROMPT):
            last_context = message[CONST.MESSAGE_TEXT].split('\n')[-1]
            message[CONST.MESSAGE_TEXT] = f"{last_context} {self.PROMPT}"
        if message[CONST.LABELS] != [self.get_label(message)]:
            message[CONST.LABELS] = [self.get_label(message)]
        message.pop(CONST.KNOWLEDGE, None)
        return message, True

    @abc.abstractmethod
    def get_label(self, message: Dict) -> str:
        """
        Return the label.
        """


class SkipRetrievalMutator(AbstractMutator):
    """
    Mutator that adds a 'skip_retrieval' key to the observation.
    """
    __name__ = 'skip_retrieval_mutator'

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        message['skip_retrieval'] = True
        return message, True


class FilterSilenceOnlyMutator(AbstractMutator):
    """
    Filters out episodes that only contain silence.
    """
    __name__ = 'filter_silence_only_mutator'

    def mutate(self, message: Dict, **kwargs) -> (Dict, Tuple):
        """
        :param message: message
        :param kwargs: arbitrary kwargs from parser
        :return: returns message and flags (include_message, include_episode)
        """
        text = message['text']
        for tok in [
            CONST.IS_SEARCH_REQUIRED,
            CONST.IS_MEMORY_REQUIRED,
            CONST.ACCESS_MEMORY,
        ]:
            text = text.replace(f' {tok}', '')
        message[CONST.EPISODE_END] = True
        if text.lower() == '__silence__':
            return message, (False, False)
        return message, (True, True)


class DoGenerateSearchQueryMutator(SearchQueryClassificationMixin):
    __name__ = 'do_generate_search_query_mutator'

    def get_label(self, message: dict) -> str:
        return CONST.DO_SEARCH


def build_all_personas(opt) -> List[str]:
    """
    Build the personas list from which we sample memories for MDM.
    """
    personas_path = os.path.join(
        opt['datapath'], 'blended_skill_talk', 'persona_list.txt'
    )
    if not os.path.exists(personas_path):
        # new_opt = copy.deepcopy(opt)
        # new_opt['task'] = 'blended_skill_talk'
        # build_bst(new_opt)
        raise ValueError(f"{personas_path} does not exist!!")
    with open(personas_path) as f:
        all_personas = [l.replace('||', '').replace('\n', '') for l in f.readlines()][
            :-1
        ]
    return all_personas

def merge_personas(personas: List[str], p_to_merge: List[List[str]]) -> List[str]:
    """
    Merge two groups of personas, based on word overlap.
    :param personas:
        list of original personas.
    :param p_to_merge:
        list of personas to merge.
        first element is partner personas, second is your personas.
    :return merged:
        return list of merged personas
    """
    new_personas = []
    split_personas = [
        [p.replace('\n', ' ') for p in personas if p.startswith('partner')],
        [p.replace('\n', ' ') for p in personas if p.startswith('your')],
    ]
    for i in range(2):
        prefix = "partner's persona: " if i == 0 else "your persona: "
        personas_i = split_personas[i]
        for ip in p_to_merge[i]:
            found = False
            for j, op in enumerate(personas_i):
                if f1_metric(ip, [op[op.index(':') + 2 :]]) > 0.5:
                    personas_i[j] = ' '.join([op, ip])
                    found = True
                    break
            if not found:
                personas_i.append(f"{prefix}{ip}")
        new_personas.append(personas_i)
    return new_personas[0] + new_personas[1]


def get_overlap_sentence(
    full_text: str, label: str, docs: List[str], find_speaker: bool = True
) -> Tuple[str, int, Optional[int]]:
    """
    Get the sentence that most overlaps with a sentence in the context.
    :param full_text:
        full context
    :param label:
        target to find overlap with
    :param docs:
        list of candidate strings for overlap
    :param find_speaker:
        whether to determine who said the sentence in the full text
    :return (best_sentence, best_f1, best_idx):
        if we reach the F1 threshold, return the
        corresponding sentence, F1 score, and index into the docs
    """
    best_f1 = 0
    best_sentence = ''
    best_idx = None
    try:
        gold_parts = nltk.word_tokenize(label)
    except IndexError:
        return best_sentence, best_f1, best_idx
    for i, d in enumerate(docs):
        ds = d.split('\n')
        for s in ds:
            f1 = calc_f1_msc(s, gold_parts)
            if f1 > best_f1:
                best_f1 = f1
                best_sentence = s
                best_idx = i
    if best_sentence != '' and find_speaker:
        # find which speaker it is
        z = full_text.split('\n')
        z.reverse()
        ind = z.index(best_sentence)
        if (ind % 2) == 0:
            speaker = '__them__'
        else:
            speaker = '__you__'
        if 'your persona:' in best_sentence:
            speaker = '__you__'

        best_sentence = f"{best_sentence} {speaker}"

    return best_sentence, best_f1, best_idx


class PersonasAsDocsMutator(AbstractMutator):
    """
    This mutator does the following:
    1. Computes memories for context lines with the persona summarizer
    2. Determines which memory has highest overlap with target label
    3. Depending on target, does the following:
        - if knowledge, we set the target as that persona memory
        - if memory_decision, set the target as access_memory if enough overlap, otherwise don't
    """
    __name__ = 'personas_as_docs'

    THRESHOLD = 0.3
    TARGET = 'knowledge'

    def __init__(self, opt) -> None:
        super().__init__(opt)
        # assert FlattenMutator not in opt.mutators, 'cannot use flatten with this mutator.'
        global SUMMARIZER
        if SUMMARIZER is None:
            SUMMARIZER = build_summarizer(opt)
        global ALL_PERSONAS
        if ALL_PERSONAS is None:
            ALL_PERSONAS = build_all_personas(opt)

        self.episodic_mutation = True
        self.first_partner_text = None
        self.all_personas = []
        self.raw_personas = []
        self.all_text = []

    def compute_context_memories(
            self, episode: List[Dict]
    ) -> Tuple[List[str], List[str]]:
        """
        Compute the memories from the context.

        :param episode:
            list of messages

        :return (self_memories, partner_memories):
            return list of memories for self and partner.
        """
        partner_texts = [m['text'] for m, aux in episode]
        first_text = [
            t for t in partner_texts[0].split('\n') if not t.startswith('your persona:')
        ]
        if not first_text:
            partner_texts[0] = CONST.DUMMY_TEXT
        else:
            partner_texts[0] = first_text[0]
        self_texts = [m.get('labels', m.get('eval_labels'))[0] for m, aux in episode]

        assert SUMMARIZER is not None
        partner_memories = SUMMARIZER.batch_respond(
            [
                {'text': p.split('\n')[-1], 'episode_done': True}
                for p in partner_texts
            ]
        )
        self_memories = SUMMARIZER.batch_respond(
            [{'text': s, 'episode_done': True} for s in self_texts]
        )

        return self_memories, partner_memories

    def mutate_episode(self, episode: List[Dict], **kwargs) -> List[List[Dict]]:
        assert ALL_PERSONAS is not None
        global NUM_DONE

        def _maybe_add_persona(text: str, is_partner: bool):
            from hexa.tasks.msc.dataloader import NOPERSONA
            prefix = "partner's persona: " if is_partner else "your persona: "
            if text != NOPERSONA:
                raw_personas.append(text)
                all_personas.append(f"{prefix}{text}")

        # compute the personas
        self_memories, partner_memories = self.compute_context_memories(episode)

        # find best overlaps
        all_personas = []
        raw_personas = []
        new_episode = []
        all_text = []
        for i, (msg, aux) in enumerate(episode):
            text = msg['text']
            label = msg['labels'][0]
            if i == 0:
                if 'persona:' in text:
                    # Convai2; session 1
                    personas = [
                        t for t in text.split('\n') if t.startswith('your persona:')
                    ]
                    text = [
                        t for t in text.split('\n') if not t.startswith('your persona:')
                    ]
                    if not text:
                        text = CONST.DUMMY_TEXT
                        msg['text'] = text
                    else:
                        text = text[0]
                    all_personas += personas
                    raw_personas += [
                        p.replace('your persona: ', '').replace(
                            "partner's persona: ", ''
                        )
                        for p in all_personas
                    ]
                elif 'personas' in msg:
                    # MSC
                    all_personas += msg['personas'].split('\n')
                    init_personas = msg.get('init_personas', ['', ''])
                    all_personas = merge_personas(all_personas, init_personas)
                    raw_personas = [
                        p.replace('your persona: ', '').replace(
                            "partner's persona: ", ''
                        )
                        for p in all_personas
                    ]
            assert (
                    'persona:' not in text
            ), "cannot use this mutator with teacher that puts personas in context"
            all_text.append(text)
            best_sentence, best_f1, best_idx = get_overlap_sentence(
                '\n'.join(all_text), msg['labels'][0], raw_personas, False
            )
            if best_f1 > self.THRESHOLD:
                assert best_idx is not None
                if self.TARGET == 'knowledge':
                    msg[CONST.RETRIEVED_DOCS] = all_personas
                    msg[CONST.RETRIEVED_DOCS_URLS] = [''] * len(all_personas)
                    msg[CONST.RETRIEVED_DOCS_TITLES] = [''] * len(all_personas)
                    msg[CONST.SELECTED_SENTENCES] = [all_personas[best_idx]]
                    msg['raw_personas'] = raw_personas
                    msg['old_target'] = msg['labels']
                    msg['labels'] = [best_sentence]
                    msg['text'] = '\n'.join(all_text)
                elif self.TARGET == 'memory_decision':
                    msg.pop('personas', None)
                    msg.pop('init_personas', None)
                    msg['persona_sentence'] = best_sentence
                    msg['old_target'] = msg['labels']
                    msg['labels'] = [CONST.DO_ACCESS_MEMORY]
                    # make sure to include memory in the context!
                    if not best_sentence.startswith('persona:'):
                        best_sentence = f"persona: {best_sentence}"
                    msg['text'] = "\n".join([best_sentence, all_text[-1].split('\n')[-1]])
                new_episode.append([(msg, aux)])
            elif self.TARGET == 'memory_decision':
                msg.pop('personas', None)
                msg.pop('init_personas', None)
                msg['persona_sentence'] = best_sentence
                msg['old_target'] = msg['labels']
                msg['labels'] = [CONST.DONT_ACCESS_MEMORY]
                # include a random memory in context!
                random_pers = (
                    random.choice(ALL_PERSONAS)
                        .replace('your persona:', 'persona:')
                        .replace("partner's persona:", 'persona:')
                )
                if not random_pers.startswith('persona:'):
                    random_pers = f"persona: {random_pers}"
                msg['text'] = "\n".join([random_pers, all_text[-1].split('\n')[-1]])
                new_episode.append([(msg, aux)])
            _maybe_add_persona(partner_memories[i], True)
            _maybe_add_persona(self_memories[i], False)
            all_text.append(label)

        return new_episode


class ConvertOverlapToPersonasAsDocs(AbstractMutator):
    """
    This mutator converts a knowledge task that originally computed word overlap between
    target and all prior sentences to choose a previous utterance as the knowledge
    sentence.
    The conversion turns all context utterances into persona memories (retrieved docs),
    and all targets into persona memories as well.
    """
    __name__ = 'convert_overlap_to_personas_as_docs'
    TARGET = 'knowledge'

    def __init__(self, opt) -> None:
        super().__init__(opt)
        self.episodic_mutation = True
        global SUMMARIZER
        if SUMMARIZER is None:
            SUMMARIZER = build_summarizer(opt)

    def mutate_episode(self, episode: List[Dict]) -> List[Tuple[Dict, Dict]]:
        assert len(episode) == 1
        message, aux = episode[0]
        texts = message['text'].split('\n')
        texts[-1] = texts[-1].replace(f' {CONST.GENERATE_KNOWLEDGE}', '')
        all_personas = []
        if 'persona:' in message['text']:
            # Convai2; session 1
            personas = [t for t in texts if t.startswith('your persona:')]
            new_text = [t for t in texts if not t.startswith('your persona:')]
            if not new_text:
                new_text = CONST.DUMMY_TEXT
            message['text'] = '\n'.join(new_text)
            all_personas += personas
        elif 'personas' in message:
            # MSC
            all_personas += message['personas'].split('\n')
        assert SUMMARIZER is not None
        required_summaries = [{'text': t.replace(f' {CONST.GENERATE_KNOWLEDGE}', ''),
                               'episode_done': True} for t in texts] \
                              + [{'text': message['labels'][0]
                                          .replace(f' {CONST.YOU}', '')
                                          .replace(f' {CONST.THEM}', '')}]
        summaries = []
        offset = 16
        for start_idx in range(0, len(required_summaries), offset):
            summaries += SUMMARIZER.batch_respond(
                required_summaries[start_idx: start_idx + offset]
            )
        self_summaries = [
            t
            for i, t in enumerate(reversed(summaries[:-1]))
            if i % 2 == 1 and CONST.NOPERSONA not in t
        ]
        partner_summaries = [
            t
            for i, t in enumerate(reversed(summaries[:-1]))
            if i % 2 == 0 and CONST.NOPERSONA not in t
        ]
        all_personas = merge_personas(all_personas, [partner_summaries, self_summaries])
        if summaries[-1] != CONST.NOPERSONA or 'persona:' in message['labels'][0]:
            # for Convai2 we might already have this.
            new_target = (
                message['labels'][0]
                if 'persona:' in message['labels'][0]
                else summaries[-1]
            )
            if 'persona:' not in new_target:
                prefix = (
                    "your" if CONST.YOU in message['labels'][0] else "partner's"
                )
                new_target = f"{prefix} persona: {new_target}"
            message['original_target'] = message['labels']
            message[CONST.RETRIEVED_DOCS] = all_personas
            message[CONST.RETRIEVED_DOCS_URLS] = [''] * len(all_personas)
            message[CONST.RETRIEVED_DOCS_TITLES] = [''] * len(all_personas)
            message[CONST.SELECTED_SENTENCES] = [new_target]
            message['labels'] = [new_target]
            message['text'] = message['text'].replace(
                CONST.GENERATE_KNOWLEDGE, CONST.ACCESS_MEMORY
                )
            message.pop('skip_retrieval', None)
            return [(message, aux)]
        else:
            return []


class MemoryDecisionMutator(PersonasAsDocsMutator):
    """
    Utilize mutator above to set the memory decision.
    """
    __name__ = 'memory_decision_mutator'
    TARGET = 'memory_decision'


class ConvertToHuggingface(AbstractMutator):
    __name__ = 'convert_to_hf'
    KEYMAP = {
        CONST.MESSAGE_TEXT:'input_ids',
        CONST.LABELS: 'lm_labels',
        CONST.RETRIEVED_DOCS: 'docs',
    }
    def mutate(self, message, **kwargs):
        return message, True


class CheckedSentenceAsLabel(AbstractMutator):
    """
    Uses the checked sentence (knowledge) as label.
    """
    __name__ = 'wow_checked_sentence_as_label'


    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        assert "checked_sentence_kword" in kwargs, f"key named \"hecked_sentence_kword\" is required"
        assert kwargs["checked_sentence_kword"] in ["checked_sentence", CONST.SELECTED_SENTENCES], f"inappropriate checked_sentence_kwordnot"

        new_message = message.copy()
        if 'text' not in message or 'labels' not in message or not message['labels']:
            return message
        labels = new_message.pop('labels')
        checked_sentence = new_message.get('checked_sentence_kword', '')
        if isinstance(checked_sentence, list):
            checked_sentence = ' '.join(checked_sentence)

        new_message['dialogue_response'] = labels
        new_message['labels'] = [checked_sentence]
        return new_message, True


class PromptKnowledgeMutator(AbstractMutator):
    """
    Add a __generate_search_query__ prompt to the end of the context, to inform the
    model.

    Assumes flattened data.
    """
    __name__ = "prompt_knowledge_mutator"

    def __init__(self, config):
        super().__init__(config)
        self.PROMPT = CONST.GENERATE_KNOWLEDGE

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        if not message['text'].endswith(self.PROMPT):
            message['text'] = f"{message['text']} {self.PROMPT}"
        return message, True


class AddSelectedSentencesMutator(AbstractMutator):
    """
    Mutator that adds selected sentences to the messages.
    """
    __name__ = "add_selected_sentences_mutator"

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        if 'checked_sentence' in message:
            message[CONST.SELECTED_SENTENCES] = [message['checked_sentence']]
        elif CONST.SELECTED_SENTENCES not in message:
            message[CONST.SELECTED_SENTENCES] = []
        return message, True


class StyleGenToVRMMutator(AbstractMutator):
    """
    Converts style gen tasks to grounded dialogue tasks.

    assumes flattened data
    """
    __name__ = "style_gen_to_grm"

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        text = message['text']
        style = message['personality']
        assert CONST.BEGIN_STYLE not in text
        message['text'] = f"{text}\n{CONST.BEGIN_STYLE} {style} {CONST.END_STYLE}"
        return message, True


class ExtractEntity(AbstractMutator):
    """
    Picks an entity in the context and uses it as the intended target.
    """
    __name__ = "extract_entity_for_knowledge_model"

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        new_message = message.copy()
        context_except_last_line = '\n'.join(message['text'].split('\n')[:-1])
        context_ent = extract_entities(context_except_last_line)
        label_ent = extract_entities(message['labels'][0])
        ents = set(context_ent).intersection(label_ent)
        if len(ents) > 0:
            longest_ent = max(ents, key=len)
        else:
            return new_message, False
        new_message['response_labels'] = message['labels']
        new_message['labels'] = [longest_ent]
        new_message[CONST.SELECTED_SENTENCES] = [longest_ent]
        new_message[CONST.RETRIEVED_DOCS] = message['text'].split('\n')[:-1]

        blanks = [''] * len(message['text'].split('\n')[:-1])
        new_message[CONST.RETRIEVED_DOCS_URLS] = blanks
        new_message[CONST.RETRIEVED_DOCS_TITLES] = blanks
        return new_message, True


class ExtractEntityResponse(AbstractMutator):
    """
    Picks an entity in the context and uses it as the intended target.
    """
    __name__ = "extract_entity_for_response_model"

    BEGIN_KNOWLEDGE: str = CONST.KNOWLEDGE_TOKEN
    END_KNOWLEDGE: str = CONST.END_KNOWLEDGE_TOKEN

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        new_message = message.copy()
        context_except_last_line = '\n'.join(message['text'].split('\n')[:-1])
        context_ent = extract_entities(context_except_last_line)
        label_ent = extract_entities(message['labels'][0])
        ents = set(context_ent).intersection(label_ent)
        if len(ents) > 0:
            longest_ent = max(ents, key=len)
        else:
            return new_message, False
        knol_text = f'{self.BEGIN_KNOWLEDGE} {longest_ent} {self.END_KNOWLEDGE}'
        new_text = message['text'] + '\n' + knol_text
        new_message['text'] = new_text
        return new_message, True



class ExtractEntityResponseBB3(ExtractEntityResponse):
    __name__ = "extract_entity_for_response_model_bb3"
    BEGIN_KNOWLEDGE: str = CONST.BEGIN_ENTITY
    END_KNOWLEDGE: str = CONST.END_ENTITY


class PromptSearchQueryMutator(AbstractMutator):
    """
    Add a __generate_search_query__ prompt to the end of the context, to inform the
    model.

    Assumes flattened data.
    """
    __name__ = "prompt_search_query_mutator"

    GENERATE_QUERY = '__generate-query__'
    PROMPT = GENERATE_QUERY

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        if not message['text'].endswith(self.PROMPT):
            message['text'] = f"{message['text']} {self.PROMPT}"
        return message, True


class PromptEntityExtractionMutator(PromptSearchQueryMutator):
    """
    Add a __extract-entity__ prompt to the end of the context, to inform the model.

    Assumes flattened data.
    """
    __name__ = "prompt_extract_entity_mutator"

    PROMPT = CONST.EXTRACT_ENTITY

class PromptAccessMemoryMutator(PromptSearchQueryMutator):
    """
    Add a __access-memory__ prompt to the end of the context, to inform the model.

    Assumes flattened data.
    """
    __name__ = "prompt_access_memory_mutator"

    PROMPT = CONST.ACCESS_MEMORY


class PromptMemoryMutator(PromptSearchQueryMutator):
    """
    Add a __gnerate-memory__ prompt to the end of the context, to inform the model.

    Assumes flattened data.
    """
    __name__ = "prompt_memory_mutator"

    PROMPT = CONST.GENERATE_MEMORY


class PromptMemoryDecisionMutator(PromptSearchQueryMutator):
    """
    Add a __is-memory-required__ prompt to the end of the context, to inform the model.

    Assumes flattened data.
    """
    __name__ = "prompt_memory_decision_mutator"

    PROMPT = CONST.IS_MEMORY_REQUIRED


class AddRetrievedDocumentsMutator(AbstractMutator):
    __name__ = "add_retrieved_documents_mutator"
    """
    Add retrieved docs and relevant keys.
    """

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        sentences = message['text'].split('\n')[:-1]
        if CONST.RETRIEVED_DOCS not in message:
            message[CONST.RETRIEVED_DOCS] = sentences
        if CONST.SELECTED_SENTENCES not in message:
            message[CONST.SELECTED_SENTENCES] = message['labels']
        if CONST.RETRIEVED_DOCS_URLS not in message:
            message[CONST.RETRIEVED_DOCS_URLS] = [''] * len(sentences)
        if CONST.RETRIEVED_DOCS_TITLES not in message:
            message[CONST.RETRIEVED_DOCS_TITLES] = [''] * len(sentences)
        return message, True


class EnsureSameNumberDocsAndTitlesMutator(AbstractMutator):
    """
    Ensures that docs and doc titles have same number of items.
    """
    __name__ = "ensure_same_number_docs_and_titles_mutator"

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        docs = message[CONST.RETRIEVED_DOCS]
        for key in [CONST.RETRIEVED_DOCS_URLS, CONST.RETRIEVED_DOCS_TITLES]:
            if len(message[key]) < len(docs):
                message[key] = message[key] + [''] * (len(docs) - len(message[key]))
            elif len(message[key]) > len(docs):
                message[key] = message[key][: len(docs)]
        return message, True


class FixMkmFormattingMutator(AbstractMutator):
    """
    Ad-hoc mutator for the utterance overlap Mkm teachers.

    With ConvAI2 and BST, the utterance overlap mutators found persona sentences from the context.

    But those examples were not formatted correctly with `convert_overlap_to_personas_as_docs`

    So, fixing that now
    """
    __name__ = "fix_mkm_formatting_mutator"

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        text = message['text']
        if not text.endswith(CONST.ACCESS_MEMORY):
            message['text'] = f"{text} {CONST.ACCESS_MEMORY}"
        for tok in [CONST.YOU, CONST.THEM]:
            if tok in message['labels'][0]:
                message['labels'] = [message['labels'][0].replace(f' {tok}', '')]
                assert tok not in message['labels'][0]
        return message, True


class NormalizeReplyMutator(AbstractMutator):
    """
    Uses string normalization over text and labels.

    And retrieved docs, I suppose...
    """
    __name__ = "normalize_reply_mutator"

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        """
        Need to normalize:

        1) text 2) label 3) retrieved docs 4) best sentence.
        """
        # 1. text
        texts = message['text'].split('\n')
        your_personas = []
        partner_personas = []
        non_personas = []
        for i, x in enumerate(texts):
            if x.startswith('your persona: '):
                # Normalize the sentence appearing after 'your persona:'
                x = _normalize_persona_line(x)
                your_personas.append(x)
            elif x.startswith("partner's persona: "):
                x = _normalize_persona_line(x)
                partner_personas.append(x)
            elif i == len(texts) - 1:
                # check for memory decision, memory, generate memory etc.
                if CONST.IS_MEMORY_REQUIRED in x:
                    x = x.replace(f" {CONST.IS_MEMORY_REQUIRED}", '')
                    x = normalize_reply(x)
                    x = f"{x} {CONST.IS_MEMORY_REQUIRED}"
                    non_personas.append(x)
                elif CONST.ACCESS_MEMORY in x:
                    x = x.replace(f" {CONST.ACCESS_MEMORY}", '')
                    x = normalize_reply(x)
                    x = f"{x} {CONST.ACCESS_MEMORY}"
                    non_personas.append(x)
                elif CONST.BEGIN_MEMORY in x:
                    x = x.replace(f'{CONST.BEGIN_MEMORY} ', '').replace(
                        f' {CONST.END_MEMORY}', ''
                    )
                    if 'persona:' in x:
                        x = _normalize_persona_line(x)
                    else:
                        x = normalize_reply(x)
                    x = f"{CONST.BEGIN_MEMORY} {x} {CONST.END_MEMORY}"
                    non_personas.append(x)
                else:
                    x = normalize_reply(x)
                    non_personas.append(x)
            else:
                x = normalize_reply(x)
                non_personas.append(x)
        message['text'] = '\n'.join(your_personas + partner_personas + non_personas)
        # 2. label
        label = message['labels'][0]
        if 'persona:' in label:
            label = _normalize_persona_line(label)
        elif label not in [CONST.DONT_ACCESS_MEMORY, CONST.DO_ACCESS_MEMORY]:
            label = normalize_reply(label)
        message['labels'] = [label]

        # 3. retrieved docs.
        if CONST.RETRIEVED_DOCS in message:
            documents = message[CONST.RETRIEVED_DOCS]
            new_docs = []
            for x in documents:
                x = _normalize_persona_line(x)
                new_docs.append(x)
            message[CONST.RETRIEVED_DOCS] = new_docs
        # 4. best sentence
        if CONST.SELECTED_SENTENCES in message:
            selected = message[CONST.SELECTED_SENTENCES][0]
            message[CONST.SELECTED_SENTENCES] = [_normalize_persona_line(selected)]

        return message, True


class MatchTargetToPersonaMutator(AbstractMutator):
    """
    For ad-hoc conversion; matches the knowledge label to a "your" or "partner's"
    prefix.
    """
    __name__ = "match_target_to_persona_mutator"

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        personas = message[CONST.RETRIEVED_DOCS]
        target = message['labels'][0]
        p = None
        for p in personas:
            if target in p:
                break
        assert p
        assert target in p
        message['labels'] = [p]
        return message, True


class ConvertMkmToMrmMutator(AbstractMutator):
    """
    Converts mkm task to mrm task.
    """
    __name__ = "convert_mkm_to_mrm_mutator"

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        persona_knowledge = message['labels'][0]
        message['text'] = f"{message['text'].replace(f' {CONST.ACCESS_MEMORY}', '')}\n{CONST.BEGIN_MEMORY} {persona_knowledge} {CONST.END_MEMORY}"

        old_target = message['old_target']
        if not isinstance(old_target, list):
            old_target = [old_target]
        message['labels'] = old_target
        return message, True


class FilterSilenceOnlyMemoryDecisionMutator(AbstractMutator):
    """
    Filters out episodes that only contain silence.
    """
    __name__ = "filter_silence_only_memory_decision_mutator"

    def mutate(self, message: Dict, **kwargs) -> (Dict, Tuple):
        text = message['text'].split('\n')
        if len(text) == 2 and text[1].lower() == '__silence__':
            return message, (True, False)

        return message, (True, True)


class PrefixSpeakersMutator(AbstractMutator):
    """
    Add control tokens to the speakers, within the context.
    """
    __name__ = 'prefix_speakers'
    SELF_PREFIX = CONST.YOU_PREFIX
    PARTNER_PREFIX = CONST.PARTNER_PREFIX

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        context, utterances, post_utterances = self.get_context_and_utterances(message)
        context = [c.lstrip(' ').rstrip(' ') for c in context]
        utterances = [u.lstrip(' ').rstrip(' ') for u in utterances]
        post_utterances = [p.lstrip(' ').rstrip(' ') for p in post_utterances]

        new_utts: List[str] = []
        for i, utt in enumerate(reversed(utterances)):
            if utt.lower() != CONST.DUMMY_TEXT.lower() and not any(
                utt.startswith(tok)
                for tok in [
                    CONST.BEGIN_ENTITY,
                    CONST.BEGIN_MEMORY,
                    CONST.KNOWLEDGE_TOKEN,
                ]
            ):
                prefix = self.SELF_PREFIX if i % 2 == 1 else self.PARTNER_PREFIX
                utt = f"{prefix}{utt.lstrip(' ').rstrip(' ')}"
            new_utts.append(utt)
        for i in range(len(context)):
            # LIGHT context starts with underscores
            if 'persona:' not in context[i] and not context[i].startswith('_'):
                context[i] = f"{self.PARTNER_PREFIX}{context[i]}"
        message['text'] = '\n'.join(context + list(reversed(new_utts)) + post_utterances
        )
        return message, True

    def get_context_and_utterances(
        self, message: Dict
    ) -> Tuple[List[str], List[str], List[str]]:
        context, utterances, post_utterances = [], [], []
        texts = message['text'].split('\n')
        m_id = message.get('id', '')
        # 1. Handle WizInt
        context_end = 0
        if ('Dialogue' in m_id and 'Vanilla' not in m_id) or (
            'msc_dialogue' in m_id
            and any(
                tok in message['text']
                for tok in [
                    CONST.BEGIN_MEMORY,
                    CONST.BEGIN_ENTITY,
                    CONST.KNOWLEDGE_TOKEN,
                ]
            )
        ):
            post_utterances = [texts[-1]]
            texts = texts[:-1]
        if 'Light' in m_id:
            context_end = 5
        elif 'Funpedia' in m_id:
            post_utterances.append(texts[-1])
            texts = texts[:-1]
        elif any(
            m_id.endswith(f'Woi{t}Teacher')
            for t in [
                'SearchDialogue',
                'VanillaDialogue',
                'SearchKnowledge',
                'SearchQuery',
            ]
        ):
            context_end = 2
        elif m_id.endswith('MemoryDecisionTeacher') or (
            'msc' in m_id and texts[0].startswith('persona:')
        ):
            context_end = 1
        elif any(
            m_id.endswith(f'Wow{t}Teacher')
            for t in ['SearchDialogue', 'VanillaDialogue', 'SearchKnowledge']
        ):
            context_end = 1
        elif any(
            m_id.endswith(f'BST{t}Teacher')
            for t in [
                'EntityDialogue',
                'VanillaDialogue',
                'StyleGroundingDialogue',
                'EntityKnowledge',
                'MemoryKnowledge',
                'MemoryKnowledgeUttOverlap',
            ]
        ):
            ctxt_dataset = message['context_dataset']
            if (
                any(tok in m_id for tok in ['MemoryKnowledge', 'MemoryDialogue'])
                and ctxt_dataset == 'wizard_of_wikipedia'
            ):
                context_end = 1
            elif any(
                t in m_id
                for t in [
                    'Knowledge',
                    'Dialogue',
                    'VanillaDialogue',
                    'StyleGroundingDialogue',
                ]
            ) and any('persona:' in t for t in texts):
                non_context = [i for i, t in enumerate(texts) if 'persona:' not in t]
                end_idx = 0
                if ctxt_dataset == 'wizard_of_wikipedia':
                    end_idx = 1
                context_end = (
                    non_context[end_idx] if len(non_context) > end_idx else len(texts)
                )
        elif any(
            m_id.endswith(f'Convai2{t}Teacher')
            for t in [
                'EntityDialogue',
                'EntityKnowledge',
                'VanillaDialogue',
                'StyleGroundingDialogue',
            ]
        ):
            non_context = [i for i, t in enumerate(texts) if 'persona:' not in t]
            context_end = non_context[0] if non_context else len(texts)
            if 'StyleGrounding' in m_id:
                # this has the topic as well
                context_end += 1

        context = texts[:context_end]
        utterances = texts[context_end:] if len(texts) > context_end else []
        return context, utterances, post_utterances


class PrefixSpeakersOPTMutator(PrefixSpeakersMutator):
    __name__ = 'prefix_speakers_opt'
    SELF_PREFIX = f"{PROMPT.SELF_PREFIX}: "
    PARTNER_PREFIX = f"{PROMPT.PARTNER_PREFIX}: "

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        message, keep = super().mutate(message, **kwargs)
        m_id = message.get('id', '')
        if (
            'dialogue' in m_id.lower()
            and not ('NQOpen' in m_id)
            and not (
                ('msc' in m_id and message['text'].startswith('persona:'))
                or ('msc' in m_id and CONST.EXTRACT_ENTITY in message['text'])
                or ('msc' in m_id and message['text'].endswith(CONST.ACCESS_MEMORY))
                or (
                    'msc' in m_id
                    and message['text'].endswith(CONST.IS_SEARCH_REQUIRED)
                )
                or (
                    'msc' in m_id
                    and message['text'].endswith(CONST.IS_MEMORY_REQUIRED)
                )
            )
        ):
            if message['labels'][0].startswith(PROMPT.SELF_PREFIX):
                message['labels'] = [message['labels'][0].replace(f"{PROMPT.SELF_PREFIX}: ", '')]
            if not message['text'][0].endswith(f"{PROMPT.SELF_PREFIX}: "):
                message['text'] = f"{message['text']}\n{PROMPT.SELF_PREFIX}: "
        return message, True * keep


class FormatResponseTasksForDecoderOnlyMutator(AbstractMutator):
    """
    Replace special tokens with appropriate prefixes.
    Before:
        - - - NEW EPISODE: WoiDialogueTeacher - - -
        My favorite game is WWE.
        fight game very interesting to played
        __knowledge__ Play WWE Games __endknowledge__
            That's really interesting! Are there are a lot WWE games out there? Why do you like them so much>
        - - - NEW EPISODE: Convai2DialogueTeacher - - -
        your persona: I had a gig at local theater last night.
        your persona: I work as a stand up comedian.
        your persona: I come from a small town.
        your persona: My favorite drink is cuba libre.
        your persona: I did a few small roles in tv series.
        We all live in a yellow submarine, a yellow submarine. Morning!
        Hi! That is a great line for my next stand up.
        Lol. I am shy, anything to break the ice, and I am a beatles fan.
        __entity__ tv __endentity__
            I can tell. I am not, you can see me in some tv shows
        - - - NEW EPISODE: Convai2DialogueFromPersonaOverlapMAMTeacher - - -
        Hi! I work as a gourmet cook.
        __memory__ your persona: I hate carrots. __endmemory__
            I don't like carrots. I throw them away.
    After:
        - - - NEW EPISODE: WoiDecoderOnlyDialogueTeacher - - -
        Person 1: My favorite game is WWE.
        Person 1: fight game very interesting to played
        Interesting Fact: Play WWE Games
            Person 2: That's really interesting! Are there are a lot WWE games out there? Why do you like them so much>
        - - - NEW EPISODE: Convai2DecoderOnlyDialogueTeacher - - -
        Person 1: We all live in a yellow submarine, a yellow submarine. Morning!
        Person 2: Hi! That is a great line for my next stand up.
        Person 1: Lol. I am shy, anything to break the ice, and I am a beatles fan.
        Previous Topic: tv
            Person 2: I can tell. I am not, you can see me in some tv shows
        - - - NEW EPISODE: Convai2DecoderOnlyDialogueFromPersonaOverlapMAMTeacher - - -
        Person 1: Hi! I work as a gourmet cook.
        Personal Fact: Person 2's Persona: I hate carrots.
            Person 2: I don't like carrots. I throw them away.
    """
    __name__ = 'format_response_tasks_for_decoder_only'

    def mutate(self, message: Dict, **kwargs) -> (Dict, bool):
        context, _ = message['text'], message['labels'][0]
        assert any(
            const in context
            for const in [
                CONST.BEGIN_ENTITY,
                CONST.BEGIN_MEMORY,
                CONST.KNOWLEDGE_TOKEN,
            ]
        )
        context = context.split('\n')
        context = [c for c in context if CONST.DUMMY_TEXT.lower() not in c.lower()]
        if 'DialogueFrom' not in message['id']:
            context = [
                c
                for c in context
                if not ('your persona: ' in c and CONST.BEGIN_MEMORY not in c)
            ]
        knowledge = context[-2]
        if CONST.BEGIN_ENTITY in knowledge:
            begin, end = CONST.BEGIN_ENTITY, CONST.END_ENTITY
            new_prefix = PROMPT.CONTEXTUAL_KNOWLEDGE_PREFIX
        elif CONST.BEGIN_MEMORY in knowledge:
            begin, end = CONST.BEGIN_MEMORY, CONST.END_MEMORY
            new_prefix = PROMPT.MEMORY_KNOWLEDGE_PREFIX
        else:
            assert CONST.KNOWLEDGE_TOKEN in knowledge
            begin, end = CONST.KNOWLEDGE_TOKEN, CONST.END_KNOWLEDGE_TOKEN
            new_prefix = PROMPT.SEARCH_KNOWLEDGE_PREFIX
        knowledge = knowledge.replace(begin, f"{new_prefix}:")
        knowledge = knowledge.replace(f" {end}", '')
        knowledge = knowledge.replace(
            'your persona: ', f"{PROMPT.SELF_MEMORY_PREFIX}: "
        )
        knowledge = knowledge.replace(
            'partner\'s persona: ', f"{PROMPT.PARTNER_MEMORY_PREFIX}: "
        )
        message['text'] = '\n'.join(context[:-2] + [knowledge] + [context[-1]])
        return message, True