import os
import torch
import sys as _sys
from enum import Enum
from typing import Dict, List
from typing import List, Tuple, Optional, Dict, Any, Set, Union

from hexa.utils.message import Message
from hexa.utils.metrics import PPLMetric
from hexa.utils.constants import Constant
import hexa.utils.prompts as PROMPT
from hexa.utils import logging


from nltk.stem import PorterStemmer
import string
import contractions
from nltk.tokenize import word_tokenize
CONST = Constant()

STEMMER = PorterStemmer()

# Very limited number of stopwords:
# We don't remove you and me etc. because of their strong importance in memory records.
_STOP_WORDS = set(
    list(string.punctuation)
    + ['a', 'an', 'the', 'am', 'are', 'is', 'as']
    + [
        STEMMER.stem(adv)
        for adv in (
            'really',
            'actually',
            'up',
            'so',
            'just',
            'now',
            'how',
            'then',
            'also',
            'very',
            'too',
            'most',
            'but',
        )
    ]
)


# parlai > utils > strings
def colorize(text, style):
    try:
        # if we're in ipython it's okay to use colors
        __IPYTHON__
        USE_COLORS = True
    except NameError:
        USE_COLORS = _sys.stdout.isatty()

    if not USE_COLORS:
        return text

    colorstyle = os.environ.get('PARLAI_COLORSTYLE')

    RESET = '\033[0;0m'
    if style == 'red':
        return '\033[0;31m' + text + RESET
    if style == 'yellow':
        return '\033[0;93m' + text + RESET
    if style == 'green':
        return '\033[0;32m' + text + RESET
    if style == 'blue':
        return '\033[0;34m' + text + RESET
    if style == 'brightblack':
        return '\033[0;90m' + text + RESET

    if colorstyle is None or colorstyle.lower() == 'steamroller':
        BLUE = '\033[1;94m'
        BOLD_LIGHT_GRAY_NOBK = '\033[1m'
        LIGHT_GRAY_NOBK = '\033[0m'
        MAGENTA = '\033[0;95m'
        HIGHLIGHT_RED_NOBK = '\033[1;31m'
        HIGHLIGHT_BLUE_NOBK = '\033[0;34m'
        if style == 'highlight':
            return HIGHLIGHT_RED_NOBK + text + RESET
        if style == 'highlight2':
            return HIGHLIGHT_BLUE_NOBK + text + RESET
        elif style == 'text':
            return LIGHT_GRAY_NOBK + text + RESET
        elif style == 'bold_text':
            return BOLD_LIGHT_GRAY_NOBK + text + RESET
        elif style == 'labels' or style == 'eval_labels':
            return BLUE + text + RESET
        elif style == 'label_candidates':
            return LIGHT_GRAY_NOBK + text + RESET
        elif style == 'id':
            return LIGHT_GRAY_NOBK + text + RESET
        elif style == 'text2':
            return MAGENTA + text + RESET
        elif style == 'field':
            return HIGHLIGHT_BLUE_NOBK + text + RESET
        else:
            return MAGENTA + text + RESET

    if colorstyle.lower() == 'spermwhale':
        BLUE = '\033[1;94m'
        BOLD_LIGHT_GRAY = '\033[1;37;40m'
        LIGHT_GRAY = '\033[0;37;40m'
        MAGENTA = '\033[0;95m'
        HIGHLIGHT_RED = '\033[1;37;41m'
        HIGHLIGHT_BLUE = '\033[1;37;44m'
        if style == 'highlight':
            return HIGHLIGHT_RED + text + RESET
        if style == 'highlight2':
            return HIGHLIGHT_BLUE + text + RESET
        elif style == 'text':
            return LIGHT_GRAY + text + RESET
        elif style == 'bold_text':
            return BOLD_LIGHT_GRAY + text + RESET
        elif style == 'labels' or style == 'eval_labels':
            return BLUE + text + RESET
        elif style == 'label_candidates':
            return LIGHT_GRAY + text + RESET
        elif style == 'id':
            return LIGHT_GRAY + text + RESET
        elif style == 'text2':
            return MAGENTA + text + RESET
        elif style == 'field':
            return HIGHLIGHT_BLUE + text + RESET
        else:
            return MAGENTA + text + RESET

    # No colorstyle specified/found.
    return text


# parlai > agents > fid > fid.py > concat_enc_outs
def concat_enc_outs(
    embedding_size: int,
    bsz: int,
    encoder_hidden_states: torch.Tensor,
    attention_mask: torch.BoolTensor,
    right_padded: bool = True,
    pad_token_id: int = 0,
) -> Tuple[torch.Tensor, torch.BoolTensor]:
    """
    Concatenate Encoder Outputs.
    Does the whole "FiD" thing; each query/document pair is independently encoded in the
    Encoder, so we need to concatenate all the outputs prior to sending to the decoder.
    :param input:
        [bsz, seqlen] original input to the encoder
    :param encoder_hidden_states:
        [bsz * n_docs, seqlen] output representations from the encoder
    :param attention_mask:
        encoder mask
    :param embedding_size:
        emb/hidden size of the enc representations
    :param padding_idx:
        pad token index; used for mask purposes.
    :param right_padded:
        whether the input is right padded (true) or left padded (false)
    :return (new_out, new_mask):
        return the encoder output and encoder mask, appropriately concatenated.
    """
    # bsz, n_docs = inputs.size(0), encoder_hidden_states.size(0) // inputs.size(0)
    # embedding_size = config.d_model
    n_docs = encoder_hidden_states.size(0) // bsz
    # assert n_docs == self.opt.n_docs

    split_enc_out = encoder_hidden_states.split([n_docs] * bsz, dim=0)
    split_mask = attention_mask.split([n_docs] * bsz, dim=0)

    concat_outs: List[torch.Tensor] = []
    concat_lengths = []
    for i in range(bsz):
        mask_i = split_mask[i].view(-1)
        out_i = split_enc_out[i].reshape(-1, embedding_size)[mask_i]
        concat_outs.append(out_i)
        concat_lengths.append(out_i.size(0))

    new_out = encoder_hidden_states.new(bsz, max(concat_lengths), embedding_size)
    new_mask: torch.BoolTensor = attention_mask.new(bsz, max(concat_lengths))  # type: ignore
    new_out.fill_(pad_token_id)
    new_mask.fill_(False)

    for i, (out_i, length_i) in enumerate(zip(concat_outs, concat_lengths)):
        if right_padded:
            new_out[i, :length_i] = out_i
            new_mask[i, :length_i] = True
        else:
            new_out[i, new_out.size(1) - length_i :] = out_i
            new_mask[i, new_out.size(1) - length_i :] = True

    return new_out, new_mask

def normal_tokenizer(
    text: str,
    remove_contractions: bool = True,
    stem=True,
    include_pronouns: bool = False,
) -> List[str]:
    """
    Returns normalized tokens for `text`
    """
    assert isinstance(text, str), 'The only valid arg type is str.'
    text = text.strip().lower()
    if not text:
        return []

    if remove_contractions:
        text = contractions.fix(text)

    tokens = word_tokenize(text)

    if stem:
        tokens = [STEMMER.stem(t) for t in tokens]
    stop_words = _STOP_WORDS if not include_pronouns else _STOP_WORDS_WITH_PRONOUNS
    tokens = [t for t in tokens if t not in stop_words]

    return tokens


def no_overlap(
    data_source: List[str],
    canidate_text: str,
    non_overlap_threshold: int = 1,
    remove_contractions: bool = True,
) -> bool:
    """
    Returns True if tokenized `canidate_text` text has at least `non_overlap_threshold`
    tokens NOT overlapping with all texts in `data_source` list.

    We use this for de-duplicating (for example memory entries)
    """
    tokenized_candidate_text = set(
        normal_tokenizer(
            canidate_text,
            remove_contractions=remove_contractions,
        )
    )
    for dst in data_source:
        dst_tokens = set(
            normal_tokenizer(
                dst,
                remove_contractions=remove_contractions,
            )
        )
        n_overlap = len(dst_tokens - tokenized_candidate_text)
        logging.debug(
            f'"{dst}" AND "{canidate_text}" ---- non-overlapping tokens: {n_overlap}'
        )
        if n_overlap < non_overlap_threshold:
            logging.info(
                f'"{canidate_text}" has too much overlap with "{dst}": discarding repeated entry.'
            )
            return False

    return True



def clean_text(text: str) -> str:
    """
    Removes all special tokens from an incoming text.
    """
    for token in CONST.ALL_SPECIAL_TOKENS:
        text = text.replace(f" {token}", '')
        text = text.replace(f"{token} ", '')
        text = text.replace(token, '')

    return text


class Decision(Enum):

    ALWAYS = 'always'
    NEVER = 'never'
    COMPUTE = 'compute'


class Module(Enum):

    SEARCH_DECISION = 'sdm'
    MEMORY_DECISION = 'mdm'
    SEARCH_QUERY = 'sgm'
    MEMORY_GENERATOR = 'mgm'
    CONTEXTUAL_KNOWLEDGE = 'ckm'
    MEMORY_KNOWLEDGE = 'mkm'
    SEARCH_KNOWLEDGE = 'skm'
    CONTEXTUAL_DIALOGUE = 'crm'
    MEMORY_DIALOGUE = 'mrm'
    SEARCH_DIALOGUE = 'srm'
    VANILLA_DIALOGUE = 'vrm'
    GROUNDED_DIALOGUE = 'grm'
    OPENING_DIALOGUE = 'orm'

    @staticmethod
    def dialogue_modules() -> List['Module']:
        return [
            Module.CONTEXTUAL_DIALOGUE,
            Module.MEMORY_DIALOGUE,
            Module.SEARCH_DIALOGUE,
            Module.VANILLA_DIALOGUE,
            Module.GROUNDED_DIALOGUE,
            Module.OPENING_DIALOGUE,
        ]

    @staticmethod
    def knowledge_modules() -> List['Module']:
        return [
            Module.CONTEXTUAL_KNOWLEDGE,
            Module.MEMORY_KNOWLEDGE,
            Module.SEARCH_KNOWLEDGE,
        ]

    @staticmethod
    def decision_modules() -> List['Module']:
        return [
            Module.SEARCH_DECISION,
            Module.MEMORY_DECISION,
        ]

    def decision_do_key(self) -> str:
        return {
            Module.SEARCH_DECISION: 'search_decision_do_search_reply',
            Module.MEMORY_DECISION: 'memory_decision_do_access_reply',
        }[self]

    def decision_dont_key(self) -> str:
        return {
            Module.SEARCH_DECISION: 'search_decision_dont_search_reply',
            Module.MEMORY_DECISION: 'memory_decision_dont_access_reply',
        }[self]

    def message_name(self) -> str:
        """
        Name used to access output of a module in a Message.
        """
        return self.tag_to_agent()[self.value]

    def agent_name(self):
        """
        Display name for user's, and debugging, sake.
        """
        return f"{self.tag_to_agent()[self.value]}_agent"

    def model_file_path_key(self):
        """
        Opt key for model file path for this agent.
        """
        return f"{self.tag_to_agent()[self.value]}_response_model_path"

    def tag(self):
        return self.value

    def is_dialogue(self):
        return self in Module.dialogue_modules()

    def is_knowledge(self):
        return self in Module.knowledge_modules()

    def skip_search(self):
        return self.value not in ['mkm', 'skm']

    def is_one_turn_history(self):
        return self.value in ['mdm', 'sdm', 'mgm']

    @staticmethod
    def tag_to_agent() -> Dict[str, str]:
        return {
            'sdm': 'search_decision',
            'mdm': 'memory_decision',
            'sgm': 'search_query',
            'mgm': 'memory_generator',
            'ckm': 'contextual_knowledge',
            'mkm': 'memory_knowledge',
            'skm': 'search_knowledge',
            'crm': 'contextual_dialogue',
            'mrm': 'memory_dialogue',
            'srm': 'search_dialogue',
            'vrm': 'vanilla_dialogue',
            'grm': 'grounded_dialogue',
            'orm': 'opening_dialogue',
        }

    ##############
    # R2C2 Func. #
    ##############
    def r2c2_prompt(self):
        """
        Prompt token for this module.
        """
        return {
            'sdm': CONST.IS_SEARCH_REQUIRED,
            'mdm': CONST.IS_MEMORY_REQUIRED,
            'sgm': CONST.GENERATE_QUERY,
            'mgm': CONST.GENERATE_MEMORY,
            'mkm': CONST.ACCESS_MEMORY,
            'ckm': CONST.EXTRACT_ENTITY,
            'skm': CONST.GENERATE_KNOWLEDGE,
            'mrm': '',
            'crm': '',
            'srm': '',
            'vrm': '',
            'grm': '',
            'orm': '',
        }[self.value]

    def special_tokens(self):
        return {
            Module.CONTEXTUAL_KNOWLEDGE.message_name(): (
                CONST.BEGIN_ENTITY,
                CONST.END_ENTITY,
            ),
            Module.MEMORY_KNOWLEDGE.message_name(): (
                CONST.BEGIN_MEMORY,
                CONST.END_MEMORY,
            ),
            Module.SEARCH_KNOWLEDGE.message_name(): (
                CONST.TOKEN_KNOWLEDGE,
                CONST.TOKEN_END_KNOWLEDGE,
            ),
        }[self.tag_to_agent()[self.value]]

    #############
    # OPT Func. #
    #############
    def opt_prompt(self):
        """
        Prompt token for OPT models.
        """
        return {
            'sdm': "Person 2 must decide whether to search the internet.\n\n",
            'mdm': "A conversation between two persons. Person 2 must consult their notes about Person 1.\n\n",
            'sgm': "Person 2 must write a search query for a search engine.\n\n",
            'mgm': "A conversation between two persons. Person 2 writes a note about Person 1 to help remember information for later.\n\n",
            'ckm': "A conversation between two persons. Person 2 recalls a previous topic in the conversation.\n\n",
            'skm': "A conversation between two persons. Person 2 finds an interesting fact from the internet.\n\n",
            'mkm': "A conversation between two persons. Person 2 recalls an interesting fact about Person 1 or Person 2.\n\n",
            'crm': "A conversation between two persons. Person 2 would like to continue talking about a previous topic in the conversation.\n\n",
            'mrm': "A conversation between two persons. Person 2 would like to chat about an interesting fact about Person 1 or Person 2.\n\n",
            'srm': "A conversation between two persons. Person 2 would like to tell Person 1 about something Person 2 found on the internet.\n\n",
            'vrm': "A conversation between two persons.\n\n",
            'grm': "A conversation between two persons. Person 2 responds in a given style.\n\n",
            'orm': "A conversation between two persons. Person 2 begins the conversation given information about Person 1.\n\n",
        }[self.value]

    def opt_final_prefix(self):
        """
        Final prefix to put after constructing context for OPT.
        """
        import projects.bb3.prompts as PROMPT

        return {
            'sdm': PROMPT.SEARCH_DECISION,
            'mdm': PROMPT.MEMORY_DECISION,
            'sgm': PROMPT.QUERY_GEN_PREFIX,
            'mgm': PROMPT.MEMORY_GEN_PREFIX,
            'mkm': PROMPT.MEMORY_KNOWLEDGE_PREFIX,
            'ckm': PROMPT.CONTEXTUAL_KNOWLEDGE_PREFIX,
            'skm': PROMPT.SEARCH_KNOWLEDGE_PREFIX,
            'mrm': PROMPT.SELF_PREFIX,
            'crm': PROMPT.SELF_PREFIX,
            'srm': PROMPT.SELF_PREFIX,
            'vrm': PROMPT.SELF_PREFIX,
            'grm': PROMPT.SELF_PREFIX,
            'orm': PROMPT.OPENING_PREFIX,
        }[self.value]

    def opt_shots(self) -> str:
        import projects.bb3.prompts as PROMPT

        return PROMPT.SHOTS[self]

    def opt_pre_context_tok(self):
        import projects.bb3.prompts as PROMPT

        if (
            self.is_knowledge() and self is not Module.CONTEXTUAL_KNOWLEDGE
        ) or self is Module.GROUNDED_DIALOGUE:
            return PROMPT.PRE_CONTEXT_TOK
        return ''

    def opt_post_context_tok(self):
        import projects.bb3.prompts as PROMPT

        if self.is_dialogue() and self not in [
            Module.VANILLA_DIALOGUE,
            Module.OPENING_DIALOGUE,
            Module.GROUNDED_DIALOGUE,
        ]:
            return PROMPT.POST_CONTEXT_TOK
        return ''

    def opt_dialogue_knowledge_prefix(self):
        return {
            'mrm': "Personal Fact: ",
            'crm': "Previous Topic: ",
            'srm': "Interesting Fact: ",
        }[self.value]




################
# Memory Utils #
################


class MemoryUtils:
    @staticmethod
    def is_opt_ft_mem_format(memory_text: str) -> bool:
        return (
            ':' in memory_text
            and memory_text.startswith(PROMPT.SELF_MEMORY_PREFIX)
            or memory_text.startswith(PROMPT.PARTNER_MEMORY_PREFIX)
        )

    @staticmethod
    def _is_r2c2_format(memory_text: str) -> bool:
        return (
            ':' in memory_text
            and memory_text.startswith("your persona")
            or memory_text.startswith("partner's persona")
        )

    @staticmethod
    def _is_opt_prompt_format(memory_text: str) -> bool:
        return memory_text.startswith(PROMPT.SELF_PREFIX) or memory_text.startswith(
            PROMPT.PARTNER_PREFIX
        )

    @staticmethod
    def validate_memory_format(memory_text: str):
        assert (
            MemoryUtils._is_r2c2_format(memory_text)
            or MemoryUtils._is_opt_prompt_format(memory_text)
            or MemoryUtils.is_opt_ft_mem_format(memory_text)
        ), f'Provided memory "{memory_text}" has invalid format for chatbot memory field.'

    @staticmethod
    def split_prefix_memory(memory_text: str) -> Tuple[str, str]:
        MemoryUtils.validate_memory_format(memory_text)
        try:
            prfx, mem = memory_text.split(':', 1)
        except ValueError:
            # prompt agent sometimes says things like,
            # "Person 2 is"...
            assert memory_text.startswith(PROMPT.SELF_PREFIX) or memory_text.startswith(
                PROMPT.PARTNER_PREFIX
            )
            if memory_text.startswith(PROMPT.SELF_PREFIX):
                prfx, mem = (
                    memory_text[: len(PROMPT.SELF_PREFIX)],
                    memory_text[len(PROMPT.SELF_PREFIX) + 1 :],
                )
            else:
                prfx, mem = (
                    memory_text[: len(PROMPT.PARTNER_PREFIX)],
                    memory_text[len(PROMPT.PARTNER_PREFIX) + 1 :],
                )
        return prfx.strip(), mem.strip()

    @staticmethod
    def is_valid_memory(
        chatbot_memory: Union[List[str], Dict[str, int]],
        new_memory: str,
        new_memory_prefix: str,
    ) -> bool:
        """
        Return whether the memory is valid.

        It rejects new memory entry as invalid if one similar to it exists already.
        """
        if new_memory in (
            CONST.NOPERSONA,
            PROMPT.NO_MEMORY,
            '',
            APIUtils.METASEQ_FAIL_MESSAGE_TEXT,
        ):
            return False

        # Rejecting if identical memories exist already.
        if not chatbot_memory:
            # No need to dedup if there is no existing memory
            return True

        # filtering for the memories that applies to the current person (self or other)
        person_memories = [
            MemoryUtils.split_prefix_memory(mem)[1]
            for mem in chatbot_memory
            if MemoryUtils.split_prefix_memory(mem)[0] == new_memory_prefix
        ]
        if not person_memories:
            # No memory on record for this person
            return True

        return no_overlap(person_memories, new_memory)

    @staticmethod
    def get_memory_prefix(self_or_partner: str, model_type: str) -> str:
        """
        Return memory prefix.
        """
        assert self_or_partner in ['self', 'partner']
        assert model_type in ['R2C2', 'OPT']
        self_prefix = (
            'your persona' if model_type == 'R2C2' else PROMPT.SELF_MEMORY_PREFIX
        )
        partner_prefix = (
            "partner's persona"
            if model_type == 'R2C2'
            else PROMPT.PARTNER_MEMORY_PREFIX
        )
        if self_or_partner == 'self':
            return self_prefix
        else:
            return partner_prefix

    @staticmethod
    def add_memory_prefix(memory: str, self_or_partner: str, model_type: str) -> str:
        """
        Ensure that the memory has a "persona" prefix.

        :param memory:
            memory to prefix
        :param self_or_partner:
            string representing if this is a self memory or partner memory
        :param model_type:
            which model we're working with

        :return prefixed_mem:
            return a memory with the appropriate prefix.
        """
        assert self_or_partner in ['self', 'partner']
        assert model_type in ['R2C2', 'OPT']
        prefix = MemoryUtils.get_memory_prefix(self_or_partner, model_type)
        if model_type == 'R2C2':
            if not memory.startswith(prefix):
                memory = f"{prefix}: {memory}"
        elif self_or_partner == 'self':
            memory = memory.replace("Person 1", "Person 2")
            if not memory.startswith('Person'):
                memory = f"{prefix}: {memory}"
        else:
            memory = memory.replace("Person 2", "Person 1")
            if not memory.startswith('Person'):
                memory = f"{prefix}: {memory}"

        return memory

    @staticmethod
    def maybe_add_memory_prefix(
        memory: str, self_or_partner: str, model_type: str
    ) -> str:
        """
        Maybe add it if it's not there.
        """
        if not MemoryUtils.is_opt_ft_mem_format(memory):
            memory = MemoryUtils.add_memory_prefix(memory, self_or_partner, model_type)
        return memory

    @staticmethod
    def _build_query_representation(
        query: str, dictionary,
    ) -> Dict[str, Any]:
        rep = {}
        rep['words'] = {}
        words = [w for w in dictionary.tokenize(query.lower())]
        rw = rep['words']
        used = {}
        for w in words:
            rw[w] = 1.0 / (1.0 + math.log(1.0 + dictionary.freq[w]))
            used[w] = True
        rep['norm'] = math.sqrt(len(words))
        return rep

    @staticmethod
    def maybe_reduce_memories(
        text: str, memories: List[str], dictionary,
    ) -> List[str]:
        """
        TFIDF-Match memories with the textual input to reduce num memories.

        :param observation:
            raw observation
        :param memories:
            memories from which to choose

        :return memories:
            return - potentially shortened - list of memories
        """
        new_memories = []
        if (
            not memories or len(memories) < 32
        ):  # 512 / 16, assuming 16 tokens max per memory
            return memories
        mpq = MaxPriorityQueue(1000)
        query = MemoryUtils._build_query_representation(text, dictionary)
        for m in memories:
            score = score_match(query, m, 0, dictionary)
            mpq.add(m, score)
        new_memories = list(reversed(mpq))[:32]
        return new_memories

    @staticmethod
    def get_available_memories(
        text: str,
        memories: Dict[str, int],
        in_session_memories: Set[str],
        dictionary: Optional = None,
        ignore_in_session_memories: bool = False,
        memory_overlap_threshold: float = 0.0,
        memory_hard_block_for_n_turns: int = 0,
        memory_soft_block_decay_factor: float = 0.0,
    ) -> List[str]:
        """
        Return available memories.

        :param text:
            incoming partner text
        :param memories:
            list of all memories
        :param in_session_memories:
            set of memories generated within the current conversation session
        :param ignore_in_session_memories:
            whether to ignore memories generated within the session
        """
        available_memory = []
        for memory, turns_since_used in memories.items():
            turns_since_used = int(turns_since_used)
            # check if we should ignore in session memories
            if ignore_in_session_memories and memory in in_session_memories:
                continue
            # check overlap
            if memory_overlap_threshold > 0:
                non_stopword_memory = ' '.join(
                    normal_tokenizer(memory.split(':')[-1], include_pronouns=True)
                )
                non_stopword_text = ' '.join(
                    normal_tokenizer(text, include_pronouns=True)
                )
                if (
                    F1Metric.compute(non_stopword_memory, [non_stopword_text]).value()
                    < memory_overlap_threshold
                ):
                    continue
            # check hard block
            if turns_since_used < memory_hard_block_for_n_turns:
                continue
            # check soft block
            if memory_soft_block_decay_factor > 0 and random.random() < (
                memory_soft_block_decay_factor**turns_since_used
            ):
                continue

            available_memory.append(memory)
        return MemoryUtils.maybe_reduce_memories(text, available_memory, dictionary)

    @staticmethod
    def add_memory(memory: str, memories: Dict[str, int]) -> Dict[str, int]:
        """
        Add memory to the memory store.

        :param memory:
            memory to add
        :param memories:
            all the memories

        :return memories:
            return memories with new memory
        """
        if not memory:
            return memories
        assert memory not in memories
        memories[memory] = 0
        return memories

    @staticmethod
    def update_memory_usage(
        used_memory: str, memories: Dict[str, int]
    ) -> Dict[str, int]:
        """
        Update memories to indicate that a memory was used.

        :param memory:
            used memory
        :param memories:
            all the memories

        :return memories:
            return memories with usage updated
        """
        if not used_memory or used_memory not in memories:
            return memories
        for mem in memories:
            if mem == used_memory:
                memories[mem] = 0
            else:
                memories[mem] += 1
        return memories


#################
# OPT API UTILS #
#################


class APIUtils:

    DEFAULT_KEY = os.environ.get("USER", "parlai")
    DEFAULT_SERVER = "DEFAULT_API_SERVER:6010"
    DEFAULT_SERVER_TIMEOUT = 600
    METASEQ_FAIL_MESSAGE_TEXT = 'METASEQ RESPONSE FAILED.'

    @staticmethod
    def is_request_failed_response(resp):
        """
        Whether the requests to Metaseq worker have failed.

        It checks this based on the existences of the failure reasons as they get
        accumulated in `_make_request` functionn calls.
        """
        return len(
            resp.get('failures', [])
        ) > 0 or APIUtils.METASEQ_FAIL_MESSAGE_TEXT in resp.get('text', '')

    @staticmethod
    async def make_request(
        session,
        server: str,
        api_key: str,
        prompt: str,
        min_tokens: int = 0,
        n: int = 1,
        max_tokens: int = 32,
        best_of: int = 1,
        top_p: float = -1,
        echo: bool = False,
        stop: Optional[str] = None,
        temperature: float = 1.0,
        num_retry_on_api_exception=-1,
        lambda_decay: float = -1,
        omega_bound: float = 0.3,
        request_delay: float = 0.5,
        alpha_presence: float = 0.0,
        alpha_frequency: float = 0.0,
        alpha_presence_src: float = 0.0,
        alpha_frequency_src: float = 0.0,
        alpha_src_penalty_end_idx: int = -1,
    ):
        data = {
            'prompt': prompt,
            'min_tokens': min_tokens,
            'max_tokens': max_tokens,
            'best_of': best_of,
            'top_p': top_p,
            'stop': stop,
            'temperature': temperature,
            'echo': echo,
            'lambda_decay': lambda_decay,
            'omega_bound': omega_bound,
            'alpha_presence': alpha_presence,
            'alpha_frequency': alpha_frequency,
            'alpha_presence_src': alpha_presence_src,
            'alpha_frequency_src': alpha_frequency_src,
            "alpha_src_penalty_end_idx": alpha_src_penalty_end_idx,
        }
        init_request_delay = request_delay
        past_exceptions = []
        while True:
            if (
                num_retry_on_api_exception >= 0
                and len(past_exceptions) > num_retry_on_api_exception
            ):
                logging.error('Reached maximum retries, returning failure message.')
                return {
                    'failures': past_exceptions,
                }
            try:
                logging.debug(f'Making request: {data}')
                headers = {'Authorization': f'Bearer {api_key}'}
                async with session.post(
                    f'{server}/completions', json=data, headers=headers
                ) as resp:
                    resp_text = await resp.text()
                    obj = json.loads(resp_text)
                    if 'error' in obj:
                        request_delay *= 2
                        logging.warning(f"Error: {obj['error']}")
                        past_exceptions.append(f"API Error: {obj['error']}")
                        logging.debug(past_exceptions[-1])
                        continue
                    debug = json.dumps(obj, sort_keys=True)
                    logging.debug(f'GPT-Z response: {debug}')
                    request_delay = init_request_delay
                    return obj
            except asyncio.TimeoutError as e:
                past_exceptions.append(
                    f'Timout a response for prompt {len(prompt)}\n{e}'
                )
                logging.warning(past_exceptions[-1])
                request_delay *= 2
            except aiohttp.client_exceptions.ClientOSError as e:
                past_exceptions.append(
                    f'Retrying a response for prompt {len(prompt)}\n{e}'
                )
                logging.warning(past_exceptions[-1])
                request_delay *= 2
            except json.decoder.JSONDecodeError as e:
                past_exceptions.append(
                    f'Got a bad response, {resp_text}. Retrying.\n{e}'
                )
                logging.debug(past_exceptions[-1])
                request_delay *= 2

            time.sleep(request_delay)

    @staticmethod
    async def async_request_many(
        server: str,
        api_key: str,
        prompts: List[str],
        timeout: Optional[int] = None,
        max_num_tries: int = -1,
        **kwargs,
    ):
        connector = aiohttp.TCPConnector(limit=0)
        timeout = aiohttp.ClientTimeout(total=timeout)
        async with aiohttp.ClientSession(
            timeout=timeout, connector=connector
        ) as session:
            tasks = []
            for prompt in prompts:
                tasks.append(
                    asyncio.ensure_future(
                        APIUtils.make_request(
                            session=session,
                            server=server,
                            api_key=api_key,
                            prompt=prompt,
                            num_retry_on_api_exception=max_num_tries,
                            **kwargs,
                        )
                    )
                )
            results = await asyncio.gather(*tasks)
            return results

    @staticmethod
    def request_many(
        server,
        api_key,
        prompts: List[str],
        timeout: Optional[int] = None,
        max_num_tries: int = -1,
        **kwargs,
    ) -> List[Dict[str, Any]]:
        return asyncio.run(APIUtils.async_request_many(server=server, api_key=api_key, prompts=prompts, timeout=timeout, max_num_tries=max_num_tries, **kwargs))  # type: ignore

    @staticmethod
    def compute_perplexities(
        observations: List[Message], results: List[Dict[str, Any]]
    ) -> Tuple[List[PPLMetric], List[PPLMetric]]:
        """
        Compute perplexities from API call.

        :param observations:
            incoming observations
        :param results:
            results from API call

        :return ppls:
            return list of perplexities
        """
        label_perplexities = []
        all_perplexities = []
        for obs, result in zip(observations, results):
            # need text offsets to figure out what comes from the prompt
            prompt_len = len(obs['prompt'])

            text_off = result['choices'][0]['logprobs']['text_offset']
            start_label = [i for i, off in enumerate(text_off) if off <= prompt_len]
            assert len(start_label) > 0
            start_label = start_label[-1]
            all_log_probs = result['choices'][0]['logprobs']['token_logprobs']
            if not all(l <= 0 for l in all_log_probs):
                logging.warning(
                    f'Out of {len(all_log_probs)} log probs, {len([l for l in all_log_probs if l > 0])} are > 0. '
                    'Clamping to 0'
                )
                all_log_probs = [min(l, 0) for l in all_log_probs]

            log_probs = all_log_probs[start_label:]
            loss = -sum(log_probs)
            label_perplexities.append(PPLMetric(loss, len(log_probs)))
            all_perplexities.append(PPLMetric(-sum(all_log_probs), len(all_log_probs)))
        return label_perplexities, all_perplexities
