from .action import APIAction
from utils import api_util
import json
import copy
import uuid


class ChainToolSelection(APIAction):
    
    name: str = "tool_selection"
    fields: list = ["sess_id", "step", "traj_id", name]
    
    def __init__(self, profile, policy, api_env, handler):
        super(ChainToolSelection, self).__init__(profile, handler)

        self.api_env = api_env
        self.policy = policy
        self.beam = 1
        self.teacher_forcing = True if "selection" in self.teacher_forcing_list else False
        
    def get_oracle_parameter(self, traj_id, state, message):
        if 'golden' not in message:
            return {}
        
        idx = state.get_cur_subplan_idx(traj_id)
        if idx < len(message['golden']):
            return message['golden'][idx]["parameters"]
        else:
            return {}

    def calculate_parameter_reward(self, traj_id, state, message, free_parameter, oracle_parameter):
        if free_parameter == '##ignore##' or isinstance(free_parameter, str) or isinstance(oracle_parameter, str):
            return 0

        if len(free_parameter) != len(oracle_parameter):
            return 0
        
        # idx = state.get_cur_subplan_idx(traj_id)
        # required_param = message['golden'][idx]['api_info']['api_param']['required']
        reward = 1
        for k, v in oracle_parameter.items():
            
            if v is None:
                continue

            if k not in free_parameter:
                reward = 0
                break
            if k in free_parameter and free_parameter[k] != v:
                reward = 0
                break
        return reward

    # for parameter selection
    def is_parameter_regen(self, traj_id, state, message):
        
        if state.is_previous_plan(traj_id):
            return False, None
        else:
            reflec_traj = state.retrieve_reflection(traj_id, neighbor=True)
            assert reflec_traj is not None
            reflec_mem = reflec_traj.get_memory()
            
            if reflec_mem['reflection']['status'] == 2:
                return True, reflec_traj.get_memory()['reflection']['reason']
            else:
                return True, None

    def parameter_regen(self, traj_id, state, message, reflection, api):
        
        api_idx = api["ID"]
        api_info = message['api_list'][api_idx]
        properties = api_info['api_info']['api_param']['properties']

        if len(properties):
            prompts, param_text = self._parameter_regen(traj_id, state, message, reflection, api_idx)
            try:
                format_conv, parameter = self.parameter_format_check(traj_id, state, message, api_idx, param_text, api_info)
            except Exception as e:
                format_conv = None
                if self.teacher_forcing:
                    parameter = "##ignore##"
                else:
                    raise Exception("Parameter Format Error {}".format(e))
        else:
            prompts = None
            parameter = {}
            format_conv = None
        
        return prompts, format_conv, parameter
        
    def _parameter_regen(self, traj_id, state, message, reflection, api_idx):
        
        plan = state.get_cur_subplan(traj_id)
        system_prompt = self.build_parameter_sys_msg(message, api_idx)
        parameter_prompt = self.prompt[self.name]['re_parameter'].format(query=message["query"], plan=plan, reflection=reflection)
        
        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": parameter_prompt},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": parameter_prompt},
            ]

        parameter_format_text = self.policy(prompts)
        prompts.append({"role": "assistant", "content": parameter_format_text})
        return prompts, parameter_format_text

    def parameter_reformat(self, parameter_text, e):
        prompts = [{
            "role": "user",
            "content": self.prompt[self.name]['param_reformat'].format(text=parameter_text, mistake=e)
        }]
        parameter_format_text = self.policy(prompts)
        prompts.append({"role": "assistant", "content": parameter_format_text})
        return prompts, parameter_format_text

    def parameter_format_check(self, traj_id, state, message, api_idx, parameter_text, api_info):
        format_conv = None
        try:
            parameter = self._parameter_check(parameter_text, api_info)
        except Exception as e:
            if "Server Parameter Check Exception" in str(e):
                format_conv, parameter_format_text = self._parameter_regen(traj_id, state, message, parameter_text + str(e), api_idx)
                print("---> parameter regen: {} for e: {}".format(parameter_format_text, e))
            else:
                format_conv, parameter_format_text = self.parameter_reformat(parameter_text, e)
                print("---> parameter reformat {} for e: {}".format(parameter_format_text, e))
            parameter = self._parameter_check(parameter_format_text, api_info)
        return format_conv, parameter

    def _parameter_check(self, parameter_text, api_info):
        if "thought" in parameter_text:
            parameter_text = str(eval(parameter_text)['parameter'])
        else:
            # filter tag
            if "<param>" in parameter_text:
                parameter_text = parameter_text.replace("<param>", "").replace("</param>", "")

            if "```json" in parameter_text:
                s_pos = parameter_text.rfind("```json")
                e_pos = parameter_text.rfind("```")
                parameter_text = parameter_text[s_pos:e_pos].replace("```json", "")
            if "{" in parameter_text and "}" in parameter_text:
                s_pos = parameter_text.find("{")
                e_pos = parameter_text.rfind("}")
                parameter_text = parameter_text[s_pos:e_pos+1]

        parameter_text = parameter_text.replace("\\\\", "").replace("null", "None").replace("false", "False").replace("true", "True")
        
        print("_parameter_check : {}".format(parameter_text))
        parameter = eval(parameter_text)
        assert isinstance(parameter, dict)
        
        category_name = api_info['category_name']
        tool_name = api_info['tool_name']
        api_name = api_info['api_info']["api_name"]
            
        req_status, req_resp = True, ""
        # req_status, req_resp = self.api_env.check_param(category_name, tool_name, api_name, parameter, fmt=True)

        if req_status:
            self.logger.info("Parameter Check Pass: {}".format(req_resp))
            return parameter
        else:
            raise Exception("Server Parameter Check Exception: {}".format(req_resp))
        
    def parameter_gen(self, traj_id, state, message, api):
        
        api_idx = api["ID"]
        api_info = message['api_list'][api_idx]
        properties = api_info['api_info']['api_param']['properties']

        if len(properties):
            
            plan = state.get_cur_subplan(traj_id)
            system_prompt = self.build_parameter_sys_msg(message, api_idx)

            parameter_prompt = self.prompt[self.name]['parameter'].format(query=message["query"], plan=plan)

            if "mistral" in self.policy_aka:
                prompts = [
                    {"role": "user", "content": system_prompt},
                    {"role": "assistant", "content": ""},
                    {"role": "user", "content": parameter_prompt},
                ]
            else:
                prompts = [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": parameter_prompt},
                ]
            
            param_text = self.policy(prompts)
            try:
                format_conv, parameter = self.parameter_format_check(traj_id, state, message, api_idx, param_text, api_info)
            except Exception as e:
                format_conv = None
                if self.teacher_forcing:
                    parameter = "##ignore##"
                else:
                    raise Exception("Parameter Format Error {}".format(e))
            prompts.append({"role": "assistant", "content": parameter})
        else:
            prompts = None
            parameter = {}
            format_conv = None

        return prompts, format_conv, parameter
    
    # for api selection
    def is_regen(self, traj_id, state, message):
        
        if state.is_previous_plan(traj_id):
            return False, None
        else:
            reflec_traj = state.retrieve_reflection(traj_id, neighbor=True)
            assert reflec_traj is not None
            reflec_mem = reflec_traj.get_memory()
            
            if reflec_mem['reflection']['status'] == 1:
                return True, reflec_traj.get_memory()['reflection']['reason']
            elif reflec_mem['reflection']['status'] == 2:
                return True, None
            else:
                return False, None
        
    def copy_plan(self, traj_id, state, message):
        prev_api_mem = state.retrieve_api(traj_id)[-2].get_memory()
        return prev_api_mem['api_conv'], prev_api_mem['api_format_conv'], prev_api_mem['api']

    def api_regen(self, traj_id, state, message, reflection):
        
        prev_api_mem = state.retrieve_api(traj_id)[-2].get_memory()
        plan = state.get_cur_subplan(traj_id)
        system_prompt = self.build_sys_msg(message)

        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content":self.prompt[self.name]['re_selection'].format(plan=plan, api=prev_api_mem['api'], reflection=reflection)},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content":self.prompt[self.name]['re_selection'].format(plan=plan, api=prev_api_mem['api'], reflection=reflection)},
            ]
        
        api_text = self.policy(prompts)
        try:
            format_conv, api = self.api_format_check(api_text, message["api_list"])
        except Exception as e:
            format_conv = None
            if self.teacher_forcing:
                api = "##ignore##"
            else:
                raise Exception("API Format Error {}".format(e))
        prompts.append({"role": "assistant", "content": api})
        return prompts, format_conv, api

    def get_oracle_api(self, traj_id, state, message):
        if 'golden' not in message:
            return {"ID": -1, "reason": "no golden"}
        
        idx = state.get_cur_subplan_idx(traj_id)
        if idx < len(message['golden']):
            return message['golden'][idx]['conv']["selection"]
        else:
            return {"ID": -1, "reason": "no golden"}

    def calculate_api_reward(self, free_api, oracle_api):
        
        if free_api == '##ignore##' or isinstance(free_api, str) or isinstance(oracle_api, str):
            return 0
        
        if free_api["ID"] == oracle_api['ID'] and free_api["api_name"] == oracle_api['api_name'] :
            reward = 1
        else:
            reward = 0
        return reward

    def api_gen(self, traj_id, state, message):
        
        plan = state.get_cur_subplan(traj_id)
        system_prompt = self.build_sys_msg(message)
        
        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content":self.prompt[self.name]['selection'].format(plan=plan)},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content":self.prompt[self.name]['selection'].format(plan=plan)},
            ]
        
        api_text = self.policy(prompts)
        
        try:
            format_conv, api = self.api_format_check(api_text, message["api_list"])
        except Exception as e:
            format_conv = None
            if self.teacher_forcing:
                api = "##ignore##"
            else:
                raise Exception("API Format Error {}".format(e))
        prompts.append({"role": "assistant", "content": api})
        return prompts, format_conv, api

    def api_reformat(self, api_text, e):

        prompts = [{
            "role": "user",
            "content": self.prompt[self.name]['api_reformat'].format(text=api_text, mistake=e)
        }]
        api_format_text = self.policy(prompts)
        prompts.append({"role": "assistant", "content": api_format_text})
        return prompts, api_format_text

    def api_format_check(self, api_text, api_list):
        
        format_conv = None
        try:
            api = self._api_check(api_text, api_list)
        except Exception as e:
            format_conv, api_format_text = self.api_reformat(api_text, e)
            api = self._api_check(api_format_text, api_list)
        return format_conv, api

    def _api_check(self, api_text, api_list):
        
        # filter tag
        if "<api>" in api_text:
            api_text = api_text.replace("<api>", "").replace("</api>", "")

        if "```json" in api_text:
            s_pos = api_text.find("```json")
            e_pos = api_text.rfind("```")
            api_text = api_text[s_pos:e_pos].replace("```json", "")
        if "{" in api_text and "}" in api_text:
            s_pos = api_text.find("{")
            e_pos = api_text.rfind("}")
            api_text = api_text[s_pos:e_pos+1]
        
        # self.logger.info("api_text: {}".format(api_text))
        api = eval(api_text)
        assert isinstance(api, dict)
        assert "api_name" in api and isinstance(api["api_name"], str)
        assert "ID" in api and isinstance(api["ID"], int)
        assert api["ID"] < len(api_list) and api["ID"] >= 0
        return api

    def run(self, state, message, traj_id, recover=False):
        
        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"

            oracle_api = self.get_oracle_api(traj_id, state, message)
            
            regen_flag, reflection = self.is_regen(traj_id, state, message)

            if regen_flag:
                if reflection is not None:
                    conv, format_conv, free_api = self.api_regen(traj_id, state, message, reflection)
                else:
                    # copy from previous api
                    conv, format_conv, free_api = self.copy_plan(traj_id, state, message)
            else:
                conv, format_conv, free_api = self.api_gen(traj_id, state, message)

            self.logger.info("oracle api: {}".format(oracle_api))
            self.logger.info("free api: {}".format(free_api))

            if self.teacher_forcing:
                api = oracle_api
            else:
                api = free_api

            if self.reward:
                reward = self.calculate_api_reward(free_api, oracle_api)
                self.logger.info("[assistant] step-{} {}: api reward = {}.".format(
                    str(state.step), self.name, reward
                ))
            else:
                reward = -1

            state.update(traj_id, {
                "api_conv": conv,
                "api_format_conv": format_conv,
                "api": api,
                "api_reward": reward,
                "free_api": free_api,
                "oracle_api": oracle_api,
                "api_json": message['api_list'][api["ID"]]
            })
            self.save_serialized_obj(state, traj_id)

        traj = state.get(traj_id)
        api = traj.get_memory()['api']
        self.logger.info("[assistant] step-{} {}: api=[{}].".format(
            str(state.step), self.name, json.dumps(api)[1:-1]
        ))
        
        if "parameter" not in traj.get_memory():
            if recover: return "leaf_node"

            oracle_parameter = self.get_oracle_parameter(traj_id, state, message)

            regen_flag, reflection = self.is_parameter_regen(traj_id, state, message)
            if regen_flag and reflection is not None:
                conv, format_conv, free_parameter = self.parameter_regen(traj_id, state, message, reflection, api)
            else:
                conv, format_conv, free_parameter = self.parameter_gen(traj_id, state, message, api)
            
            self.logger.info("oracle parameter: {}".format(oracle_parameter))
            self.logger.info("free parameter: {}".format(free_parameter))

            if self.teacher_forcing:
                parameter = oracle_parameter
            else:
                parameter = free_parameter

            if self.reward:
                if traj.get_memory()['api_reward'] == 1:
                    reward = self.calculate_parameter_reward(traj_id, state, message, free_parameter, oracle_parameter)
                else:
                    reward = 0
                self.logger.info("[assistant] step-{} {}: parameter reward = {}.".format(
                    str(state.step), self.name, reward
                ))
            else:
                reward = -1
            
            state.update(traj_id, {
                "param_conv": conv,
                "param_format_conv": format_conv,
                "parameter": parameter,
                "parameter_reward": reward,
                "free_parameter": free_parameter,
                "oracle_parameter": oracle_parameter,
                "child_id": [str(uuid.uuid4())]
            })
            self.save_serialized_obj(state, traj_id)
        else:
            parameter = traj.get_memory()['parameter']
        
        traj.set_trans_param("*")
        state.next_traj(traj, traj.get_memory()['child_id'][0])

        self.logger.info("[assistant] step-{} {}: parameter=[{}].".format(
            str(state.step), self.name, json.dumps(traj.get_memory()['parameter'])[1:-1]
        ))


class ChainToughtToolSelection(ChainToolSelection):
    
    def __init__(self, profile, policy, api_env, handler):
        super(ChainToughtToolSelection, self).__init__(profile, policy, api_env, handler)

    def thought_gen(self, system_prompt, message, plan):

        thought_prompt = self.prompt[self.name]['thought'].format(query=message["query"], plan=plan)

        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": thought_prompt},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": thought_prompt},
            ]
        thought_text = self.policy(prompts)
        return thought_text

    def parameter_gen(self, traj_id, state, message, api):
        
        api_idx = api["ID"]
        api_info = message['api_list'][api_idx]
        properties = api_info['api_info']['api_param']['properties']

        if len(properties):
            
            plan = state.get_cur_subplan(traj_id)
            system_prompt = self.build_parameter_sys_msg(message, api_idx)
            
            thought_text = self.thought_gen(system_prompt, message, plan)
            parameter_prompt = self.prompt[self.name]['thought_parameter'].format(query=message["query"], plan=plan, thought=thought_text)
            prompt = self.prompt[self.name]['parameter'].format(query=message["query"], plan=plan)
            
            if "mistral" in self.policy_aka:
                prompts = [
                    {"role": "user", "content": system_prompt},
                    {"role": "assistant", "content": ""},
                    {"role": "user", "content": parameter_prompt},
                ]
                rprompts = [
                    {"role": "user", "content": system_prompt},
                    {"role": "assistant", "content": ""},
                    {"role": "user", "content": prompt},
                ]
            else:
                prompts = [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": parameter_prompt},
                ]
                rprompts = [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": prompt},
                ]
            
            param_text = self.policy(prompts)
            try:
                format_conv, parameter = self.parameter_format_check(traj_id, state, message, api_idx, param_text, api_info)
            except Exception as e:
                format_conv = None
                if self.teacher_forcing:
                    parameter = "##ignore##"
                else:
                    raise Exception("Parameter Format Error {}".format(e))
            rprompts.append({"role": "assistant", "content": {"parameter": parameter, "thought": thought_text}})
        else:
            rprompts = None
            parameter = {}
            format_conv = None

        return rprompts, format_conv, parameter

class RejectToolSelection(ChainToughtToolSelection):

    def run(self, state, message, traj_id, recover=False):
        
        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"

            oracle_api = self.get_oracle_api(traj_id, state, message)
            
            regen_flag, reflection = self.is_regen(traj_id, state, message)

            if regen_flag:
                if reflection is not None:
                    conv, format_conv, free_api = self.api_regen(traj_id, state, message, reflection)
                else:
                    # copy from previous api
                    conv, format_conv, free_api = self.copy_plan(traj_id, state, message)
            else:
                conv, format_conv, free_api = self.api_gen(traj_id, state, message)

            self.logger.info("oracle api: {}".format(oracle_api))
            self.logger.info("free api: {}".format(free_api))

            if self.teacher_forcing:
                api = oracle_api
            else:
                api = free_api

            if self.reward:
                reward = self.calculate_api_reward(free_api, oracle_api)
                self.logger.info("[assistant] step-{} {}: api reward = {}.".format(
                    str(state.step), self.name, reward
                ))
            else:
                reward = -1

            if reward != 1:
                raise Exception("Reject...")

            state.update(traj_id, {
                "api_conv": conv,
                "api_format_conv": format_conv,
                "api": api,
                "api_reward": reward,
                "free_api": free_api,
                "oracle_api": oracle_api,
                "api_json": message['api_list'][api["ID"]]
            })
            self.save_serialized_obj(state, traj_id)

        traj = state.get(traj_id)
        api = traj.get_memory()['api']
        self.logger.info("[assistant] step-{} {}: api=[{}].".format(
            str(state.step), self.name, json.dumps(api)[1:-1]
        ))
        
        if "parameter" not in traj.get_memory():
            if recover: return "leaf_node"

            oracle_parameter = self.get_oracle_parameter(traj_id, state, message)

            regen_flag, reflection = self.is_parameter_regen(traj_id, state, message)
            if regen_flag and reflection is not None:
                conv, format_conv, free_parameter = self.parameter_regen(traj_id, state, message, reflection, api)
            else:
                conv, format_conv, free_parameter = self.parameter_gen(traj_id, state, message, api)
            
            self.logger.info("oracle parameter: {}".format(oracle_parameter))
            self.logger.info("free parameter: {}".format(free_parameter))

            if self.teacher_forcing:
                parameter = oracle_parameter
            else:
                parameter = free_parameter

            if self.reward:
                if traj.get_memory()['api_reward'] == 1:
                    reward = self.calculate_parameter_reward(traj_id, state, message, free_parameter, oracle_parameter)
                else:
                    reward = 0
                self.logger.info("[assistant] step-{} {}: parameter reward = {}.".format(
                    str(state.step), self.name, reward
                ))
            else:
                reward = -1
            
            if reward != 1:
                raise Exception("Reject...")

            state.update(traj_id, {
                "param_conv": conv,
                "param_format_conv": format_conv,
                "parameter": parameter,
                "parameter_reward": reward,
                "free_parameter": free_parameter,
                "oracle_parameter": oracle_parameter,
                "child_id": [str(uuid.uuid4())]
            })
            self.save_serialized_obj(state, traj_id)
        else:
            parameter = traj.get_memory()['parameter']
        
        traj.set_trans_param("*")
        state.next_traj(traj, traj.get_memory()['child_id'][0])

        self.logger.info("[assistant] step-{} {}: parameter=[{}].".format(
            str(state.step), self.name, json.dumps(traj.get_memory()['parameter'])[1:-1]
        ))

class BiTreeToolSelection(ChainToolSelection):

    def __init__(self, profile, policy, api_env, handler):
        super(BiTreeToolSelection, self).__init__(profile, policy, api_env, handler)
        self.beam = 1

class TreeToolSelection(BiTreeToolSelection):

    def __init__(self, profile, policy, api_env, handler):
        super(TreeToolSelection, self).__init__(profile, policy, api_env, handler)
        self.beam = 1

    