"""Dialog agent interface and classes."""

from abc import ABC, abstractmethod
import logging
from convlab.nlu import NLU
from convlab.dst import DST
from convlab.policy import Policy
from convlab.nlg import NLG
from copy import deepcopy
import time
import pdb
from pprint import pprint


class Agent(ABC):
    """Interface for dialog agent classes."""

    @abstractmethod
    def __init__(self, name: str):
        self.name = name

    @abstractmethod
    def response(self, observation):
        """Generate agent response given user input.

        The data type of input and response can be either str or list of tuples, condition on the form of agent.

        Example:
            If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and
            type(response) == list of tuples.
        Args:
            observation (str or list of tuples):
                The input to the agent.
        Returns:
            response (str or list of tuples):
                The response generated by the agent.
        """
        pass

    @abstractmethod
    def init_session(self, **kwargs):
        """Reset the class variables to prepare for a new session."""
        pass


class PipelineAgent(Agent):
    """Pipeline dialog agent base class, including NLU, DST, Policy and NLG.

    The combination modes of pipeline agent modules are flexible. The only thing you have to make sure is that
    the API of agents are matched.

    Example:
        If agent A is (nlu, tracker, policy), then the agent B should be like (tracker, policy, nlg) to ensure API
        matching.
    The valid module combinations are as follows:
           =====   =====    ======  ===     ==      ===
            NLU     DST     Policy  NLG     In      Out
           =====   =====    ======  ===     ==      ===
            \\+      \\+        \\+    \\+      nl      nl
             o      \\+        \\+    \\+      da      nl
             o      \\+        \\+     o      da      da
            \\+      \\+        \\+     o      nl      da
             o       o        \\+     o      da      da
           =====   =====    ======  ===     ==      ===
    """

    def __init__(
        self,
        nlu: NLU,
        dst: DST,
        policy: Policy,
        nlg: NLG,
        name: str,
        return_semantic_acts: bool = False,
        debug: bool = False,
    ):
        """The constructor of PipelineAgent class.

        Here are some special combination cases:

            1. If you use word-level DST (such as Neural Belief Tracker), you should set the nlu_model paramater \
             to None. The agent will combine the modules automitically.

            2. If you want to aggregate DST and Policy as a single module, set tracker to None.

        Args:
            nlu (NLU):
                The natural language understanding module of agent.

            dst (DST):
                The dialog state tracker of agent.

            policy (Policy):
                The dialog policy module of agent.

            nlg (NLG):
                The natural language generator module of agent.

        """
        super(PipelineAgent, self).__init__(name=name)
        assert self.name in ["user", "sys"]
        self.opponent_name = "user" if self.name == "sys" else "sys"
        self.nlu = nlu
        self.dst = dst
        self.policy = policy
        self.nlg = nlg
        self.return_semantic_acts = return_semantic_acts
        self.debug = debug

        self.init_session()
        self.agent_saves = []
        self.history = []
        self.turn = 0
        self.output_action = []

        # logging.info("Pipeline Agent info_dict check")
        if hasattr(self.nlu, "info_dict") == False:
            logging.warning("nlu info_dict is not initialized")
        if hasattr(self.dst, "info_dict") == False:
            logging.warning("dst info_dict is not initialized")
        if hasattr(self.policy, "info_dict") == False:
            logging.warning("policy info_dict is not initialized")
        if hasattr(self.nlg, "info_dict") == False:
            logging.warning("nlg info_dict is not initialized")
        # logging.info("Done")

        self.response_type = self.get_response_type()

    def get_response_type(self):
        if hasattr(self.policy, "get_conduct"):
            return "conduct_system"
        if self.name == "user" and hasattr(self.policy, "system_utterance"):
            if self.policy.system_utterance:
                return "utterance_to_user"
        if self.name == "user" and hasattr(self.policy, "need_conduct"):
            if self.policy.need_conduct:
                return "need_conduct_user"
        if self.name == "user" and hasattr(self.policy, "system_action_utterance"):
            if self.policy.system_action_utterance:
                return "action_utterance_to_user"
        return "default"

    def state_replace(self, agent_state):
        """
        this interface is reserved to replace all interal states of agent
        the code snippet example below is for the scenario when the agent state only depends on self.history and self.dst.state
        """
        self.history = deepcopy(agent_state["history"])
        self.dst.state = deepcopy(agent_state["dst_state"])

    def state_return(self):
        """
        this interface is reserved to return all interal states of agent
        the code snippet example below is for the scenario when the agent state only depends on self.history and self.dst.state
        """
        agent_state = {}
        agent_state["history"] = deepcopy(self.history)
        agent_state["dst_state"] = deepcopy(self.dst.state)

        return agent_state

    def _do_nlu(self, observation):
        if self.nlu is None:
            return observation
        return self.nlu.predict(observation, context=[x[1] for x in self.history[:-1]])

    def _do_dst(self, observation):
        if self.dst is None:
            return observation

        if self.name == "sys":
            self.dst.state["user_action"] = observation
        else:
            self.dst.state["system_action"] = observation

        return self.dst.update(observation)

    def _do_policy(self, state):
        return self.policy.predict(state)

    def _do_nlg(self, action):
        if self.nlg is None:
            return action
        return self.nlg.generate(action)

    def _update_dst(self, model_response):
        self.dst.state["history"].append([self.name, model_response])
        if self.name == "sys":
            self.dst.state["system_action"] = self.output_action

            if type(self.output_action) == list:
                for intent, domain, slot, value in self.output_action:
                    if intent.lower() == "book":
                        self.dst.state["booked"][domain] = [{slot: value}]
        else:
            self.dst.state["user_action"] = self.output_action
            # user dst is also updated by itself
            self.dst.update(self.output_action)

    def _complex_response(self, observation, action=None, conduct="default"):
        # get dialog act
        if self.name == "sys":
            if self.nlu is not None:
                self.input_action = self.nlu.predict(
                    observation, context=[x[1] for x in self.history[:-1]]
                )
            else:
                self.input_action = observation
        else:
            if self.nlu is not None:
                # self.input_action_eval = self.nlu.predict(
                #     observation, context=[x[1] for x in self.history[:-1]])

                self.input_action = self.nlu.predict(
                    observation, context=[x[1] for x in self.history[:-1]]
                )
                # print("system semantic action: ", self.input_action)
            else:
                # self.input_action = observation
                if action is not None:
                    self.input_action = action
                else:
                    self.input_action = observation
        # get rid of reference problem
        self.input_action = deepcopy(self.input_action)
        if self.debug:
            print(f"({self.name}) input_action: ", self.input_action)

        # get state
        if self.dst is not None:
            if self.name == "sys":
                self.dst.state["user_action"] = self.input_action
            else:
                self.dst.state["system_action"] = self.input_action
            state = self.dst.update(self.input_action)
        else:
            state = self.input_action

        state = deepcopy(state)  # get rid of reference problem
        # if self.debug:
        #     print(f"({self.name}) state: ", state)
        # get action
        # get rid of reference problem
        if self.name == "sys":
            self.output_action = deepcopy(self.policy.predict(state))

        else:
            # check this part care fully!
            if conduct != "default":
                self.output_action = deepcopy(
                    self.policy.predict(state, sys_conduct=conduct)
                )
            elif action is not None:
                self.output_action = deepcopy(
                    self.policy.predict(sys_act=action, sys_utt=observation)
                )
            else:
                self.output_action = deepcopy(self.policy.predict(state))

        if self.debug:
            print(f"({self.name}) action: ", self.output_action)

        # get model response
        if self.nlg is not None:
            model_response = self.nlg.generate(self.output_action)
        else:
            model_response = self.output_action
        return model_response

    def _before_prediction(self, observation):
        if self.dst is not None:
            # [['sys', sys_utt], ['user', user_utt],...]
            self.dst.state["history"].append([self.opponent_name, observation])
        self.history.append([self.opponent_name, observation])

    def _after_prediction(self, model_response):
        if self.dst is not None:
            self._update_dst(model_response)

        self.history.append([self.name, model_response])

        self.turn += 1
        self.agent_saves.append(self.save_info())

    def _default_response(self, observation):
        # nlu
        self.input_action = deepcopy(self._do_nlu(observation))

        # dst
        self.state = deepcopy(self._do_dst(self.input_action))

        # policy
        self.output_action = deepcopy(self._do_policy(self.state))

        # nlg
        model_response = self._do_nlg(self.output_action)

        return model_response

    def _conduct_do_policy(self, state):
        output_action = deepcopy(self._do_policy(state))
        conduct = "neutral"
        if hasattr(self.policy, "get_conduct"):
            conduct = self.policy.get_conduct()
        return output_action, conduct

    def _conduct_response(self, observation, emotion=None):
        # nlu
        self.input_action = deepcopy(self._do_nlu(observation))

        # dst
        self.state = deepcopy(self._do_dst(self.input_action))
        # overwrite emotion if emotion is provided and observation is not a string
        if emotion is not None and type(observation) is not str:
            self.state["user_emotion"] = emotion

        # policy
        self.output_action, conduct = self._conduct_do_policy(self.state)

        # nlg
        if hasattr(self.nlg, "require_conduct"):
            model_response = self.nlg.generate(self.output_action, conduct, observation)
        else:
            model_response = self._do_nlg(self.output_action)

        return model_response

    def _utterance_semantic_response(self, utterance, action):
        if self.name != "user":
            raise Exception("This function is only for user agent")
        # nlu
        self.input_action = action
        if action is None:
            self.input_action = deepcopy(self._do_nlu(utterance))

        # dst
        self.state = deepcopy(self._do_dst(self.input_action))

        # policy
        self.output_action = deepcopy(
            self.policy.predict(sys_act=self.input_action, sys_utt=utterance)
        )

        # nlg
        model_response = self._do_nlg(self.output_action)

        return model_response

    def _user_conduct_response(self, observation, conduct):
        if self.name != "user":
            raise Exception("This function is only for user agent")
        # nlu
        self.input_action = deepcopy(self._do_nlu(observation))

        # dst
        self.state = deepcopy(self._do_dst(self.input_action))

        self.output_action = deepcopy(
            self.policy.predict(self.state, sys_conduct=conduct)
        )

        # nlg
        model_response = self._do_nlg(self.output_action)

        return model_response

    def response(self, observation, **kwargs):
        """Generate agent response using the agent modules."""
        # Please ensure this is consistent with deploy.server.ServerCtrl._turn()
        self._before_prediction(observation)
        # response generation
        if self.response_type == "conduct_system":
            emotion = kwargs.get("emotion", None)
            model_response = self._conduct_response(observation, emotion)
        elif self.response_type == "utterance_to_user":
            action = kwargs.get("action", None)
            model_response = self._utterance_semantic_response(observation, action)
        elif self.response_type == "action_utterance_to_user":
            action = kwargs.get("action", None)
            model_response = self._complex_response(observation, action)
        elif self.response_type == "need_conduct_user":
            conduct = kwargs.get("conduct", "Neutral")
            model_response = self._user_conduct_response(observation, conduct)
        else:
            # default response generation
            model_response = self._default_response(observation)

        self._after_prediction(model_response)

        if self.return_semantic_acts:
            return self.output_action
        return model_response

    def save_info(self):
        try:
            infos = {}
            if hasattr(self.nlu, "info_dict"):
                infos["nlu"] = self.nlu.info_dict
            if hasattr(self.dst, "info_dict"):
                infos["dst"] = self.dst.info_dict
            if hasattr(self.policy, "info_dict"):
                infos["policy"] = self.policy.info_dict
            if hasattr(self.nlg, "info_dict"):
                infos["nlg"] = self.nlg.info_dict
            # nlu_info = self.agents[agent_id].nlu.info
            # policy_info = self.agents[agent_id].policy.info
            # nlg_info = self.agents[agent_id].nlg.info
            # infos = {"nlu": nlu_info, "policy": policy_info, "nlg": nlg_info}
            # infos = {"nlu": self.turn, "policy": "policy", "nlg": "nlg"}
        except:
            infos = None

        return infos

    def is_terminated(self):
        if hasattr(self.policy, "is_terminated"):
            return self.policy.is_terminated()
        return None

    def get_reward(self):
        if hasattr(self.policy, "get_reward"):
            return self.policy.get_reward()
        return None

    def init_session(self, **kwargs):
        """Init the attributes of DST and Policy module."""
        if self.nlu is not None:
            self.nlu.init_session()
        if self.dst is not None:
            self.dst.init_session()

            if self.name == "sys":
                self.dst.state["history"].append([self.name, "null"])  # TODO: ??

        if self.policy is not None:
            self.policy.init_session(**kwargs)
        if self.nlg is not None:
            self.nlg.init_session()
        self.history = []

    # def get_in_da_eval(self):
    #     return self.input_action_eval

    def get_in_da(self):
        return self.input_action

    def get_out_da(self):
        if self.name == "user" and hasattr(self.policy, "semantic_action"):
            return self.policy.semantic_action
        return self.output_action


# Agent for Dialogue Server for HHU Dialcrowd. It is an extension of PipelineAgent with minor modification.


class DialogueAgent(Agent):
    """Pipeline dialog agent base class, including NLU, DST, Policy and NLG."""

    def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str = "sys"):
        """The constructor of DialogueAgent class.

        Here are some special combination cases:

            1. If you use word-level DST (such as Neural Belief Tracker), you should set the nlu_model parameter \
             to None. The agent will combine the modules automatically.

            2. If you want to aggregate DST and Policy as a single module, set tracker to None.

        Args:
            nlu (NLU):
                The natural language understanding module of agent.

            dst (DST):
                The dialog state tracker of agent.

            policy (Policy):
                The dialog policy module of agent.

            nlg (NLG):
                The natural language generator module of agent.
        """

        super(DialogueAgent, self).__init__(name=name)
        assert self.name in ["sys"]
        self.opponent_name = "user"
        self.nlu = nlu
        self.dst = dst
        self.policy = policy
        self.nlg = nlg
        self.module_names = ["nlu", "dst", "policy", "nlg"]
        self.init_session()
        self.history = []
        self.session_id = None
        self.ENDING_DIALOG = False
        self.USER_RATED = False
        self.USER_GOAL_ACHIEVED = None
        self.taskID = None
        self.feedback = None
        self.requested_feedback = False
        self.sys_state_history = []
        self.sys_action_history = []
        self.sys_utterance_history = []
        self.sys_output_history = []
        self.action_mask_history = []
        self.action_prob_history = []
        self.turn = 0
        self.agent_saves = {
            "session_id": None,
            "agent_id": None,
            "user_id": None,
            "timestamp": None,
            "dialogue_info": [],
            "dialogue_info_fundamental": [],
        }
        self.initTime = int(time.time())
        self.lastUpdate = int(time.time())

        logging.info("Dialogue Agent info_dict check")
        if not hasattr(self.nlu, "info_dict"):
            logging.warning("nlu info_dict is not initialized")
        if not hasattr(self.dst, "info_dict"):
            logging.warning("dst info_dict is not initialized")
        if not hasattr(self.policy, "info_dict"):
            logging.warning("policy info_dict is not initialized")
        if not hasattr(self.nlg, "info_dict"):
            logging.warning("nlg info_dict is not initialized")

    def response(self, observation):
        """Generate agent response using the agent modules."""

        self.sys_utterance_history.append(observation)
        fundamental_info = {"observation": observation}

        if self.dst is not None:
            self.dst.state["history"].append(
                [self.opponent_name, observation]
            )  # [['sys', sys_utt], ['user', user_utt],...]
        self.history.append([self.opponent_name, observation])
        # get dialog act
        if self.nlu is not None:
            self.input_action = self.nlu.predict(
                observation, context=[x[1] for x in self.history[:-1]]
            )
        else:
            self.input_action = observation
        # get rid of reference problem
        self.input_action = deepcopy(self.input_action)
        fundamental_info["input_action"] = self.input_action

        # get state
        if self.dst is not None:
            self.dst.state["user_action"] = self.input_action
            state = self.dst.update(self.input_action)
        else:
            state = self.input_action

        state = deepcopy(state)  # get rid of reference problem
        fundamental_info["state"] = state
        self.sys_state_history.append(state)

        # get action
        # get rid of reference problem
        self.output_action = deepcopy(self.policy.predict(state))
        if hasattr(self.policy, "last_action"):
            self.sys_action_history.append(self.policy.last_action)
        else:
            self.sys_action_history.append(self.output_action)

        fundamental_info["output_action"] = self.output_action

        if hasattr(self.policy, "prob"):
            self.action_prob_history.append(self.policy.prob)

        # get model response
        if self.nlg is not None:
            model_response = self.nlg.generate(self.output_action)
        else:
            model_response = self.output_action

        self.sys_output_history.append(model_response)

        fundamental_info["model_response"] = model_response

        if self.dst is not None:
            self.dst.state["history"].append([self.name, model_response])
            self.dst.state["system_action"] = self.output_action
            # If system takes booking action add booking info to the 'book-booked' section of the belief state
            if type(self.output_action) == list:
                for intent, domain, slot, value in self.output_action:
                    if intent == "book":
                        self.dst.state["booked"][domain] = [{slot: value}]
        self.history.append([self.name, model_response])

        self.turn += 1
        self.lastUpdate = int(time.time())

        self.agent_saves["dialogue_info_fundamental"].append(fundamental_info)
        self.agent_saves["dialogue_info"].append(self.get_info())
        return model_response

    def get_info(self):
        info_dict = {}
        for name in self.module_names:
            module = getattr(self, name)
            module_info = getattr(module, "info_dict", None)
            info_dict[name] = deepcopy(module_info)

        return info_dict

    def is_terminated(self):
        if hasattr(self.policy, "is_terminated"):
            return self.policy.is_terminated()
        return None

    def retrieve_reward(self):
        rewards = [1] * len(self.sys_state_history)
        for turn in self.feedback:
            turn_number = int((int(turn) - 2) / 2)
            if turn_number >= len(self.sys_state_history):
                continue
            # TODO possibly use text here to check whether rating belongs to the right utterance of the system
            text = self.feedback[turn]["text"]
            rating = self.feedback[turn]["isGood"]
            rewards[turn_number] = int(rating)
        return rewards

    def get_reward(self):
        if hasattr(self.policy, "get_reward"):
            return self.policy.get_reward()
        return None

    def init_session(self):
        """Init the attributes of DST and Policy module."""
        if self.nlu is not None:
            self.nlu.init_session()
        if self.dst is not None:
            self.dst.init_session()
            self.dst.state["history"].append([self.name, "null"])
        if self.policy is not None:
            self.policy.init_session()
        if self.nlg is not None:
            self.nlg.init_session()
        self.history = []

    def get_in_da(self):
        return self.input_action

    def get_out_da(self):
        return self.output_action

    def print_ending_agent_summary(self):
        print("session_id")
        print(self.session_id)
        print("taskID")
        print(self.taskID)
        print("USER_GOAL_ACHIEVED")
        print(self.USER_GOAL_ACHIEVED)
        print("sys_state_history")
        print(self.sys_state_history)
        print("sys_action_history")
        print(self.sys_action_history)

    def is_inactive(self):
        currentTime = int(time.time())
        return (
            currentTime - self.initTime >= 600 and currentTime - self.lastUpdate >= 60
        )


# shutong: fixing dialcrowd_server not taking correct input etc.


class EmoLoopDialogueAgent(Agent):
    """Pipeline dialog agent base class, including NLU, DST, Policy and NLG."""

    def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str = "sys"):
        """The constructor of DialogueAgent class.

        Here are some special combination cases:

            1. If you use word-level DST (such as Neural Belief Tracker), you should set the nlu_model parameter \
             to None. The agent will combine the modules automatically.

            2. If you want to aggregate DST and Policy as a single module, set tracker to None.

        Args:
            nlu (NLU):
                The natural language understanding module of agent.

            dst (DST):
                The dialog state tracker of agent.

            policy (Policy):
                The dialog policy module of agent.

            nlg (NLG):
                The natural language generator module of agent.
        """

        super(EmoLoopDialogueAgent, self).__init__(name=name)
        assert self.name in ["sys"]
        self.opponent_name = "user"
        self.nlu = nlu
        self.dst = dst
        self.policy = policy
        self.nlg = nlg
        self.module_names = ["nlu", "dst", "policy", "nlg"]
        self.init_session()
        self.history = []
        self.session_id = None
        self.ENDING_DIALOG = False
        self.USER_RATED = False
        self.USER_GOAL_ACHIEVED = None
        self.taskID = None
        self.feedback = None
        self.requested_feedback = False
        self.sys_state_history = []
        self.sys_action_history = []
        self.sys_utterance_history = []
        self.sys_output_history = []
        self.action_mask_history = []
        self.action_prob_history = []
        self.turn = 0
        self.agent_saves = {
            "session_id": None,
            "agent_id": None,
            "user_id": None,
            "timestamp": None,
            "dialogue_info": [],
            "dialogue_info_fundamental": [],
        }
        self.initTime = int(time.time())
        self.lastUpdate = int(time.time())

        logging.info("Dialogue Agent info_dict check")
        if not hasattr(self.nlu, "info_dict"):
            logging.warning("nlu info_dict is not initialized")
        if not hasattr(self.dst, "info_dict"):
            logging.warning("dst info_dict is not initialized")
        if not hasattr(self.policy, "info_dict"):
            logging.warning("policy info_dict is not initialized")
        if not hasattr(self.nlg, "info_dict"):
            logging.warning("nlg info_dict is not initialized")

    def response(self, observation):
        """Generate agent response using the agent modules."""

        self.sys_utterance_history.append(observation)
        fundamental_info = {"observation": observation}

        if self.dst is not None:
            self.dst.state["history"].append(
                [self.opponent_name, observation]
            )  # [['sys', sys_utt], ['user', user_utt],...]
        self.history.append([self.opponent_name, observation])
        # get dialog act
        if self.nlu is not None:
            self.input_action = self.nlu.predict(
                observation, context=[x[1] for x in self.history[:-1]]
            )
        else:
            self.input_action = observation
        # get rid of reference problem
        self.input_action = deepcopy(self.input_action)
        fundamental_info["input_action"] = self.input_action

        # get state
        if self.dst is not None:
            self.dst.state["user_action"] = self.input_action
            state = self.dst.update(self.input_action)

            print("------state------")
            pprint(state)
            print("------end of state------")
            if hasattr(self.dst, "get_emotion"):
                emotion = self.dst.get_emotion()
            else:
                emotion = "neutral"
        else:
            state = self.input_action

        state = deepcopy(state)  # get rid of reference problem
        fundamental_info["state"] = state
        self.sys_state_history.append(state)

        # get action
        # get rid of reference problem
        self.output_action = deepcopy(self.policy.predict(state))
        if hasattr(self.policy, "get_conduct"):
            conduct = self.policy.get_conduct()
        else:
            conduct = "neutral"

        print("------sys_action------")
        pprint(self.output_action)
        pprint(self.policy.get_conduct())
        print("------end of sys_action------")
        if hasattr(self.policy, "last_action"):
            self.sys_action_history.append(self.policy.last_action)
        else:
            self.sys_action_history.append(self.output_action)

        fundamental_info["output_action"] = self.output_action

        if hasattr(self.policy, "prob"):
            self.action_prob_history.append(self.policy.prob)

        # get model response
        if self.nlg is not None:
            model_response = self.nlg.generate(
                self.output_action, conduct=conduct, user_utt=observation
            )
        else:
            model_response = self.output_action

        self.sys_output_history.append(model_response)

        fundamental_info["model_response"] = model_response

        if self.dst is not None:
            self.dst.state["history"].append([self.name, model_response])
            self.dst.state["system_action"] = self.output_action
            # If system takes booking action add booking info to the 'book-booked' section of the belief state
            if type(self.output_action) == list:
                for intent, domain, slot, value in self.output_action:
                    if intent == "book":
                        self.dst.state["booked"][domain] = [{slot: value}]
        self.history.append([self.name, model_response])

        self.turn += 1
        self.lastUpdate = int(time.time())

        self.agent_saves["dialogue_info_fundamental"].append(fundamental_info)
        self.agent_saves["dialogue_info"].append(self.get_info())
        return model_response

    def get_info(self):
        info_dict = {}
        for name in self.module_names:
            module = getattr(self, name)
            module_info = getattr(module, "info_dict", None)
            info_dict[name] = deepcopy(module_info)

        return info_dict

    def is_terminated(self):
        if hasattr(self.policy, "is_terminated"):
            return self.policy.is_terminated()
        return None

    def retrieve_reward(self):
        rewards = [1] * len(self.sys_state_history)
        for turn in self.feedback:
            turn_number = int((int(turn) - 2) / 2)
            if turn_number >= len(self.sys_state_history):
                continue
            # TODO possibly use text here to check whether rating belongs to the right utterance of the system
            text = self.feedback[turn]["text"]
            rating = self.feedback[turn]["isGood"]
            rewards[turn_number] = int(rating)
        return rewards

    def get_reward(self):
        if hasattr(self.policy, "get_reward"):
            return self.policy.get_reward()
        return None

    def init_session(self):
        """Init the attributes of DST and Policy module."""
        if self.nlu is not None:
            self.nlu.init_session()
        if self.dst is not None:
            self.dst.init_session()
            self.dst.state["history"].append([self.name, "null"])
        if self.policy is not None:
            self.policy.init_session()
        if self.nlg is not None:
            self.nlg.init_session()
        self.history = []

    def get_in_da(self):
        return self.input_action

    def get_out_da(self):
        return self.output_action

    def print_ending_agent_summary(self):
        print("session_id")
        print(self.session_id)
        print("taskID")
        print(self.taskID)
        print("USER_GOAL_ACHIEVED")
        print(self.USER_GOAL_ACHIEVED)
        print("sys_state_history")
        print(self.sys_state_history)
        print("sys_action_history")
        print(self.sys_action_history)

    def is_inactive(self):
        currentTime = int(time.time())
        return (
            currentTime - self.initTime >= 600 and currentTime - self.lastUpdate >= 60
        )


class NeuLoopDialogueAgent(Agent):
    """Pipeline dialog agent base class, including NLU, DST, Policy and NLG."""

    def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str = "sys"):
        """The constructor of DialogueAgent class.

        Here are some special combination cases:

            1. If you use word-level DST (such as Neural Belief Tracker), you should set the nlu_model parameter \
             to None. The agent will combine the modules automatically.

            2. If you want to aggregate DST and Policy as a single module, set tracker to None.

        Args:
            nlu (NLU):
                The natural language understanding module of agent.

            dst (DST):
                The dialog state tracker of agent.

            policy (Policy):
                The dialog policy module of agent.

            nlg (NLG):
                The natural language generator module of agent.
        """

        super(NeuLoopDialogueAgent, self).__init__(name=name)
        assert self.name in ["sys"]
        self.opponent_name = "user"
        self.nlu = nlu
        self.dst = dst
        self.policy = policy
        self.nlg = nlg
        self.module_names = ["nlu", "dst", "policy", "nlg"]
        self.init_session()
        self.history = []
        self.session_id = None
        self.ENDING_DIALOG = False
        self.USER_RATED = False
        self.USER_GOAL_ACHIEVED = None
        self.taskID = None
        self.feedback = None
        self.requested_feedback = False
        self.sys_state_history = []
        self.sys_action_history = []
        self.sys_utterance_history = []
        self.sys_output_history = []
        self.action_mask_history = []
        self.action_prob_history = []
        self.turn = 0
        self.agent_saves = {
            "session_id": None,
            "agent_id": None,
            "user_id": None,
            "timestamp": None,
            "dialogue_info": [],
            "dialogue_info_fundamental": [],
        }
        self.initTime = int(time.time())
        self.lastUpdate = int(time.time())

        logging.info("Dialogue Agent info_dict check")
        if not hasattr(self.nlu, "info_dict"):
            logging.warning("nlu info_dict is not initialized")
        if not hasattr(self.dst, "info_dict"):
            logging.warning("dst info_dict is not initialized")
        if not hasattr(self.policy, "info_dict"):
            logging.warning("policy info_dict is not initialized")
        if not hasattr(self.nlg, "info_dict"):
            logging.warning("nlg info_dict is not initialized")

    def response(self, observation):
        """Generate agent response using the agent modules."""

        self.sys_utterance_history.append(observation)
        fundamental_info = {"observation": observation}

        if self.dst is not None:
            self.dst.state["history"].append(
                [self.opponent_name, observation]
            )  # [['sys', sys_utt], ['user', user_utt],...]
        self.history.append([self.opponent_name, observation])
        # get dialog act
        if self.nlu is not None:
            self.input_action = self.nlu.predict(
                observation, context=[x[1] for x in self.history[:-1]]
            )
        else:
            self.input_action = observation
        # get rid of reference problem
        self.input_action = deepcopy(self.input_action)
        fundamental_info["input_action"] = self.input_action

        # get state
        if self.dst is not None:
            self.dst.state["user_action"] = self.input_action
            state = self.dst.update(self.input_action)
            print("------state------")
            pprint(state)
            print("------end of state------")
            if hasattr(self.dst, "get_emotion"):
                emotion = self.dst.get_emotion()
            else:
                emotion = "neutral"
        else:
            state = self.input_action

        state = deepcopy(state)  # get rid of reference problem
        fundamental_info["state"] = state
        self.sys_state_history.append(state)

        # get action
        # get rid of reference problem
        self.output_action = deepcopy(self.policy.predict(state))
        if hasattr(self.policy, "get_conduct"):
            conduct = self.policy.get_conduct()
        else:
            conduct = "neutral"

        print("------sys_action------")
        pprint(self.output_action)
        pprint(self.policy.get_conduct())
        print("------end of sys_action------")
        if hasattr(self.policy, "last_action"):
            self.sys_action_history.append(self.policy.last_action)
        else:
            self.sys_action_history.append(self.output_action)

        fundamental_info["output_action"] = self.output_action

        if hasattr(self.policy, "prob"):
            self.action_prob_history.append(self.policy.prob)

        # get model response
        if self.nlg is not None:
            model_response = self.nlg.generate(
                self.output_action, conduct="neutral", user_utt=observation
            )
        else:
            model_response = self.output_action

        self.sys_output_history.append(model_response)

        fundamental_info["model_response"] = model_response

        if self.dst is not None:
            self.dst.state["history"].append([self.name, model_response])
            self.dst.state["system_action"] = self.output_action
            # If system takes booking action add booking info to the 'book-booked' section of the belief state
            if type(self.output_action) == list:
                for intent, domain, slot, value in self.output_action:
                    if intent == "book":
                        self.dst.state["booked"][domain] = [{slot: value}]
        self.history.append([self.name, model_response])

        self.turn += 1
        self.lastUpdate = int(time.time())

        self.agent_saves["dialogue_info_fundamental"].append(fundamental_info)
        self.agent_saves["dialogue_info"].append(self.get_info())
        return model_response

    def get_info(self):
        info_dict = {}
        for name in self.module_names:
            module = getattr(self, name)
            module_info = getattr(module, "info_dict", None)
            info_dict[name] = deepcopy(module_info)

        return info_dict

    def is_terminated(self):
        if hasattr(self.policy, "is_terminated"):
            return self.policy.is_terminated()
        return None

    def retrieve_reward(self):
        rewards = [1] * len(self.sys_state_history)
        for turn in self.feedback:
            turn_number = int((int(turn) - 2) / 2)
            if turn_number >= len(self.sys_state_history):
                continue
            # TODO possibly use text here to check whether rating belongs to the right utterance of the system
            text = self.feedback[turn]["text"]
            rating = self.feedback[turn]["isGood"]
            rewards[turn_number] = int(rating)
        return rewards

    def get_reward(self):
        if hasattr(self.policy, "get_reward"):
            return self.policy.get_reward()
        return None

    def init_session(self):
        """Init the attributes of DST and Policy module."""
        if self.nlu is not None:
            self.nlu.init_session()
        if self.dst is not None:
            self.dst.init_session()
            self.dst.state["history"].append([self.name, "null"])
        if self.policy is not None:
            self.policy.init_session()
        if self.nlg is not None:
            self.nlg.init_session()
        self.history = []

    def get_in_da(self):
        return self.input_action

    def get_out_da(self):
        return self.output_action

    def print_ending_agent_summary(self):
        print("session_id")
        print(self.session_id)
        print("taskID")
        print(self.taskID)
        print("USER_GOAL_ACHIEVED")
        print(self.USER_GOAL_ACHIEVED)
        print("sys_state_history")
        print(self.sys_state_history)
        print("sys_action_history")
        print(self.sys_action_history)

    def is_inactive(self):
        currentTime = int(time.time())
        return (
            currentTime - self.initTime >= 600 and currentTime - self.lastUpdate >= 60
        )
