from .finite_state import ActionStateValue, TaskStateValue
from .traj import TreeTrajectory
from enum import Enum
import json
import uuid


class ReactState(FiniteState):
    
    def __init__(self, profile):
        super().__init__(profile)

        self.trajectories = None
        self._transition = self.initialize_transition()
        self._step = 0

    def initialize_transition(self):
        
        return {
            ("?", ActionStateValue.ToolPlanning, "*"): ("?", ActionStateValue.ToolSelection),
            ("?", ActionStateValue.ToolSelection, "*"): ("?", ActionStateValue.ToolExecution),
            ("?", ActionStateValue.ToolExecution, "*"): ("?", ActionStateValue.ToolReflection),
            
            ("?", ActionStateValue.ToolReflection, 0): (TaskStateValue.PartiallySolved, ActionStateValue.ToolPlanning),
            ("?", ActionStateValue.ToolReflection, 1): (TaskStateValue.Solved, "*"),
        }
    
    def get_session_id(self):
        return self.session_id
        
    def initialize_state(self, session_id):
        
        self.session_id = session_id
        task_state = TaskStateValue.UnSolved
        agent_state = ReactActionValue.ToolPlanning
        self.trajectories = TreeTrajectory(task_state, agent_state, "*", self.session_id)
    
    def state_transition(self, t_state, a_state, trans_param, param_filter):
        '''
            state transition function: to obtain the next execution state.
        '''
        if trans_param not in param_filter: trans_param = "*"
        task_state, agent_state = self._transition[("?", a_state, trans_param)]
        if task_state == "?": task_state = t_state
        return task_state, agent_state, trans_param
    
    def run_action_loop(self):
        self._step += 1
        actions = []
        for traj in self.trajectories.find_leaves():
            task_state, agent_state, trans_param = traj.get_transition()
            if task_state == TaskStateValue.Solved:
                continue
            actions.append((agent_state, traj.get_node_id()))
        return actions
    
    def next_traj(self, parent_traj, traj_id, param_filter=[0]):
        task_state, agent_state, trans_param = self.state_transition(*parent_traj.get_transition(), param_filter=param_filter)
        self.trajectories.next_node(task_state, agent_state, trans_param, parent_traj, traj_id)
    
    def get_last_status(self, message):
        
        stat_info = []
        train_traj = []
        for paths in self.trajectories.get_all_paths():
            all_traj = []
            response = ""
            for traj in paths:
                if traj.agent_state == ReactActionValue.ToolPlanning:
                    continue
                elif traj.agent_state == ReactActionValue.ToolSelection:
                    continue
                elif traj.agent_state == ReactActionValue.ToolExecution:
                    all_traj.append(traj.get_memory()["observation_compr"])
                elif traj.agent_state == ReactActionValue.ToolReflection:
                    response = traj.get_memory()["summary"]

            train_traj.append(json.dumps({
                "session_id": message['session_id'],
                "query": message['query'],
                "available_tools": message['api_list'],
                "answer": {
                    "method": "react",
                    "total_steps": len(all_traj),
                    "final_answer": response,
                    "answer_details": all_traj,
                },
            }, ensure_ascii=False))
            
            stat_info.append({
                "step": len(paths),
                "status": paths[-1].agent_state
            })
        return stat_info, train_traj

    def get(self, traj_id):
        return self.trajectories.get(traj_id)

    def update(self, traj_id, new_traj_cont, trans_param=None):
        '''
            only update the traj content in memory
        '''
        self.trajectories.update(traj_id, new_traj_cont, trans_param)
    
    def update_local(self, traj_id, new_traj_cont):
        self.trajectories.update(traj_id, new_traj_cont, local=True)
    
    def dumps(self, traj_id):
        '''
            convert the Traj object into a string.
        '''
        return json.dumps(self.trajectories.get_memory(traj_id))

    def loads(self, traj_id, serial):
        '''
            convert the serialized string into a Traj object.
        '''
        self.trajectories.update(traj_id, json.loads(serial))
    
    def _retrieve(self, 
        traj_id,
        xfilter={
            ReactActionValue.ToolPlanning,
            ReactActionValue.ToolSelection,
            ReactActionValue.ToolExecution,
            ReactActionValue.ToolReflection
        }
    ):
        traj_path = self.trajectories.find_ancestors(traj_id)
        plan_list = []
        for traj in traj_path:
            if traj.agent_state in xfilter:
                plan_list.append(traj)
        return plan_list

    def retrieve_plan(self, traj_id, neighbor=False):
        plans = self._retrieve(traj_id, xfilter={ReactActionValue.ToolPlanning})
        if neighbor:
            return plans[-1]
        else:
            return plans
    
    def retrieve_api(self, traj_id, neighbor=False):
        api = self._retrieve(traj_id, xfilter={ReactActionValue.ToolSelection})
        if neighbor:
            return api[-1]
        else:
            return api
    
