from .action import APIAction
from utils import api_util
import json
import copy
import uuid


class ChainTaskReflection(APIAction):

    name: str = "task_reflection"
    fields: list = ["sess_id", "step", "traj_id", name]
    
    def __init__(self, profile, policy, handler):
        super(ChainTaskReflection, self).__init__(profile, handler)

        self.policy = policy
        self.trans_param_list = [0, 1]
        self.beam = 1
        self.teacher_forcing = True if "freflection" in self.teacher_forcing_list else False
        
    def calculate_reflection_reward(self, message, traj_id, state, free_reflection, oracle_reflection):
        
        if free_reflection == "##ignore##":
            return 0
        
        if self.teacher_forcing:
            if free_reflection["status"] == oracle_reflection["status"]:
                return 1
            else:
                return 0
        else:
            pass_set = set()
            for ref_traj in state.retrieve_reflection_boundary(traj_id):
                reward = ref_traj.get_memory()['reward']
                idx = state.get_cur_subplan_idx(ref_traj.get_node_id())
                if reward >= 1: pass_set.add(idx)

            if len(message['golden']) == len(pass_set) and free_reflection["status"] == 0:
                return 1
            elif len(message['golden']) != len(pass_set) and free_reflection["status"] == 1:
                return 1
            else:
                return 0

    def get_oracle_reflection(self, state, traj_id, message):
        return {
            "status": 0,
            "reason": "the summary solved this problem."
        }

    def reflec_reformat(self, reflec_text, e):
        prompts = [{
            "role": "user",
            "content": self.prompt[self.name]['reformat'].format(text=reflec_text, mistake=e)
        }]
        reflec_format_text = self.policy(prompts)
        self.logger.info("reformat: {}".format(reflec_format_text))
        prompts.append({"role": "assistant", "content": reflec_format_text})
        return prompts, reflec_format_text

    def reflec_format_check(self, reflec_text):
        
        format_conv = None
        try:
            reflection = self._reflec_check(reflec_text)
        except Exception as e:
            format_conv, reflec_format_text = self.reflec_reformat(reflec_text, e)
            reflection = self._reflec_check(reflec_format_text)
        
        return format_conv, reflection

    def _reflec_check(self, reflection_text):
        
        # filter tag
        if "<reflection>" in reflection_text:
            reflection_text = reflection_text.replace("<reflection>", "").replace("</reflection>", "")

        if "```json" in reflection_text:
            s_pos = reflection_text.find("```json")
            e_pos = reflection_text.rfind("```")
            reflection_text = reflection_text[s_pos:e_pos].replace("```json", "")
        if "{" in reflection_text and "}" in reflection_text:
            s_pos = reflection_text.find("{")
            e_pos = reflection_text.rfind("}")
            reflection_text = reflection_text[s_pos:e_pos+1]
        
        reflection = eval(reflection_text)
        assert isinstance(reflection, dict)
        assert "status" in reflection and "reason" in reflection
        assert reflection['status'] in [0, 1, 2]
        return reflection
        
    def reflec_gen(self, traj_id, state, summary, message):

        prompts = [{
            "role": "user",
             "content": self.prompt[self.name]['reflection'].format(
                summary=summary,
                query=message["query"]
            )
        }]
        reflection_text = self.policy(prompts)
        try:
            format_conv, reflection = self.reflec_format_check(reflection_text)
        except Exception as e:
            format_conv = None
            if self.teacher_forcing:
                reflection = "##ignore##"
            else:
                raise Exception("Reflection Format Error {}".format(e))
        prompts.append({"role": "assistant", "content": reflection})
        return prompts, format_conv, reflection

    def get_oracle_summary(self, state, traj_id, message):
        if 'golden' not in message:
            return "The function call result perfectly resolved the query."
        
        idx = state.get_cur_subplan_idx(traj_id)
        if 'answer' in message:
            return message['answer']
        else:
            if idx - 1 < len(message['golden']):
                return message['golden'][idx - 1]['conv']["answer"]
            else:
                return ""
    def get_summary(self, traj_id, state):
        
        ref_trajs = state.retrieve_reflection(traj_id)
        summary = []
        for ref_traj in ref_trajs:
            if "summary" in ref_traj.get_memory():
                summary.append("[{}] {}.".format(len(summary), ref_traj.get_memory()["summary"]))

        return "\n".join(summary)

    def summary_gen(self, traj_id, state, message):
        
        summary = self.get_summary(traj_id, state)
        prompts = [{
            "role": "user", 
            "content": self.prompt[self.name]['summary'].format(
                summary=summary,
                query=message["query"]
            )
        }]
        
        summary = self.policy(prompts)
        summary = summary.replace("\\", "").replace("\n\n", "")
        prompts.append({"role": "assistant", "content": summary})
        return prompts, summary
    
    def run(self, state, message, traj_id, recover=False):

        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"

            oracle_summary = self.get_oracle_summary(state, traj_id, message)
            oracle_reflection = self.get_oracle_reflection(state, traj_id, message)
            
            summary_conv, free_summary = self.summary_gen(traj_id, state, message)
            reflection_conv, format_conv, free_reflection = self.reflec_gen(traj_id, state, free_summary, message)
            
            self.logger.info("free summary: {}".format(free_summary))
            self.logger.info("free reflection: {}".format(free_reflection))

            if self.teacher_forcing:
                summary = oracle_summary
                reflection = oracle_reflection
            else:
                summary = free_summary
                reflection = free_reflection

            if self.reward:
                reward = self.calculate_reflection_reward(message, traj_id, state, free_reflection, oracle_reflection)
                self.logger.info("[assistant] step-{} {}: reflection reward = {}.".format(
                    str(state.step), self.name, reward
                ))
            else:
                reward = -1

            state.update(traj_id, {
                "summary_conv": summary_conv, 
                "summary": summary,
                "oracle_summary": oracle_summary,
                "free_summary": free_summary,
                "reflection_conv": reflection_conv,
                "reflection_format_conv": format_conv,
                "reflection": reflection,
                "oracle_reflection": oracle_reflection,
                "free_reflection": free_reflection,
                "reward": reward,
                "child_id": [str(uuid.uuid4()) for _ in range(self.beam)]
            })
            self.save_serialized_obj(state, traj_id)
        
        traj = state.get(traj_id)
        reflection = traj.get_memory()["reflection"]
        reflection["status"] = 0
        traj.set_trans_param(reflection["status"])
        for idx in range(self.beam):
            state.next_traj(traj, traj.get_memory()['child_id'][idx], param_filter=self.trans_param_list)

        self.logger.info("[assistant] step-{} {}: [{}].".format(
            str(state.step), self.name,
            traj.get_memory()['reflection']
        ))


class PesudoChainTaskReflection(ChainTaskReflection):

    def run(self, state, message, traj_id, recover=False):

        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"

            state.update(traj_id, {
                "summary_conv": [],
                "summary": "summary",
                "oracle_summary": "oracle_summary",
                "free_summary": "free_summary",
                "reflection_conv": [],
                "reflection_format_conv": [],
                "reflection": {"status": 0, "reason": "reflection"},
                "oracle_reflection": {"status": 0, "reason": "reflection"},
                "free_reflection": {"status": 0, "reason": "reflection"},
                "reward": 1,
                "child_id": [str(uuid.uuid4()) for _ in range(self.beam)]
            })
            self.save_serialized_obj(state, traj_id)
        
        traj = state.get(traj_id)
        reflection = traj.get_memory()["reflection"]
        reflection["status"] = 0
        traj.set_trans_param(reflection["status"])
        for idx in range(self.beam):
            state.next_traj(traj, traj.get_memory()['child_id'][idx], param_filter=self.trans_param_list)

        self.logger.info("[assistant] step-{} {}: [{}].".format(
            str(state.step), self.name,
            traj.get_memory()['reflection']
        ))


class BiTreeTaskReflection(ChainTaskReflection):

    def __init__(self, profile, policy, handler):
        super(BiTreeTaskReflection, self).__init__(profile, policy, handler)
        self.beam = 2

class TreeTaskReflection(BiTreeTaskReflection):

    def __init__(self, profile, policy, handler):
        super(TreeTaskReflection, self).__init__(profile, policy, handler)
        self.beam = 1
