from .state.finite_state import ActionStateValue, TaskStateValue
from .state.corvus_state import CorvusState
from .registry import auto_register
from action.task_init import *
from action.task_planning import *
from action.tool_selection import *
from action.tool_execution import *
from action.tool_reflection import *
from action.task_reflection import *
from parallel.runner import ActionQueue
from multiprocessing import Value
from utils import api_util, log
import json
import copy


@auto_register("corvus")
class Corvus:

    def __init__(self, profile, policy, api_env, handler):
        
        self.logger = log.get_loguru()
        
        self._step = 0

        self.profile = profile
        self.policy = policy
        self.api_env = api_env
        self.handler = handler
        
        self.core_profile = self.profile.load_core()
        self.core_memory = self.core_profile['memory']
        self.max_trail = self.core_profile["max_trail"]      
        
        self.processed = Value('i', 0)
        self.traj = []
        self.conv = []

        # prepare the action in pipeline
        self.build_action()
        self.action_runner = ActionQueue(1)

    def build_action(self):
        
        param_pack = (self.profile, self.policy, self.handler)
        api_param_pack = (self.profile, self.policy, self.api_env, self.handler)
        
        if self.core_memory == "chain":
            self.init_action = ChainInit(*param_pack)
            self.planning_action = ChainTaskPlanning(*param_pack)
            self.selection_action = ChainToolSelection(*api_param_pack)
            self.execution_action = ChainToolExecution(*api_param_pack)
            self.tool_reflection_action = ChainToolReflection(*param_pack)
            self.task_reflection_action = ChainTaskReflection(*param_pack)
        elif self.core_memory == "chain_thought":
            self.init_action = ChainInit(*param_pack)
            self.planning_action = ChainToughtTaskPlanning(*param_pack)
            self.selection_action = ChainToughtToolSelection(*api_param_pack)
            self.execution_action = ChainToolExecution(*api_param_pack)
            self.tool_reflection_action = ChainToolReflection(*param_pack)
            self.task_reflection_action = ChainTaskReflection(*param_pack)
        elif self.core_memory == "chain_reject":
            self.init_action = ChainInit(*param_pack)
            self.planning_action =  RejectTaskPlanning(*param_pack)
            self.selection_action = RejectToolSelection(*api_param_pack)
            self.execution_action = ChainToolExecution(*api_param_pack)
            self.tool_reflection_action = ChainToolReflection(*param_pack)
            self.task_reflection_action = ChainTaskReflection(*param_pack)
        
        elif self.core_memory == "bitree":
            self.init_action = ChainInit(*param_pack)
            self.planning_action = BiTreeTaskPlanning(*param_pack)
            self.selection_action = BiTreeToolSelection(*api_param_pack)
            self.execution_action = ChainToolExecution(*api_param_pack)
            self.tool_reflection_action = BiTreeToolReflection(*param_pack)
            self.task_reflection_action = BiTreeTaskReflection(*param_pack)
        elif self.core_memory == "tree":
            self.init_action = TreeInit(*param_pack)
            self.planning_action = TreeTaskPlanning(*param_pack)
            self.selection_action = TreeToolSelection(*api_param_pack)
            self.execution_action = TreeToolExecution(*api_param_pack)
            self.tool_reflection_action = TreeToolReflection(*param_pack)
            self.task_reflection_action = TreeTaskReflection(*param_pack)
        else:
            raise Exception("Not Support Memory: {}".format(self.core_memory))

    def _get_current_action(self, agent_state):

        if agent_state == ActionStateValue.TaskInit:
            return self.init_action
        elif agent_state == ActionStateValue.TaskPlanning:
            return self.planning_action
        elif agent_state == ActionStateValue.ToolSelection:
            return self.selection_action
        elif agent_state == ActionStateValue.ToolExecution:
            return self.execution_action
        elif agent_state == ActionStateValue.ToolReflection:
            return self.tool_reflection_action
        elif agent_state == ActionStateValue.TaskReflection:
            return self.task_reflection_action

    def _recover(self, state, message):
        
        loop = True
        while loop:
            loop = False
            recover_result = []
            for (agent_state, traj_id) in state.recover_loop():
                
                _action = self._get_current_action(agent_state)
                if _action is not None:
                    result = _action.run(state, message, traj_id, recover=True)
                    recover_result.append(result)

            for r in recover_result:
                if r is None or r != "leaf_node":
                    loop=True

    def pre_process(self, message):
        
        state = CorvusState(self.profile, message["session_id"])
        state.initialize_state()
        self._recover(state, message)
        return state

    def run(self, state, message):

        if state.step > self.max_trail:
            return 0
        
        self.logger.info("Run current loop for step: {} start...".format(state.step))
        
        if "chain" in self.core_memory:
            in_funcs = []
            func_args = []
            for (agent_state, traj_id) in state.run_action_loop(message):
                in_funcs.append(self._get_current_action(agent_state).run)
                func_args.append((state, message, traj_id))

            for in_func, func_arg in zip(in_funcs, func_args):
                
                # in_func(*func_arg)
                
                try:
                    in_func(*func_arg)
                except Exception as e:
                    self.logger.info("Courvus run action err {}".format(e))
            
            if state.stop():
                return 0
            
        else:
            in_funcs = []
            func_args = []
            for (agent_state, traj_id) in state.run_action_tree_loop(message):
                in_funcs.append(self._get_current_action(agent_state).run)
                func_args.append((state, message, traj_id))
            
            if len(in_funcs):

                if state.stop():
                    return 0

                for in_func, func_arg in zip(in_funcs, func_args):                
                    try:
                        in_func(*func_arg)
                    except Exception as e:
                        self.logger.info("Courvus run action err {}".format(e))
                
                # self.action_runner.run(in_funcs, func_args)
                # try:
                #     self.action_runner.run(in_funcs, func_args)
                # except Exception as e:
                #     self.logger.info("Courvus run action err {}".format(e))
            else:
                return 0

        self.logger.info("Run current loop for step: {} down...".format(state.step))
        return 1

    def post_process(self, state, message):
        
        converstations = self.build_converstation(state, message)
        self.conv.extend(converstations)

        trajs = self.build_trajectories(state, message)
        self.traj.extend(trajs)
        
        self.processed.value += 1
        
    def build_trajectories(self, state, message):
        if "chain" in self.core_profile['memory']:
            return self.get_chain_trajectories(state, message)
        elif "tree" in self.core_profile['memory']:
            return self.get_tree_trajectories(state, message)
        else:
            self.logger.info("Not support {} memory for Trajectories".format(self.core_profile['memory']))
            return []
    
    def build_converstation(self, state, message):
        
        if "chain" in self.core_profile['memory']:
            return self.build_chain_converstation(state, message)
        elif "tree" in self.core_profile['memory']:
            return self.build_tree_converstation(state, message)
        else:
            self.logger.info("Not support {} memory for conversation".format(self.core_profile['memory']))
            return []
        
    def build_chain_converstation(self, state, message):
        # for teacher forcing
        converstations = []
        all_paths = state.trajectories.get_all_paths(filter_stop=True)
        final_reward = 0

        reward_traj = {}
        for paths in all_paths:
            
            for traj in paths[:-1]:
                memory = traj.get_memory()
                if traj.agent_state == ActionStateValue.TaskPlanning:
                    
                    plan_conv = {
                        "query_id": message['session_id'],
                        "state": "planning",
                        "reward": memory["reward"],
                        "free": memory['free_plan'],
                        "oracle": memory['oracle_plan'],
                        "conv": memory['conv'],
                    }
                    converstations.append(plan_conv)
                    reward_traj['{}{}'.format("plan", len(reward_traj))] = memory["reward"]
                    if memory['format_conv'] is not None:
                        format_conv = {
                            "query_id": message['session_id'],
                            "state": "planning_format",
                            "reward": memory["reward"],
                            "free": memory['free_plan'],
                            "oracle": memory['oracle_plan'],
                            "conv": memory['format_conv'],
                        }
                        converstations.append(format_conv)
                elif traj.agent_state == ActionStateValue.ToolSelection:
                    
                    api_conv = {
                        "query_id": message['session_id'],
                        "state": "api",
                        "reward": memory["api_reward"],
                        "free": memory['free_api'],
                        "oracle": memory['oracle_api'],
                        "conv": memory['api_conv'],
                    }
                    converstations.append(api_conv)
                    reward_traj['{}{}'.format("api", len(reward_traj))] = memory["api_reward"]

                    parameter_conv = {
                        "query_id": message['session_id'],
                        "state": "parameter",
                        "reward": memory["parameter_reward"],
                        "free": memory['free_parameter'],
                        "oracle": memory['oracle_parameter'],
                        "conv": memory['param_conv'],
                    }
                    converstations.append(parameter_conv)
                    reward_traj['{}{}'.format("param", len(reward_traj))] = memory["parameter_reward"]
                    if memory['api_format_conv'] is not None:
                        format_conv = {
                            "query_id": message['session_id'],
                            "state": "api_format",
                            "reward": memory["api_reward"],
                            "free": memory['free_api'],
                            "oracle": memory['oracle_api'],
                            "conv": memory['api_format_conv'],
                        }
                        converstations.append(format_conv)
                    if memory['param_format_conv'] is not None:
                        format_conv = {
                            "query_id": message['session_id'],
                            "state": "parameter_format",
                            "reward": memory["parameter_reward"],
                            "free": memory['free_parameter'],
                            "oracle": memory['oracle_parameter'],
                            "conv": memory['param_format_conv'],
                        }
                        converstations.append(format_conv)
                elif traj.agent_state == ActionStateValue.ToolReflection or traj.agent_state == ActionStateValue.TaskReflection:
                    summary_conv = {
                        "query_id": message['session_id'],
                        "state": "summary",
                        "reward": memory["reward"],
                        "free": memory['free_summary'],
                        "oracle": memory['oracle_summary'],
                        "conv": memory['summary_conv'],
                    }
                    converstations.append(summary_conv)

                    reflection_conv = {
                        "query_id": message['session_id'],
                        "state": "reflection",
                        "reward": memory["reward"],
                        "free": memory['free_reflection'],
                        "oracle": memory['oracle_reflection'],
                        "conv": memory['reflection_conv'],
                    }
                    converstations.append(reflection_conv)
                    reward_traj['{}{}'.format("reflec", len(reward_traj))] = memory["reward"]

                    if traj.agent_state == ActionStateValue.TaskReflection and memory["reward"] == 1 and memory['oracle_reflection']['status'] == 0:
                        final_reward = 1
                    
                    if memory['reflection_format_conv'] is not None: 
                        format_conv = {
                            "query_id": message['session_id'],
                            "state": "reflection_format",
                            "reward": memory["reward"],
                            "free": memory['free_reflection'],
                            "oracle": memory['oracle_reflection'],
                            "conv": memory['reflection_format_conv'],
                        }
                        converstations.append(format_conv)
        
        for conv in converstations:
            conv['final_reward'] = final_reward
            conv['reward_traj'] = reward_traj
        return converstations
        
    def build_tree_converstation(self, state, message):
        # for teacher forcing
        
        all_paths = state.trajectories.get_all_paths(filter_stop=True)
        # if len(all_paths) == 0:
        #     all_paths = state.trajectories.get_all_paths()
        #     new_all_paths = []
        #     max_idx, max_length = 0, 0
        #     for idx, paths in enumerate(all_paths):
        #         if max_length < len(paths):
        #             max_idx = idx
        #             max_length = len(paths)
        #     new_all_paths = [all_paths[max_idx]]
        #     all_paths = new_all_paths
        
        tree_converstations = []
        # print("build_tree_converstation", len(all_paths))
        for paths in all_paths:
            
            converstations = []
            final_reward = 0
            reward_traj = {}
            for traj in paths[:-1]:
                memory = traj.get_memory()
                if traj.agent_state == ActionStateValue.TaskPlanning:
                    converstations = []

                    sbling_traj = state._retrieve_sibling(traj.get_node_id(), 'reward')
                    if sbling_traj is not None:
                        reject = sbling_traj.get_memory()['free_plan']
                    else:
                        reject = ""
                    
                    plan_conv = {
                        "query_id": message['session_id'],
                        "state": "planning",
                        "reward": memory["reward"],
                        "free": memory['free_plan'],
                        "reject": reject,
                        "oracle": memory['oracle_plan'],
                        "conv": memory['conv'],
                    }
                    converstations.append(plan_conv)

                    reward_traj['{}{}'.format("plan", len(reward_traj))] = memory["reward"]
                    if memory['format_conv'] is not None:
                        format_conv = {
                            "query_id": message['session_id'],
                            "state": "planning_format",
                            "reward": memory["reward"],
                            "free": memory['free_plan'],
                            "oracle": memory['oracle_plan'],
                            "conv": memory['format_conv'],
                        }
                        converstations.append(format_conv)
                elif traj.agent_state == ActionStateValue.ToolSelection:
                    
                    sbling_traj = state._retrieve_sibling(traj.get_node_id(), 'api_reward')
                    if sbling_traj is not None:
                        reject = sbling_traj.get_memory()['free_api']
                    else:
                        reject = ""

                    api_conv = {
                        "query_id": message['session_id'],
                        "state": "api",
                        "reward": memory["api_reward"],
                        "free": memory['free_api'],
                        "reject": reject,
                        "oracle": memory['oracle_api'],
                        "conv": memory['api_conv'],
                    }
                    converstations.append(api_conv)
                    reward_traj['{}{}'.format("api", len(reward_traj))] = memory["api_reward"]

                    sbling_traj = state._retrieve_sibling(traj.get_node_id(), 'parameter_reward')
                    if sbling_traj is not None:
                        reject = sbling_traj.get_memory()['free_parameter']
                    else:
                        reject = ""
                    
                    parameter_conv = {
                        "query_id": message['session_id'],
                        "state": "parameter",
                        "reward": memory["parameter_reward"],
                        "free": memory['free_parameter'],
                        "reject": reject,
                        "oracle": memory['oracle_parameter'],
                        "conv": memory['param_conv'],
                    }
                    converstations.append(parameter_conv)
                    reward_traj['{}{}'.format("param", len(reward_traj))] = memory["parameter_reward"]
                    if memory['api_format_conv'] is not None:
                        format_conv = {
                            "query_id": message['session_id'],
                            "state": "api_format",
                            "reward": memory["api_reward"],
                            "free": memory['free_api'],
                            "oracle": memory['oracle_api'],
                            "conv": memory['api_format_conv'],
                        }
                        converstations.append(format_conv)
                    if memory['param_format_conv'] is not None:
                        format_conv = {
                            "query_id": message['session_id'],
                            "state": "parameter_format",
                            "reward": memory["parameter_reward"],
                            "free": memory['free_parameter'],
                            "oracle": memory['oracle_parameter'],
                            "conv": memory['param_format_conv'],
                        }
                        converstations.append(format_conv)
                elif traj.agent_state == ActionStateValue.ToolReflection or traj.agent_state == ActionStateValue.TaskReflection:
                    
                    summary_conv = {
                        "query_id": message['session_id'],
                        "state": "summary",
                        "reward": memory["reward"],
                        "free": memory['free_summary'],
                        "oracle": memory['oracle_summary'],
                        "conv": memory['summary_conv'],
                    }
                    converstations.append(summary_conv)

                    sbling_traj = state._retrieve_sibling(traj.get_node_id(), 'reward')
                    if sbling_traj is not None:
                        reject = sbling_traj.get_memory()['free_reflection']
                    else:
                        reject = ""
                    
                    reflection_conv = {
                        "query_id": message['session_id'],
                        "state": "reflection",
                        "reward": memory["reward"],
                        "free": memory['free_reflection'],
                        "reject": reject,
                        "oracle": memory['oracle_reflection'],
                        "conv": memory['reflection_conv'],
                    }
                    converstations.append(reflection_conv)
                    reward_traj['{}{}'.format("reflec", len(reward_traj))] = memory["reward"]

                    if traj.agent_state == ActionStateValue.TaskReflection and memory["reward"] == 1 and memory['oracle_reflection']['status'] == 0:
                        final_reward = 1
                    
                    if memory['reflection_format_conv'] is not None: 
                        format_conv = {
                            "query_id": message['session_id'],
                            "state": "reflection_format",
                            "reward": memory["reward"],
                            "free": memory['free_reflection'],
                            "oracle": memory['oracle_reflection'],
                            "conv": memory['reflection_format_conv'],
                        }
                        converstations.append(format_conv)

            for conv in converstations:
                conv['final_reward'] = final_reward
                conv['reward_traj'] = reward_traj
            tree_converstations.extend(converstations)

        return tree_converstations

    def get_chain_trajectories(self, state, message):
        
        train_traj = []
        all_paths = state.trajectories.get_all_paths()      
        for paths in all_paths:
            
            plan_traj = []
            api_traj = []
            parameter_traj = []
            summary_traj = []
            exec_traj = []
            reflection_traj = []
            final_summary = []
            final_reflection = []

            for traj in paths[:-1]:
                memory = traj.get_memory()
                if traj.agent_state == ActionStateValue.TaskPlanning:
                    plan_traj.append(memory['free_plan'])
                elif traj.agent_state == ActionStateValue.ToolSelection:
                    api_traj.append(memory['api_json'])
                    parameter_traj.append(memory['free_parameter'])
                elif traj.agent_state == ActionStateValue.ToolExecution:
                    exec_traj.append(memory['free_observation'])
                elif traj.agent_state == ActionStateValue.ToolReflection:
                    summary_traj.append(memory['free_summary'])
                    reflection_traj.append(memory['free_reflection'])
                elif traj.agent_state == ActionStateValue.TaskReflection:
                    final_summary.append(memory['free_summary'])
                    final_reflection.append(memory['free_reflection'])
            
            train_traj.append({
                "session_id": message['session_id'],
                "query": message['query'],
                "available_tools": message['api_list'],
                "answer": {
                    "method": "corvus_{}".format(self.core_memory),
                    "total_steps": len(paths),
                    "final_answer": final_summary,
                    "plan_traj": plan_traj,
                    "api_traj": api_traj,
                    "exec_traj": exec_traj,
                    "parameter_traj": parameter_traj,
                    "answer_details": summary_traj,
                    "reflection_traj": reflection_traj,
                    "answer_reflection": final_reflection,
                },
            })

        return train_traj

    def get_tree_trajectories(self, state, message):
        
        train_traj = []
        all_paths = state.trajectories.get_all_paths(filter_stop=True)      
        for paths in all_paths:
            
            plan_traj = []
            api_traj = []
            parameter_traj = []
            summary_traj = []
            reflection_traj = []
            final_summary = []
            final_reflection = []

            for traj in paths[:-1]:
                memory = traj.get_memory()
                if traj.agent_state == ActionStateValue.TaskPlanning:
                    
                    # reset result
                    plan_traj = []
                    api_traj = []
                    parameter_traj = []
                    summary_traj = []
                    exec_traj = []
                    reflection_traj = []
                    final_summary = []
                    final_reflection = []
                    
                    plan_traj.append(memory['free_plan'])
                elif traj.agent_state == ActionStateValue.ToolSelection:
                    api_traj.append(memory['free_api'])
                    parameter_traj.append(memory['free_parameter'])
                elif traj.agent_state == ActionStateValue.ToolExecution:
                    exec_traj.append(memory['free_observation'])
                elif traj.agent_state == ActionStateValue.ToolReflection:
                    summary_traj.append(memory['free_summary'])
                    reflection_traj.append(memory['free_reflection'])
                elif traj.agent_state == ActionStateValue.TaskReflection:
                    final_summary.append(memory['free_summary'])
                    final_reflection.append(memory['free_reflection'])
            
            train_traj.append({
                "session_id": message['session_id'],
                "query": message['query'],
                "available_tools": message['api_list'],
                "answer": {
                    "method": "corvus_{}".format(self.core_memory),
                    "total_steps": len(paths),
                    "final_answer": final_summary,
                    "plan_traj": plan_traj,
                    "api_traj": api_traj,
                    "exec_traj": exec_traj,
                    "parameter_traj": parameter_traj,
                    "answer_details": summary_traj,
                    "reflection_traj": reflection_traj,
                    "answer_reflection": final_reflection,
                },
            })
        
        return train_traj
