from .action import APIAction
from utils import api_util
import json
import copy
import uuid


class ChainToolReflection(APIAction):
    
    name: str = "tool_reflection"
    fields: list = ["sess_id", "step", "traj_id", name]
    
    def __init__(self, profile, policy, handler):
        super(ChainToolReflection, self).__init__(profile, handler)
        
        self.policy = policy
        self.trans_param_list = [0, 1, 2, 3]
        self.beam = 1
        self.teacher_forcing = True if "reflection" in self.teacher_forcing_list else False

    def calculate_reflection_reward(self, traj_id, state, free_reflection, oracle_reflection):
        if free_reflection == "##ignore##":
            return 0

        if self.teacher_forcing:
            if free_reflection["status"] == oracle_reflection["status"]:
                return 1
            else:
                return 0
        else:
            api_traj = state.retrieve_api(traj_id, neighbor=True)
            api_memory = api_traj.get_memory()

            status = free_reflection["status"]
            if status == 0:
                return (api_memory["api_reward"] + api_memory["parameter_reward"]) / 2
            elif status == 1:
                return 1 - api_memory["api_reward"]
            elif status == 2:
                return 1 - api_memory["parameter_reward"]
            else:
                raise Exception("Not Support Status: {}".format(status))
        
    def get_oracle_reflection(self, state, traj_id, message):
        return {
            "status": 0,
            "reason": "the summary solved this problem."
        }

    def reflec_reformat(self, reflec_text, e):
        prompts = [{
            "role": "user",
            "content": self.prompt[self.name]['reformat'].format(text=reflec_text, mistake=e)
        }]
        reflec_format_text = self.policy(prompts)
        self.logger.info("reformat: {}".format(reflec_format_text))
        prompts.append({"role": "assistant", "content": reflec_format_text})
        return prompts, reflec_format_text

    def reflec_format_check(self, reflec_text):
        
        format_conv = None
        try:
            reflection = self._reflec_check(reflec_text)
        except Exception as e:
            format_conv, reflec_format_text = self.reflec_reformat(reflec_text, e)
            reflection = self._reflec_check(reflec_format_text)
        
        return format_conv, reflection

    def _reflec_check(self, reflection_text):
        
        # filter tag
        if "<reflection>" in reflection_text:
            reflection_text = reflection_text.replace("<reflection>", "").replace("</reflection>", "")

        if "```json" in reflection_text:
            s_pos = reflection_text.find("```json")
            e_pos = reflection_text.rfind("```")
            reflection_text = reflection_text[s_pos:e_pos].replace("```json", "")
        if "{" in reflection_text and "}" in reflection_text:
            s_pos = reflection_text.find("{")
            e_pos = reflection_text.rfind("}")
            reflection_text = reflection_text[s_pos:e_pos+1]
        
        reflection = eval(reflection_text)
        assert isinstance(reflection, dict)
        assert "status" in reflection and "reason" in reflection
        assert reflection['status'] in [0, 1, 2, 3]
        return reflection
        
    def reflec_gen(self, traj_id, state, summary):

        plan = state.get_cur_subplan(traj_id)
        prompts = [{
            "role": "user",
             "content": self.prompt[self.name]['reflection'].format(
                observation=summary,
                plan=plan
            )
        }]
        reflection_text = self.policy(prompts)
        
        try:
            format_conv, reflection = self.reflec_format_check(reflection_text)
        except Exception as e:
            format_conv = None
            if self.teacher_forcing:
                reflection = "##ignore##"
            else:
                raise Exception("Reflection Format Error {}".format(e))
        prompts.append({"role": "assistant", "content": reflection})
        return prompts, format_conv, reflection

    def get_oracle_summary(self, state, traj_id, message):
        if 'golden' not in message:
            return "The function call result perfectly resolved the query."
        
        idx = state.get_cur_subplan_idx(traj_id)
        if idx < len(message['golden']):
            return message['golden'][idx]['conv']["answer"]
        else:
            return ""

    def get_observation(self, traj_id, state):
        obs_traj = state.retrieve_observation(traj_id, neighbor=True)
        observation = obs_traj.get_memory()["observation"]
        return observation

    def summary_gen(self, traj_id, state):
        
        plan = state.get_cur_subplan(traj_id)
        observation = self.get_observation(traj_id, state)

        prompts = [{
            "role": "user",
            "content": self.prompt[self.name]['summary'].format(
                observation=observation,
                plan=plan
            )
        }]

        summary = self.policy(prompts)
        summary = summary.replace("\\", "").replace("\n\n", "")
        prompts.append({"role": "assistant", "content": summary})
        return prompts, summary
        
    def run(self, state, message, traj_id, recover=False):
        
        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"

            oracle_summary = self.get_oracle_summary(state, traj_id, message)
            oracle_reflection = self.get_oracle_reflection(state, traj_id, message)

            summary_conv, free_summary = self.summary_gen(traj_id, state)
            reflection_conv, format_conv, free_reflection = self.reflec_gen(traj_id, state, free_summary)
            
            # self.logger.info("free summary: {}".format(free_summary))
            self.logger.info("free reflection: {}".format(free_reflection))

            if self.teacher_forcing:
                summary = oracle_summary
                reflection = oracle_reflection
            else:
                summary = free_summary
                reflection = free_reflection

            if self.reward:
                reward = self.calculate_reflection_reward(traj_id, state, free_reflection, oracle_reflection)
                self.logger.info("[assistant] step-{} {}: reflection reward = {}.".format(
                    str(state.step), self.name, reward
                ))
            else:
                reward = -1

            state.update(traj_id, {
                "summary_conv": summary_conv, 
                "summary": summary,
                "oracle_summary": oracle_summary,
                "free_summary": free_summary,
                "reflection_conv": reflection_conv,
                "reflection_format_conv": format_conv,
                "reflection": reflection,
                "oracle_reflection": oracle_reflection,
                "free_reflection": free_reflection,
                "reward": reward,
                "child_id": [str(uuid.uuid4()) for _ in range(self.beam)]
            })
            self.save_serialized_obj(state, traj_id)

        traj = state.get(traj_id)
        reflection = traj.get_memory()["reflection"]

        plan_traj = state.retrieve_plan(traj_id, neighbor=True)
        plan = plan_traj.get_memory()['plan']
        local_plan = state.get_local_plan(traj_id, state)[traj_id]

        if local_plan['exec_count'] > 1:
            reflection['status'] = 0
        
        if reflection['status'] == 1 or reflection['status'] == 2:
            # If re-execution, keep the plan index, and the execution count + 1
            sub_plan_idx = local_plan['plan_idx']
            exec_count = local_plan['exec_count'] + 1
        else:
            # status == 0
            sub_plan_idx = local_plan['plan_idx'] + 1
            exec_count = local_plan['exec_count']
        
        # update local plan
        local_plan["plan_idx"] = sub_plan_idx
        local_plan["exec_count"] = exec_count
        state.update_local(plan_traj.get_node_id(), {traj_id: local_plan})
        
        if sub_plan_idx == len(plan):
            reflection['status'] = 3

        traj.set_trans_param(reflection["status"])
        for idx in range(self.beam):
            state.next_traj(traj, traj.get_memory()['child_id'][idx], param_filter=self.trans_param_list)

        self.logger.info("[assistant] step-{} {}:[plan_idx: {}/{}, exec_count = {}]: [{}].".format(
            str(state.step), self.name,
            local_plan['plan_idx'],
            len(plan),
            local_plan['exec_count'],
            traj.get_memory()['reflection']
        ))


class PesudoChainToolReflection(ChainToolReflection):

    def run(self, state, message, traj_id, recover=False):
        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"
            
            state.update(traj_id, {
                "summary_conv": [],
                "summary": "summary",
                "oracle_summary": "oracle_summary",
                "free_summary": "free_summary",
                "reflection_conv": [],
                "reflection_format_conv": None,
                "reflection": {"status": 0, "reason": "reflection"},
                "oracle_reflection": {"status": 0, "reason": "oracle_reflection"},
                "free_reflection": {"status": 0, "reason": "free_reflection"},
                "reward": 1,
                "child_id": [str(uuid.uuid4()) for _ in range(self.beam)]
            })
            self.save_serialized_obj(state, traj_id)

        traj = state.get(traj_id)
        reflection = traj.get_memory()["reflection"]

        plan_traj = state.retrieve_plan(traj_id, neighbor=True)
        plan = plan_traj.get_memory()['plan']
        local_plan = state.get_local_plan(traj_id, state)[traj_id]

        if local_plan['exec_count'] > 1:
            reflection['status'] = 0
        
        if reflection['status'] == 1 or reflection['status'] == 2:
            # If re-execution, keep the plan index, and the execution count + 1
            sub_plan_idx = local_plan['plan_idx']
            exec_count = local_plan['exec_count'] + 1
        else:
            # status == 0
            sub_plan_idx = local_plan['plan_idx'] + 1
            exec_count = local_plan['exec_count']
        
        # update local plan
        local_plan["plan_idx"] = sub_plan_idx
        local_plan["exec_count"] = exec_count
        state.update_local(plan_traj.get_node_id(), {traj_id: local_plan})
        
        if sub_plan_idx == len(plan):
            reflection['status'] = 3

        traj.set_trans_param(reflection["status"])
        for idx in range(self.beam):
            state.next_traj(traj, traj.get_memory()['child_id'][idx], param_filter=self.trans_param_list)

        self.logger.info("[assistant] step-{} {}:[plan_idx: {}/{}, exec_count = {}]: [{}].".format(
            str(state.step), self.name,
            local_plan['plan_idx'],
            len(plan),
            local_plan['exec_count'],
            traj.get_memory()['reflection']
        ))

class BiTreeToolReflection(ChainToolReflection):
    
    def __init__(self, profile, policy, handler):
        super(BiTreeToolReflection, self).__init__(profile, policy, handler)
        self.beam = 2

class TreeToolReflection(BiTreeToolReflection):
    
    def __init__(self, profile, policy, handler):
        super(TreeToolReflection, self).__init__(profile, policy, handler)
        self.beam = 2
