from .traj import TreeTrajectory
import uuid
import json


class ReflectionState:
    
    def initialize_transition(self):
        
        return {
            ("?", ReflectionActionValue.ToolPlanning, "*"): ("?", ReflectionActionValue.ToolSelection),
            ("?", ReflectionActionValue.ToolSelection, "*"): ("?", ReflectionActionValue.ToolExecution),
            ("?", ReflectionActionValue.ToolExecution, "*"): ("?", ReflectionActionValue.ToolReflection),

            ("?", ReflectionActionValue.ToolReflection, 3): (TaskStateValue.UnSolved, ReflectionActionValue.ToolPlanning),
            ("?", ReflectionActionValue.ToolReflection, 0): (TaskStateValue.PartiallySolved, ReflectionActionValue.ToolSelection),
            ("?", ReflectionActionValue.ToolReflection, 1): (TaskStateValue.PartiallySolved, ReflectionActionValue.ToolSelection),
            ("?", ReflectionActionValue.ToolReflection, 2): (TaskStateValue.Solved, "*"),
        }

    def initialize_state(self, session_id):
        
        self.session_id = session_id
        task_state = TaskStateValue.UnSolved
        agent_state = ReflectionActionValue.ToolPlanning
        self.trajectories = TreeTrajectory(task_state, agent_state, "*", self.session_id)
    
    def get_last_status(self, message):
        
        stat_info = []
        train_traj = []
        plan_traj = []
        api_traj = []
        parameter_traj = []
        for paths in self.trajectories.get_all_paths():
            all_traj = []
            response = ""
            for traj in paths[:-1]:
                if traj.agent_state == ReflectionActionValue.ToolPlanning:
                    plan_traj.append(traj.get_memory()["plan"])
                elif traj.agent_state == ReflectionActionValue.ToolSelection:
                    continue
                elif traj.agent_state == ReflectionActionValue.ToolExecution:
                    local_memory = traj.get_memory(local=True)
                    if "api" in local_memory:
                        api_traj.append(local_memory['api'])
                        parameter_traj.append(local_memory['parameter'])
                    
                    all_traj.append(traj.get_memory()["observation_compr"])
                elif traj.agent_state == ReflectionActionValue.ToolReflection:
                    if "final_summary" in traj.get_memory():
                        response = traj.get_memory()["final_summary"]
                    elif "summary" in traj.get_memory():
                        response = traj.get_memory()["summary"]
                    else:
                        reposne = ""
            
            train_traj.append(json.dumps({
                "session_id": message['session_id'],
                "query": message['query'],
                "available_tools": message['api_list'],
                "answer": {
                    "method": "react",
                    "total_steps": len(all_traj),
                    "final_answer": response,
                    "plan_traj": plan_traj,
                    "api_traj": api_traj,
                    "parameter_traj": parameter_traj,
                    "answer_details": all_traj,
                },
            }, ensure_ascii=False))
            
            stat_info.append({
                "step": len(paths),
                "status": paths[-1].trans_param
            })
        return stat_info, train_traj
    
    def _retrieve(self, 
        traj_id,
        xfilter={
            ReflectionActionValue.ToolPlanning,
            ReflectionActionValue.ToolSelection,
            ReflectionActionValue.ToolExecution,
            ReflectionActionValue.ToolReflection
        }
    ):
        traj_path = self.trajectories.find_ancestors(traj_id)
        plan_list = []
        for traj in traj_path:
            if traj.agent_state in xfilter:
                plan_list.append(traj)
        return plan_list

    def retrieve_plan(self, traj_id, neighbor=False):
        plans = self._retrieve(traj_id, xfilter={ReflectionActionValue.ToolPlanning})
        if neighbor:
            return plans[-1]
        else:
            return plans
    
    def retrieve_api(self, traj_id, neighbor=False):
        api = self._retrieve(traj_id, xfilter={ReflectionActionValue.ToolSelection})
        if neighbor:
            return api[-1]
        else:
            return api
    
    def retrieve_tool(self, traj_id, neighbor=False):
        api = self._retrieve(traj_id, xfilter={ReflectionActionValue.ToolExecution})
        if neighbor:
            if len(api):
                return api[-1]
            else: return None
        else:
            return api

    def retrieve_reflection(self, traj_id, neighbor=False):
        relection = self._retrieve(traj_id, xfilter={ReflectionActionValue.ToolReflection})
        if neighbor:
            return relection[-1]
        else:
            return relection