import os
import copy
import json
import torch
import pickle
import argparse
from datetime import datetime
from rich import print as rich_print
from rich.console import Console


from transformers import AutoTokenizer, BartConfig
from transformers.modeling_outputs import BaseModelOutput
from typing import List, Tuple, Optional, Dict, Any, Set, Union

from model import rename_state_dict, TrainInferenceComboHFR2C2Model
from hexa.models.modeling_r2c2 import R2C2ForConditionalGeneration

from hexa.utils import logging
from hexa.utils.message import Message
from hexa.utils.document import Document
from hexa.utils.constants import Constant
from hexa.utils.metrics import AverageMetric, F1Metric, RougeMetric, BleuMetric
from hexa.utils.search_retriever import SearchQuerySearchEngineRetriever

from hexa.utils.inference_utils import (
    colorize,
    clean_text,
    concat_enc_outs,
    Module,
    Decision,
    MemoryUtils,
)

from hexa.utils.base_utils import (
    build_config,
    set_seed,
)

from multitask_datasets import (
    MultiTaskCollator,
    InhouseDataset
)

BLANK_DOC = Document('', '', '')
CONST = Constant()
console = Console()

set_seed(42)


def _str2bool(text):
    if text.lower() in ['yes', 'true']:
        return True
    else:
        return False


class BB3InferenceAgent():

    def __init__(self, config, tokenizer=None):
        self.config = config
        self.n_docs = config.dataset.n_docs
        self.knowledge_conditioning = config.knowledge_conditioning
        self.memory_decision = config.memory_decision
        self.debug = config.debug
        self.config.fp16 = config.trainer.fp16

        self.device = torch.device(config.device_id)
        self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(config.hf_tokenizer_path)
        self.model = self.load_model()
        self.search_retriever = SearchQuerySearchEngineRetriever(config.dataset, self.device)

        config.dataset.fp16 = config.trainer.fp16
        self.dataset_fn = InhouseDataset(config.dataset, self.tokenizer)
        self.collator_fn = MultiTaskCollator(
            pad_token_id=0,
            text_truncate=config.dataset.truncate,
            n_docs=self.n_docs,
            device=self.device,
        )

        self.pad_token_id = self.tokenizer.pad_token_id
        self.reset()

        self.model_type = 'R2C2'
        self.rank = config.local_rank

    def reset(self):
        self.memories = {}
        self.dialogue_turns = {}
        self.histories = []

    @staticmethod
    def clean_text(text: str) -> str:
        """
        Removes all special tokens from an incoming text.
        """
        # for token in ALL_SPECIAL_TOKENS:
        for token in CONST.ALL_SPECIAL_TOKENS:
            text = text.replace(f" {token}", '')
            text = text.replace(f"{token} ", '')
            text = text.replace(token, '')
        return text


    @staticmethod
    def pad_like_parlai(vec: torch.Tensor) -> torch.Tensor:
        batch_size, len_item = vec.size()
        len_tobe_added = 8 - len_item % 8
        if len_tobe_added != 8:
            zero_tensor = torch.zeros((batch_size, len_tobe_added)) + 0
            vec = torch.cat([vec, zero_tensor.to(vec.device)], dim=1).long()
        return vec


    @staticmethod
    def receive_inputs() -> Message:
        reply = Message()
        reply['id'] = 'human'
        try:
            # reply_text = input("Enter Your Message: ")
            reply_text = input(colorize("[Enter Your Message]", 'yellow') + ' ')
        except EOFError:
            raise StopIteration

        reply_text = reply_text.replace('\\n', '\n')
        reply['text'] = reply_text
        reply['episode_done'] = False
        return reply


    @staticmethod
    def split_memories(raw_memories: Dict) -> str:
        self_memories = [
            m.replace('your persona: ', '')
            for m in raw_memories
            if m.startswith('your')
        ]
        partner_memories = [
            m.replace("partner's persona: ", '')
            for m in raw_memories
            if m.startswith('partner')
        ]

        memories = f"your persona: {' '.join(self_memories)}\npartner's persona: {' '.join(partner_memories)}"
        return memories


    @staticmethod
    def check_empty_doc(top_docs: List[Document]) -> bool:
        return all([d._text==BLANK_DOC._text for d in top_docs])


    def load_model(self) -> R2C2ForConditionalGeneration:
        model = TrainInferenceComboHFR2C2Model(self.config)
        return model


    def get_memory_str(self):
        if len(self.memories)==0:
            return ''
        memories = self.split_memories(self.memories)
        memories += '\n'
        return memories


    def get_dialogue_history_str(self):
        full_history = '\n'.join(self.histories)
        if full_history:
            full_history += '\n'
        return full_history


    def retrieved_doc_scores_from_memories(
        self,
    ) -> Tuple[List[List[Document]], torch.Tensor]:
        top_docs = []
        top_doc_scores = []
        docs_i = []
        scores_i = []
        max_n_docs: int = self.n_docs

        for memory in self.memories.keys():
            docs_i.append(Document(docid='', text=memory, title=''))
            scores_i.append(1)

        # Change this debug later
        max_n_docs = max(max_n_docs, len(docs_i))
        top_docs.append(docs_i)
        top_doc_scores.append(scores_i)

        # Pad with empty docs
        i = 0
        n_empty = max_n_docs - len(top_docs[i])
        if n_empty:
            top_docs[i] = top_docs[i] + [BLANK_DOC] * n_empty
            top_doc_scores[i] = top_doc_scores[i] + [0] * n_empty

        return top_docs, torch.Tensor(top_doc_scores).to(self.device)


    def retrieved_doc_scores_from_internet(
        self,
        query: str,
    ) -> Tuple[List[List[Document]], torch.Tensor]:
        top_docs, top_doc_scores = self.search_retriever.retrieve_and_score(query)
        self.search_retriever.top_docs = [[str(d) for d in ds] for ds in top_docs]
        return top_docs, top_doc_scores


    def concat_docs_and_input(
        self,
        input_text: str,
        top_docs: List[List[Document]],
        max_n_docs: int,
    ) -> Dict[str, torch.Tensor]:
        # tokenize input text
        text_vec = torch.LongTensor(self.tokenizer.encode(input_text))
        text_vec = self.dataset_fn._check_truncate(text_vec, self.dataset_fn.text_truncate-self.dataset_fn.offset, truncate_left=True)
        text_vec = self.dataset_fn._add_start_end_tokens(text_vec, add_start=True, add_end=True)
        text_vec = self.dataset_fn._add_start_end_tokens(text_vec, add_start=False, add_end=True)
        text_vec = text_vec.unsqueeze(0)

        # concat docs
        # max_n_docs = top_doc_scores.size(1)
        expanded_input = self.dataset_fn.concat_docs_and_input(
            input=text_vec,
            input_lengths=text_vec.ne(self.pad_token_id).sum(1),
            top_docs=top_docs,
            max_num_docs=max_n_docs,
            right_padded=True,
        )
        expanded_input = self.dataset_fn._check_dimension(expanded_input).to(self.device)
        if expanded_input.dim() == 1:
            expanded_input = expanded_input.unsqueeze(0)

        # apply collator function
        single_batch = [{'docs_text_vec': expanded_input}]
        input_ids = self.collator_fn._pad_sequence(single_batch, 'docs_text_vec', max_n_docs)
        if len(input_ids.shape)==3:
            input_ids = input_ids.view(-1, input_ids.shape[-1])
        attention_mask = input_ids.ne(self.pad_token_id)

        return {'input_ids': input_ids, 'attention_mask': attention_mask}


    def compute_metrics(
            self,
            input_text: str,
            label_ids: torch.Tensor,
            predicted_text: str,
    ):
        input_ids, _ = self._turn_text_to_vec(input_text)
        attention_mask= input_ids.ne(self.tokenizer.pad_token_id)
        skip_retrieval_vec = torch.ones((input_ids.shape[0],)).to(self.device).bool()
        is_doc_added = torch.zeros((input_ids.shape[0],)).to(self.device).bool()

        with torch.no_grad():
            outputs = self.model.forward_loss(
                input_ids=input_ids,
                attention_mask=attention_mask,
                lm_labels=label_ids,
                skip_retrieval_vec=skip_retrieval_vec,
                is_doc_added=is_doc_added,
            )
        # loss = outputs['loss'].sum(-1, keepdim=True).cpu()
        #
        # notnull = label_ids.ne(self.pad_token_id).cpu()
        # target_tokens = metric_target_tokens = notnull.long().sum(dim=-1).cpu()
        # preds = outputs['logits'].cpu().max(-1)
        # metric_correct = ((label_ids == preds) * notnull).sum(dim=-1).cpu()
        loss = outputs['loss'].unsqueeze(0).cpu()
        target_tokens = metric_target_tokens = outputs['metric_target_tokens'].cpu()
        metric_correct = outputs['metric_correct'].cpu()

        ground_truth_text = self.tokenizer.batch_decode(label_ids, skip_special_tokens=True)[0]

        f1_score = F1Metric.compute(ground_truth_text, [predicted_text])
        # rouge1_score, rouge2_score, rougeL_score = RougeMetric.compute_many(ground_truth_text, [predicted_text])
        bleu_score = BleuMetric.compute(ground_truth_text, [predicted_text])


        metrics = {
            'loss': loss.sum() / target_tokens.sum(),
            'metric_loss': loss.tolist(),
            'metric_target_tokens': metric_target_tokens,
            'metric_correct': metric_correct,
            'f1': f1_score,
            # 'rouge1': rouge1_score,
            # 'rouge2': rouge2_score,
            # 'rougeL': rougeL_score,
            'bleu': bleu_score,
        }
        return metrics

    def get_knowledge_response(
        self,
        knowledge_type: str,
        full_history: str,
        user_utterance: str,
        top_docs: List[List[Document]],
        top_doc_scores: torch.Tensor,
        num_beams=1,
        min_length=1,
        max_length=128,
        do_sample=False,
    ) -> Dict[str, Any]:
        assert knowledge_type in ['search', 'memory']

        if knowledge_type == 'search':
            suffix = CONST.GENERATE_KNOWLEDGE
        elif knowledge_type == 'memory':
            suffix = CONST.ACCESS_MEMORY

        input_text = full_history + user_utterance + ' ' + suffix

        batch = self.concat_docs_and_input(
            input_text=input_text,
            top_docs=top_docs,
            max_n_docs=top_doc_scores.size(1),
        )

        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        encoder_hidden_states, encoder_attention_mask = self.enc_forward_with_docs(input_ids, attention_mask)

        model_kwargs = {
            'encoder_outputs': BaseModelOutput(encoder_hidden_states, None, encoder_attention_mask),
            'attention_mask': encoder_attention_mask,
        }
        generated_output = self.generate(
            min_length=min_length,
            num_beams=num_beams,
            do_sample=do_sample,
            input_text=None,
            **model_kwargs,
        )

        generated_output['top_docs'] = top_docs[0]
        generated_output['top_doc_scores'] = top_doc_scores[0]

        return generated_output


    def combine_knowledge_response(
        self,
        full_history: str,
        user_utterance: str,
        skm_output: Dict[str, Any],
        mkm_output: Dict[str, Any],
        ckm_output: Dict[str, Any],
        num_beams: int=1,
        do_sample: bool=False,
        max_length: int=128,
        min_length: int=1,
        return_input_text: bool=False,
        **kwargs
    ) -> Dict[str, Any]:

        temp_history = '\n'
        knowledge_obs = Message({})

        if skm_output['text']:
            knowledge_obs['search_knowledge'] = skm_output['text']
            knowledge_obs['search_knowledge_top_docs'] = skm_output['top_docs']

        if mkm_output['text']:
            knowledge_obs['memory_knowledge'] = mkm_output['text']
            knowledge_obs['memory_knowledge_top_docs'] = mkm_output['top_docs']

        if ckm_output['text']:
            knowledge_obs['contextual_knowledge'] = ckm_output['text']
            knowledge_obs['contextual_knowledge_top_docs'] = ckm_output['top_docs']

        for m in Module.knowledge_modules():
            if m.message_name() in knowledge_obs:
                tokens = m.special_tokens()
                temp_history += (
                    f"{tokens[0]} {knowledge_obs[m.message_name()]} {tokens[1]}"
                )

        input_text = self.clean_text(full_history) + user_utterance + temp_history
        generated_output = self.generate(
            input_text,
            num_beams=num_beams,
            do_sample=do_sample,
            max_length=max_length,
            min_length=min_length,
            **kwargs
        )

        generated_output['knowledge_obs'] = knowledge_obs

        if return_input_text:
            generated_output['input_text'] = input_text
        return generated_output


    def separate_knowledge_response(
        self,
        full_history: str,
        user_utterance: str,
        skm_output: Dict[str, Any],
        mkm_output: Dict[str, Any],
        ckm_output: Dict[str, Any],
        num_beams: int=1,
        do_sample: bool=False,
        max_length: int=128,
        min_length: int=1,
        return_input_text: bool=False,
    ) -> Dict[str, Any]:
        _input_text = full_history + '\n' + user_utterance

        dialogue_obs = Message({})
        knowledge_obs = Message({})

        if skm_output['text']:
            knowledge_obs['search_knowledge'] = skm_output['text']

        if mkm_output['text']:
            knowledge_obs['memory_knowledge'] = mkm_output['text']

        if ckm_output['text']:
            knowledge_obs['contextual_knowledge'] = ckm_output['text']

        input_texts = []

        for m in Module.knowledge_modules():
            if m.message_name() in knowledge_obs:
                task_type = m.message_name().split('_')[0]
                tokens = m.special_tokens()
                input_text = _input_text + f"{tokens[0]} {knowledge_obs[m.message_name()]} {tokens[1]}"
                input_texts.append(input_text)
                output = self.generate(
                    input_text,
                    num_beams=num_beams,
                    do_sample=do_sample,
                    max_length=max_length,
                    min_length=min_length,
                )
                dialogue_obs[f'{task_type}_dialogue'] = output['text']
                dialogue_obs[f'{task_type}_dialogue_score'] = output['score']

        options, scores = [], []

        for i, m in enumerate(
            [
                Module.SEARCH_DIALOGUE,
                Module.MEMORY_DIALOGUE,
                Module.CONTEXTUAL_DIALOGUE,
            ]
        ):
            options.append(dialogue_obs.get(m.message_name(), ''))
            scores.append(
                (i, dialogue_obs.get(f"{m.message_name()}_score", -float('inf')))
            )
        max_score = max(scores, key=lambda x: x[-1])
        dialogue_obs['text'] = options[max_score[0]]
        dialogue_obs['max_score'] = max_score[1]
        dialogue_obs['knowledge_obs'] = {}

        if return_input_text:
            dialogue_obs['input_text'] = input_texts[max_score[0]]
        return dialogue_obs


    def combine_all_response(
        self,
        sdm: Dict[str, Any],
        mdm: Dict[str, Any],
        sgm: Dict[str, Any],
        mgm_self: Dict[str, Any],
        mgm_partner: Dict[str, Any],
        srm: Dict[str, Any],
        km: Dict[str, Any],
    ) -> Message:


        mems = copy.deepcopy(self.memories)

        ##
        srm = Message(srm)
        reply = Message(
            {
                k : v
                for k, v in srm.items()
                if k not in ['top_docs']  # leave as list for future use cases
            }
        )

        reply.force_set(Module.SEARCH_DECISION.message_name(), sdm.get('text', ''))
        reply.force_set(Module.MEMORY_DECISION.message_name(), mdm.get('text', ''))
        reply.force_set(Module.SEARCH_QUERY.message_name(), sgm.get('text', ''))

        reply.force_set(
            f'{Module.MEMORY_GENERATOR.message_name()}_self',
            mgm_self.get('text', ''),
        )

        reply.force_set(
            f'{Module.MEMORY_GENERATOR.message_name()}_partner',
            mgm_partner.get('text', ''),
        )

        for message, person in zip([mgm_self, mgm_partner], ['self', 'partner']):
            if MemoryUtils.is_valid_memory(
                mems,
                message.get('text', ''),
                MemoryUtils.get_memory_prefix(person, self.model_type),
            ):
                mems = MemoryUtils.add_memory(
                    MemoryUtils.add_memory_prefix(
                        message['text'], person, self.model_type
                    ),
                    mems,
                )

        reply.force_set('memories', mems)

        reply.force_set(
            Module.SEARCH_KNOWLEDGE.message_name(),
            km.get(Module.SEARCH_KNOWLEDGE.message_name(), ''),
        )

        reply.force_set(
            Module.CONTEXTUAL_KNOWLEDGE.message_name(),
            km.get(Module.CONTEXTUAL_KNOWLEDGE.message_name(), ''),
        )

        reply.force_set(
            Module.MEMORY_KNOWLEDGE.message_name(),
            km.get(Module.MEMORY_KNOWLEDGE.message_name(), ''),
        )

        for m in Module:
            # set all the knowledge responses
            if m.is_knowledge():
                reply.force_set(m.message_name(), km.get(m.message_name(), ''))
            # if separate, set all of the dialogue responses as well
            elif m.is_dialogue():
                reply.force_set(m.message_name(), srm.get(m.message_name(), ''))
                reply.force_set(
                    f"{m.message_name()}_score",
                    reply.get(f"{m.message_name()}_score", -float('inf')),
                )

            if not m.skip_search():
                docs = km.get(
                    f'{m.message_name()}_top_docs', [Document("", "", "")]
                )
                reply.force_set(
                    f'{m.message_name()}_doc_titles', [d.get_title() for d in docs]
                )
                reply.force_set(
                    f'{m.message_name()}_doc_content', [d.get_text() for d in docs]
                )
                reply.force_set(
                    f'{m.message_name()}_doc_urls', [d.get_id() for d in docs]
                )

        return reply


    @torch.no_grad()
    def enc_forward_with_docs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        batch_size=1
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        encoder_hidden_states = self.model.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        encoder_hidden_states = encoder_hidden_states.last_hidden_state
        encoder_hidden_states, attention_mask = concat_enc_outs(
            self.model.config.d_model,
            bsz=batch_size,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            right_padded=True,
        )

        return encoder_hidden_states, attention_mask

    def _turn_text_to_vec(self, input_text):
        start_token_id = self.tokenizer.bos_token_id
        end_token_id = self.tokenizer.eos_token_id
        if input_text is not None:
            input_ids = self.tokenizer.encode(input_text)
            max_len = min(len(input_ids), self.config.dataset.truncate - 3)
            input_ids = input_ids[-max_len:]
            input_ids = [start_token_id] + input_ids + [end_token_id, end_token_id]
            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)
            input_ids = self.pad_like_parlai(input_ids)
            bsz = input_ids.shape[0]
        else:
            input_ids = None
            bsz = 1
        return input_ids, bsz

    @torch.no_grad()
    def generate(
        self,
        input_text=None,
        max_length=128,
        min_length=1,
        num_beams=1,
        do_sample=False,
        **kwargs,
    ) -> Dict[str, Any]:
        start_token_id = self.tokenizer.bos_token_id
        end_token_id = self.tokenizer.eos_token_id

        input_ids, bsz = self._turn_text_to_vec(input_text)

        decoder_input_ids = torch.tensor(
            [end_token_id, start_token_id],
            dtype=torch.long,
            device=self.device
        ).repeat(bsz, 1)

        generated_outputs = self.model.generate(
            input_ids,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams,
            do_sample=do_sample,
            # length_penalty = 0.65,
            # no_repeat_ngram_size = 2,
            decoder_input_ids=decoder_input_ids,
            output_scores=True,
            return_dict_in_generate=True,
            **kwargs
        )

        num_samples = len(generated_outputs.sequences)
        if num_samples == 1:
            sequences = generated_outputs.sequences[0].tolist()
            if num_beams > 1:
                sequences_scores = generated_outputs.sequences_scores.item()
            else:
                sequences_scores = None

            generated_text = self.tokenizer.decode(
                sequences,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )
            output = {
                'text': generated_text,
                'score': sequences_scores,
                'top_docs': None,
                'top_doc_scores': None,
            }
        else:
            texts = []
            scores = []
            _scores = torch.exp(generated_outputs.scores[0]).sum(-1)
            for i in range(num_samples):
                sequence = generated_outputs.sequences[i].tolist()
                score = _scores[i].item()
                text = self.tokenizer.decode(
                    sequence,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False
                )
                texts.append(text)
                scores.append(score)

            output = {
                'text': texts,
                'score': scores,
                'top_docs': None,
                'top_doc_scores': None,
            }

        return output


    def update_memory(self, self_message: Message):
        memory_key = Module.MEMORY_GENERATOR.message_name()
        for person in ['self', 'partner']:
            memory_candidate = self_message.get(f"{memory_key}_{person}")
            if not memory_candidate:
                continue
            if MemoryUtils.is_valid_memory(
                self.memories,
                memory_candidate,
                MemoryUtils.get_memory_prefix(person, self.model_type),
            ):
                memory_to_add = MemoryUtils.add_memory_prefix(
                    memory_candidate, person, self.model_type
                )

                self.memories = MemoryUtils.add_memory(memory_to_add, self.memories)

        # update mem usage
        used_memory = self_message.get(Module.MEMORY_KNOWLEDGE.message_name(), '')
        self.memories = MemoryUtils.update_memory_usage(used_memory, self.memories)

    def update_history(self, user_utterance: str, self_message: Message):
        self.histories.append(user_utterance)
        self.histories.append(self_message['text'])
        turn_cnt = len(self.dialogue_turns)
        turn_message = copy.deepcopy(self_message)
        turn_message['human_input'] = user_utterance
        turn_message['current_memories'] = copy.deepcopy(self.memories)
        turn_message['dialogue_history'] = copy.deepcopy(self.histories)
        self.dialogue_turns[turn_cnt] = turn_message

    def _infer(self, user_utterance, full_history, memories,
               return_input_text=False, greedy=False, return_final_samples=False):

        """
        Search Decision
        - Only looks at the final turn of dialogue, generally.
        - Default inference uses greedy decoding.
        - Skip retrieval
        """
        input_text = user_utterance + ' ' + CONST.IS_SEARCH_REQUIRED
        sdm_output = self.generate(
            input_text,
            num_beams=1 if greedy else self.config.beam_size.search_decision,
            do_sample=False if greedy else self.config.do_sample.search_decision
        )

        if self.debug:
            # logging.info('search_decision: {}'.format(sdm_output['text']))
            console.print('search_decision: {}'.format(sdm_output['text']), style='bold')

        """
        Memory Decision
        - Only looks at the final turn of dialogue, along with the store of memories.
        - Default inference uses greedy decoding.
        - Skip retrieval
        """
        input_text = memories + user_utterance + ' ' + CONST.IS_MEMORY_REQUIRED
        mdm_output = self.generate(
            input_text,
            num_beams=1 if greedy else self.config.beam_size.memory_decision,
            do_sample=False if greedy else self.config.do_sample.memory_decision
        )

        if self.debug:
            # logging.info('memory_decision: {}'.format(mdm_output['text']))
            console.print('memory_decision: {}'.format(mdm_output['text']), style='bold')

        """
        {Memory, Search, Contextual} Knowledge Module
        """

        """
        Search Knowledge Response
        - Skip-retrieval FALSE
        """
        if sdm_output['text'] == CONST.DO_SEARCH:
            """
            Search Query Generation Module
            - Used for generating a search query given a dialogue context.
            - Default inference uses greedy decoding.
            - Skip retrieval
            """
            input_text = full_history + user_utterance + ' ' + CONST.GENERATE_QUERY
            sgm_output = self.generate(
                input_text,
                max_length=128,
                min_length=1,
                num_beams=1 if greedy else self.config.beam_size.search_query_generation,
                do_sample=False if greedy else self.config.do_sample.search_query_generation
            )

            if self.debug:
                # logging.info('search_query: {}'.format(sgm_output['text']))
                console.print('search_query: {}'.format(sgm_output['text']), style='bold')

            # input_text = full_history + '\n' + user_utterance + ' ' + CONST.GENERATE_KNOWLEDGE
            search_top_docs, search_top_doc_scores = self.retrieved_doc_scores_from_internet(query=sgm_output['text'])
            skm_output = self.get_knowledge_response(
                knowledge_type='search',
                full_history=full_history,
                user_utterance=user_utterance,
                top_docs=search_top_docs,
                top_doc_scores=search_top_doc_scores,
                num_beams=1 if greedy else self.config.beam_size.search_knowledge,
                do_sample=False if greedy else self.config.do_sample.search_knowledge,
                min_length=1,
            )

        else:
            sgm_output = {
                'text': None
            }
            skm_output = {
                'text': None,
                'score': None,
                'top_docs': None,
                'top_doc_scores': None,
            }

        if self.debug:
            # logging.info('search_knowledge: {}'.format(skm_output['text']))
            console.print('search_knowledge: {}'.format(skm_output['text']), style='bold')

        """
        Memory Knowledge Response
        """
        memory_top_docs, memory_top_doc_scores = self.retrieved_doc_scores_from_memories()
        is_empty_docs = self.check_empty_doc(memory_top_docs[0])

        if mdm_output['text'] == CONST.DO_ACCESS_MEMORY and not is_empty_docs:
            # input_text = full_history + '\n' + user_utterance + ' ' + CONST.ACCESS_MEMORY
            mkm_output = self.get_knowledge_response(
                knowledge_type='memory',
                full_history=full_history,
                user_utterance=user_utterance,
                top_docs=memory_top_docs,
                top_doc_scores=memory_top_doc_scores,
                num_beams=1 if greedy else self.config.beam_size.memory_knowledge,
                do_sample=False if greedy else self.config.do_sample.memory_knowledge,
                min_length=1,
            )
        else:
            mkm_output = {
                'text': None,
                'score': None,
                'top_docs': None,
                'top_doc_scores': None,
            }

        if self.debug:
            # logging.info('memory_knowledge: {}'.format(mkm_output['text']))
            console.print('memory_knowledge: {}'.format(mkm_output['text']), style='bold')

        """
        Contextual Knowledge Response
        """
        if sdm_output['text'] != CONST.DO_SEARCH and mdm_output['text'] != CONST.DO_ACCESS_MEMORY:
            input_text = full_history + user_utterance + ' ' + CONST.EXTRACT_ENTITY
            ckm_output = self.generate(
                input_text,
                max_length=128,
                min_length=1,
                num_beams=1 if greedy else self.config.beam_size.contextual_knowledge,
                do_sample=False if greedy else self.config.do_sample.contextual_knowledge
            )

        else:
            ckm_output = {
                'text': None,
                'score': None,
                'top_docs': None,
                'top_doc_scores': None,
            }

        if self.debug:
            # logging.info('contextual_knowledge: {}'.format(ckm_output['text']))
            console.print('contextual_knowledge: {}'.format(ckm_output['text']), style='bold')

        """
        Combine Knowledge Response
        - Three possible ways depending on agent.combine_knowledge_response
        - 1. combined
        - 2. separate
        - 3. both
        """

        if self.knowledge_conditioning == 'combined':
            if return_final_samples:
                knowledge_output = self.combine_knowledge_response(
                    full_history=full_history,
                    user_utterance=user_utterance,
                    skm_output=skm_output,
                    mkm_output=mkm_output,
                    ckm_output=ckm_output,
                    num_beams=1 if greedy else self.config.beam_size.knowledge_response,
                    do_sample=True,
                    num_return_sequences=return_final_samples,
                    return_input_text=return_input_text,
                )
            else:
                knowledge_output = self.combine_knowledge_response(
                    full_history=full_history,
                    user_utterance=user_utterance,
                    skm_output=skm_output,
                    mkm_output=mkm_output,
                    ckm_output=ckm_output,
                    num_beams=1 if greedy else self.config.beam_size.knowledge_response,
                    do_sample=False if greedy else self.config.do_sample.knowledge_response,
                    return_input_text=return_input_text,
                )

        elif self.knowledge_conditioning == 'separate':
            knowledge_output = self.separate_knowledge_response(
                full_history=full_history,
                user_utterance=user_utterance,
                skm_output=skm_output,
                mkm_output=mkm_output,
                ckm_output=ckm_output,
                num_beams=1 if greedy else self.config.beam_size.dialogue_response,
                do_sample=False if greedy else self.config.do_sample.dialogue_response,
                return_input_text=return_input_text,
            )

        else:
            reply_combined = self.combine_knowledge_response(
                full_history=full_history,
                user_utterance=user_utterance,
                skm_output=skm_output,
                mkm_output=mkm_output,
                ckm_output=ckm_output,
                num_beams=1 if greedy else self.config.beam_size.knowledge_response,
                do_sample=False if greedy else self.config.do_sample.knowledge_response,
                return_input_text=return_input_text,
            )

            reply_separate = self.separate_knowledge_response(
                full_history=full_history,
                user_utterance=user_utterance,
                skm_output=skm_output,
                mkm_output=mkm_output,
                ckm_output=ckm_output,
                num_beams=1 if greedy else self.config.beam_size.dialogue_response,
                do_sample=False if greedy else self.config.do_sample.dialogue_response,
                return_input_text=return_input_text,
            )

            if reply_separate['max_score'] > reply_combined['score']:
                reply_combined['text'] = reply_separate['text']
                reply_combined['score'] = reply_separate['max_score']
                if return_input_text:
                    reply_combined['input_text'] = reply_separate['input_text']

            knowledge_output = reply_combined

        """
        Memory Generation Module
        - Used for generating a new memory to write to the long-term memory store.
        - Conditioned on the last turn of the dialogue context.
        - Default inference uses beam search in the 3B model and greedy decoding in the 30B/175B models.
        - Skip retrieval
        """
        input_text = full_history + user_utterance + ' ' + CONST.GENERATE_MEMORY
        mgm_output = self.generate(
            input_text,
            max_length=128,
            min_length=1,
            num_beams=1 if greedy else self.config.beam_size.memory_generation,
            do_sample=False if greedy else self.config.do_sample.memory_generation
        )

        if self.debug:
            # logging.info('partner_memory: {}'.format(mgm_output['text']))
            console.print('partner_memory: {}'.format(mgm_output['text']), style='bold')

        return (full_history,
                sdm_output, mdm_output, sgm_output,
                skm_output, mkm_output, ckm_output,
                mgm_output, knowledge_output)

    def act(self, user_utterance, do_print=True, update_history=True):
        full_history = self.get_dialogue_history_str()
        memories = self.get_memory_str()

        (full_history, sdm_output, mdm_output, sgm_output,
         skm_output, mkm_output, ckm_output,
         mgm_output, knowledge_output) = self._infer(user_utterance, full_history, memories)

        """
        Generate New Memories
        - Used for generating a new memory to write to the long-term memory store.
        - Conditioned on the last turn of the dialogue context.
        - Default inference uses beam search in the 3B model and greedy decoding in the 30B/175B models.
        """
        input_text = knowledge_output['text'] + ' ' + CONST.GENERATE_MEMORY
        self_mgm_output = self.generate(
            input_text,
            num_beams=self.config.beam_size.memory_generation,
            do_sample=self.config.do_sample.memory_generation
        )

        if self.debug:
            # logging.info('self_memory: {}'.format(self_mgm_output['text']))
            console.print('self_memory: {}'.format(self_mgm_output['text']), style='bold')

        """
        Combine them all in the srm batch reply.
        """
        self_message = self.combine_all_response(
            sdm=sdm_output,
            mdm=mdm_output,
            sgm=sgm_output,
            mgm_self=self_mgm_output,
            mgm_partner=mgm_output,
            srm=knowledge_output,
            km=knowledge_output['knowledge_obs'],
        )

        """
        Update memories
        """
        # self.update_memory(self_message)
        if update_history:
            self.update_history(user_utterance, self_message)

        if self.debug:
            # logging.info('[persona_dict]')
            console.print('[persona_dict]')
            for k, v in self.memories.items():
                # logging.info(' ㄴ {}'.format(k))
                console.print(' ㄴ: {}'.format(k), style='bold')

        if do_print:
            rich_print(f"[cyan][Blenderbot3]_{self.rank}[/] [bold]{self_message['text']}")

        return self_message

    def dummy_act(self, user_utterance, do_print=True, history=None, greedy=False, memory_label=None,
                  return_loss=False, lm_labels=None, return_final_samples=False):
        if history:
            full_history = history
        else:
            full_history = self.get_dialogue_history_str()
        memories = self.get_memory_str()

        (full_history, sdm_output, mdm_output, sgm_output,
         skm_output, mkm_output, ckm_output,
         mgm_output, knowledge_output) = self._infer(user_utterance, full_history, memories,
                                                     greedy=greedy, return_input_text=return_loss,
                                                     return_final_samples=return_final_samples)

        """
        Generate New Memories
        - Used for generating a new memory to write to the long-term memory store.
        - Conditioned on the last turn of the dialogue context.
        - Default inference uses beam search in the 3B model and greedy decoding in the 30B/175B models.
        """
        if memory_label:
            input_text = memory_label + ' ' + CONST.GENERATE_MEMORY
        else:
            if type(knowledge_output['text']) == str:
                input_text = knowledge_output['text'] + ' ' + CONST.GENERATE_MEMORY
            else:
                # this thing can be a list
                input_text = knowledge_output['text'][0] + ' ' + CONST.GENERATE_MEMORY
        self_mgm_output = self.generate(
            input_text,
            num_beams=1 if greedy else self.config.beam_size.memory_generation,
            do_sample=False if greedy else self.config.do_sample.memory_generation
        )

        if self.debug:
            # logging.info('self_memory: {}'.format(self_mgm_output['text']))
            console.print('self_memory: {}'.format(self_mgm_output['text']), style='bold')

        """
        Combine them all in the srm batch reply.
        """
        self_message = self.combine_all_response(
            sdm=sdm_output,
            mdm=mdm_output,
            sgm=sgm_output,
            mgm_self=self_mgm_output,
            mgm_partner=mgm_output,
            srm=knowledge_output,
            km=knowledge_output['knowledge_obs'],
        )
        self_message.force_set('memories', memories)
        if return_loss:
            assert lm_labels is not None
            """
            Calculate loss and metric here
            """
            metrics = self.compute_metrics(
                knowledge_output['input_text'],
                lm_labels.to(self.device),
                knowledge_output['text'],
            )
            metric_loss = metrics['metric_loss']
            metric_target_tokens = metrics['metric_target_tokens']
            bleu = metrics['bleu']
            f1 = metrics['f1']
            self_message.force_set('loss', metric_loss)
            self_message.force_set('bleu', bleu)
            self_message.force_set('f1', f1)
            self_message.force_set('tokens', metric_target_tokens)

        if self.debug:
            # logging.info('[persona_dict]')
            console.print('[persona_dict]')
            for k, v in self.memories.items():
                # logging.info(' ㄴ {}'.format(k))
                console.print(' ㄴ: {}'.format(k), style='bold')

        if do_print:
            rich_print(f"[cyan][Blenderbot3]_{self.rank}[/] [bold]{self_message['text']}")

        return self_message

    def evaluate(self,
                 input_ids=None,
                 lm_labels=None,
                 task_ids=None,
                 skip_retrieval_vec=None,
                 is_doc_added=None,
                 max_n_docs=None,
                 **kwargs
                 ):

        assert len(input_ids) == 1, 'Assert batch size of 1 for now'
        user_utterance = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]

        full_history = self.get_dialogue_history_str()
        memories = self.get_memory_str()

        (full_history, sdm_output, mdm_output, sgm_output,
         skm_output, mkm_output, ckm_output,
         mgm_output, knowledge_output) = self._infer(user_utterance, full_history, memories, return_input_text=True)

        """
        Calculate loss and metric here
        """
        metrics = self.compute_metrics(
            knowledge_output['input_text'],
            lm_labels,
            knowledge_output['text'],
        )

        """
        Replace generated text with ground truth
        """
        knowledge_output['generated_text'] = knowledge_output['text']
        knowledge_output['text'] = self.tokenizer.batch_decode(lm_labels, skip_special_tokens=True)[0]

        """
        Generate New Memories
        - Used for generating a new memory to write to the long-term memory store.
        - Conditioned on the last turn of the dialogue context.
        - Default inference uses beam search in the 3B model and greedy decoding in the 30B/175B models.
        """
        input_text = knowledge_output['text'] + ' ' + CONST.GENERATE_MEMORY
        self_mgm_output = self.generate(
            input_text,
            num_beams=self.config.beam_size.memory_generation
        )

        if self.debug:
            # logging.info('self_memory: {}'.format(self_mgm_output['text']))
            console.print('self_memory: {}'.format(self_mgm_output['text']), style='bold')

        """
        Combine them all in the srm batch reply.
        """
        self_message = self.combine_all_response(
            sdm=sdm_output,
            mdm=mdm_output,
            sgm=sgm_output,
            mgm_self=self_mgm_output,
            mgm_partner=mgm_output,
            srm=knowledge_output,
            km=knowledge_output['knowledge_obs'],
        )

        """
        Update memories
        """
        self.update_memory(self_message)
        self.update_history(user_utterance, self_message)

        if self.debug:
            # logging.info('[persona_dict]')
            console.print('[persona_dict]')
            for k, v in self.memories.items():
                # logging.info(' ㄴ {}'.format(k))
                console.print(' ㄴ: {}'.format(k), style='bold')

        rich_print("[cyan][Blenderbot3][/] [bold]{}".format(self_message['text']))

        # return self_message
        return metrics

    def postprocess_dialogue_turns(self):
        dialogue_turns = copy.deepcopy(self.dialogue_turns)
        dialogue_turns = {k: v for k, v in dialogue_turns.items()}
        for turn_cnt, msg in dialogue_turns.items():
            msg.pop('knowledge_obs')
            new_msg = {}
            for k, v in msg.items():
                if k == 'text':
                    new_msg['bot_response'] = v
                    continue
                if v == -float("inf"):
                    new_msg[k] = -100
                    continue
                new_msg[k] = v
            dialogue_turns[turn_cnt] = new_msg

        return dialogue_turns

    def save_chat(self):
        now = datetime.now()

        sub_dir = now.strftime("%Y/%m/%d %H:%M")
        sub_dir = sub_dir.replace('/', '_')
        sub_dir = sub_dir.replace(':', '_')
        sub_dir = sub_dir.replace(' ', '_')

        save_dir = os.path.join(self.config.chat_save_dir, sub_dir)
        os.makedirs(save_dir, exist_ok=True)

        save_path = os.path.join(save_dir, 'dialogue_turns.json')
        with open(save_path, 'w') as f:
            dialogue_turns = self.postprocess_dialogue_turns()
            json.dump(dialogue_turns, f)

        console.print("Save the chat history to {}".format(save_path, style="italic red"))


if __name__ == "__main__":
    # opt
    opt = build_config()
    opt.server_port = opt.server_port if opt.server_port != -1 else os.environ['PORT1']
    opt.dataset.search_server = 'http://127.0.0.1:{}'.format(opt.server_port)
    opt.dataset.skip_retrieval_token = 'no_passages_used'

    # init agent
    set_seed(opt.seed)
    agent = BB3InferenceAgent(opt)
    if opt.trainer.fp16:
        agent.model.half()
    agent.model.eval().to(agent.device)

    # interaction loop begins
    console.print("Enter <exit> to terminate the conversation", style="italic red")

    while True:
        user_message = agent.receive_inputs()
        user_utterance = user_message['text']
        if user_utterance == '<exit>':
            console.print("Terminate the conversation", style="italic red")
            break

        bot_response = agent.act(user_utterance)

    # save dialogue history
    if opt.save_chat:
        agent.save_chat()