from .action import APIAction
from utils import api_util
import json
import copy
import uuid


class ChainTaskPlanning(APIAction):
    
    name: str = "task_planning"
    fields: list = ["sess_id", "step", "traj_id", name]

    def __init__(self, profile, policy, handler):
        super(ChainTaskPlanning, self).__init__(profile, handler)

        self.policy = policy
        self.beam = 1
        self.teacher_forcing = True if "plan" in self.teacher_forcing_list else False

    def calculate_reward(self, free_plan, oracle_plan):
        if len(free_plan) == len(oracle_plan):
            return 1
        else:
            return 0
        
    def get_oracle_plan(self, message):
        oracle_plan = []
        
        if 'golden' not in message:
            return oracle_plan
        
        for golden in message['golden']:
            oracle_plan.append(golden['conv']["plan"])
        return oracle_plan

    def reformat(self, plan_text):  

        prompts = [{
            "role": "user",
            "content": self.prompt[self.name]['plan_reformat'].format(text=plan_text)
        }]
        plan_format_text = self.policy(prompts)
        self.logger.info("plan reformat: {}".format(plan_format_text))
        prompts.append({"role": "assistant", "content": plan_format_text})
        return prompts, plan_format_text

    def plan_format_check(self, plan_text):
        if "thought" in plan_text:
            plan_text = str(eval(plan_text)['plan'])
        else:
            # filter tag
            if "<plan>" in plan_text:
                plan_text = plan_text.replace("<plan>", "").replace("</plan>", "")

            if "```json" in plan_text:
                s_pos = plan_text.rfind("```json")
                e_pos = plan_text.rfind("```")
                plan_text = plan_text[s_pos:e_pos].replace("```json", "")
            if "[" in plan_text and "]" in plan_text:
                s_pos = plan_text.find("[")
                e_pos = plan_text.rfind("]")
                plan_text = plan_text[s_pos:e_pos+1]
                
        print("plan_text: {}".format(plan_text))
        format_conv = None
        
        try:
            plan = eval(plan_text)
            assert isinstance(plan, list)
        except:
            format_conv, plan_format_text = self.reformat(plan_text)
            plan = eval(plan_format_text)
            assert isinstance(plan, list)
        
        # f_plan = []
        # for p in plan:
        #     if isinstance(p, str):
        #         f_plan.append(p.replace("\\", "").replace("'","").replace("\"", ""))
        #     else:
        #         f_plan.append(str(p).replace("\\", "").replace("'","").replace("\"", ""))
        return format_conv, plan

    def plan_gen(self, traj_id, state, message):
        system_prompt = self.build_sys_msg(message)
        plan_prompt = self.prompt[self.name]['plan'].format(query=message["query"])

        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": plan_prompt},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": plan_prompt},
            ]
        plan_text = self.policy(prompts)
        
        try:
            format_conv, plan = self.plan_format_check(plan_text)
        except Exception as e:
            format_conv = None
            if self.teacher_forcing:
                plan = "##ignore##"
            else:
                raise Exception("Plan Format Error {}".format(e))
        
        prompts.append({"role": "assistant", "content": plan})
        return prompts, format_conv, plan
        
    def _is_replan(self, traj_id, state, message):
        ref_traj = state.retrieve_task_reflection(traj_id, neighbor=True)
        if ref_traj is not None:
            assert state.retrieve_task_reflection(traj_id, neighbor=True).get_memory()['reflection']['status'] != 0
            return True
        else:
            return False

    def plan_regen(self, traj_id, state, message):
        
        plan_traj = state.retrieve_plan(traj_id)[-2]
        prev_plan = json.dumps(plan_traj.get_memory()['plan'])
        reflection = state.retrieve_task_reflection(traj_id, neighbor=True).get_memory()['reflection']['reason']
        
        system_prompt = self.build_sys_msg(message)
        plan_prompt = self.prompt[self.name]['replan'].format(prev_plan=prev_plan, query=message["query"], reflection=reflection)

        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": plan_prompt},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": plan_prompt},
            ]
        plan_text = self.policy(prompts)

        try:
            format_conv, plan = self.plan_format_check(plan_text)
        except Exception as e:
            format_conv = None
            if self.teacher_forcing:
                plan = "##ignore##"
            else:
                raise Exception("Plan Format Error {}".format(e))
        
        prompts.append({"role": "assistant", "content": plan})
        return prompts, format_conv, plan

    def run(self, state, message, traj_id, recover=False):
        
        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"
            
            if self._is_replan(traj_id, state, message):
                conv, format_conv, free_plan = self.plan_regen(traj_id, state, message)
                oracle_plan = self.get_oracle_plan(message)
            else:
                conv, format_conv, free_plan = self.plan_gen(traj_id, state, message)
                oracle_plan = self.get_oracle_plan(message)
            
            self.logger.info("free plan: {}".format(free_plan))
            if self.teacher_forcing:
                plan = oracle_plan
            else:
                plan = free_plan

            if self.reward:
                reward = self.calculate_reward(free_plan, oracle_plan)
                self.logger.info("[assistant] step-{} {}: plan reward = {}.".format(
                    str(state.step), self.name, reward
                ))
            else:
                reward = -1
            
            state.update(traj_id, {
                "conv": conv,
                "format_conv": format_conv,
                "plan": plan,
                "free_plan": free_plan,
                "oracle_plan": oracle_plan,
                "reward": reward,
                "child_id": [str(uuid.uuid4()) for _ in range(self.beam)]
            })
            self.save_serialized_obj(state, traj_id)
        
        traj = state.get(traj_id)
        traj.set_trans_param("*")
        state.update_local(traj_id, {traj_id: {"plan_idx": 0, "exec_count": 0}})
        for idx in range(self.beam): 
            state.next_traj(traj, traj.get_memory()['child_id'][idx])

        self.logger.info("[assistant] step-{} {}: {}.".format(
            str(state.step), self.name, traj.get_memory()["plan"])
        )

class FineChainTaskPlanning(ChainTaskPlanning):

    def __init__(self, profile, policy, handler):
        super(FineChainTaskPlanning, self).__init__(profile, policy, handler)

    def plan_regen(self, traj_id, state, message):
        
        plan_traj = state.retrieve_plan(traj_id)[-2]
        prev_plan = "".join(plan_traj.get_memory()['plan'])
        reflection = state.retrieve_task_reflection(traj_id, neighbor=True).get_memory()['reflection']['reason']
        
        system_prompt = self.build_sys_msg(message)
        plan_prompt = self.prompt[self.name]['fine_replan'].format(prev_plan=prev_plan, query=message["query"], reflection=reflection)

        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": plan_prompt},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": plan_prompt},
            ]
        plan_text = self.policy(prompts)

        try:
            format_conv, plan = self.plan_format_check(message, plan_text)
        except Exception as e:
            format_conv = None
            if self.teacher_forcing:
                plan = "##ignore##"
            else:
                raise Exception("Plan Format Error {}".format(e))
        
        prompts.append({"role": "assistant", "content": plan})
        return prompts, format_conv, plan

    def plan_gen(self, traj_id, state, message):
        system_prompt = self.build_sys_msg(message)
        plan_prompt = self.prompt[self.name]['fine_plan'].format(query=message["query"])

        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": plan_prompt},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": plan_prompt},
            ]
        plan_text = self.policy(prompts)

        try:
            format_conv, plan = self.plan_format_check(message, plan_text)
        except Exception as e:
            format_conv = None
            if self.teacher_forcing:
                plan = "##ignore##"
            else:
                raise Exception("Plan Format Error {}".format(e))
        
        prompts.append({"role": "assistant", "content": plan})
        return prompts, format_conv, plan

    def extract(self, message, plan_text):
        
        system_prompt = self.build_sys_msg(message)
        plan_prompt = self.prompt[self.name]['fine_extract'].format(plan=plan_text)

        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": plan_prompt},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": plan_prompt},
            ]
        plan_format_text = self.policy(prompts)
        return plan_format_text
        
    def plan_format_check(self, message, plan_text):
        
        plan_text = self.extract(message, plan_text)

        format_conv = None
        try:    
            plan = eval(plan_text)
            assert isinstance(plan, list)
        except:
            format_conv, plan_format_text = self.reformat(plan_text)
            plan = eval(plan_format_text)
            assert isinstance(plan, list)
        
        return format_conv, plan

class ChainToughtTaskPlanning(ChainTaskPlanning):

    def __init__(self, profile, policy, handler):
        super(ChainToughtTaskPlanning, self).__init__(profile, policy, handler)

    def thought_gen(self, system_prompt, message):
        thought_prompt = self.prompt[self.name]['thought'].format(query=message["query"])

        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": thought_prompt},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": thought_prompt},
            ]
        thought_text = self.policy(prompts)
        return thought_text

    def plan_gen(self, traj_id, state, message):
        
        system_prompt = self.build_sys_msg(message)
        thought_text = self.thought_gen(system_prompt, message)
        plan_prompt = self.prompt[self.name]['thought_plan'].format(thought=thought_text, query=message["query"])
        rplan_prompt = self.prompt[self.name]['plan'].format(query=message["query"])

        if "mistral" in self.policy_aka:
            prompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": plan_prompt},
            ]
            rprompts = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": ""},
                {"role": "user", "content": rplan_prompt},
            ]
        else:
            prompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": plan_prompt},
            ]
            rprompts = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": rplan_prompt},
            ]
        
        plan_text = self.policy(prompts)
        try:
            format_conv, plan = self.plan_format_check(plan_text)
        except Exception as e:
            format_conv = None
            if self.teacher_forcing:
                plan = "##ignore##"
            else:
                raise Exception("Plan Format Error {} for {}".format(e, plan_text))
        
        rprompts.append({"role": "assistant", "content": {"thought": thought_text, "plan": plan}})
        return rprompts, format_conv, plan

class RejectTaskPlanning(ChainToughtTaskPlanning):

    def run(self, state, message, traj_id, recover=False):
        
        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"
            
            if self._is_replan(traj_id, state, message):
                conv, format_conv, free_plan = self.plan_regen(traj_id, state, message)
                oracle_plan = self.get_oracle_plan(message)
            else:
                conv, format_conv, free_plan = self.plan_gen(traj_id, state, message)
                oracle_plan = self.get_oracle_plan(message)
            
            self.logger.info("free plan: {}".format(free_plan))
            if self.teacher_forcing:
                plan = oracle_plan
            else:
                plan = free_plan

            if self.reward:
                reward = self.calculate_reward(free_plan, oracle_plan)
                self.logger.info("[assistant] step-{} {}: plan reward = {}.".format(
                    str(state.step), self.name, reward
                ))
            else:
                reward = -1
            
            if reward != 1:
                raise Exception("Reject...")

            state.update(traj_id, {
                "conv": conv,
                "format_conv": format_conv,
                "plan": plan,
                "free_plan": free_plan,
                "oracle_plan": oracle_plan,
                "reward": reward,
                "child_id": [str(uuid.uuid4()) for _ in range(self.beam)]
            })
            self.save_serialized_obj(state, traj_id)
        
        traj = state.get(traj_id)
        traj.set_trans_param("*")
        state.update_local(traj_id, {traj_id: {"plan_idx": 0, "exec_count": 0}})
        for idx in range(self.beam): 
            state.next_traj(traj, traj.get_memory()['child_id'][idx])

        self.logger.info("[assistant] step-{} {}: {}.".format(
            str(state.step), self.name, traj.get_memory()["plan"])
        )

class BiTreeTaskPlanning(ChainTaskPlanning):

    def __init__(self, profile, policy, handler):
        super(BiTreeTaskPlanning, self).__init__(profile, policy, handler)
        self.beam = 2

class TreeTaskPlanning(BiTreeTaskPlanning):
    
    def __init__(self, profile, policy, handler):
        super(TreeTaskPlanning, self).__init__(profile, policy, handler)
        self.beam = 2
