from .finite_state import FiniteState, ActionStateValue, TaskStateValue
from .traj import TreeTrajectory
from utils import api_util
from collections import Counter
from enum import Enum
import math
import uuid
import json
import random


class CorvusState(FiniteState):
    
    def __init__(self, profile, session_id):
        FiniteState.__init__(self, profile)
        # Reward.__init__(self)

        self.trajectories = None
        self._transition = self.initialize_transition()
        self._step = 0

        self.session_id = session_id
        self.max_frontier = 2
    
    @property
    def step(self):
        return self._step

    def get_session_id(self):
        return self.session_id

    def get(self, traj_id):
        return self.trajectories.get(traj_id)

    def update(self, traj_id, new_traj_cont, trans_param=None):
        '''
            only update the traj content in memory
        '''
        self.trajectories.update(traj_id, new_traj_cont, trans_param)
    
    def update_local(self, traj_id, new_traj_cont):
        self.trajectories.update(traj_id, new_traj_cont, local=True)
    
    def dumps(self, traj_id):
        '''
            convert the Traj object into a string.
        '''
        return json.dumps(self.trajectories.get_memory(traj_id))

    def loads(self, traj_id, serial):
        '''
            convert the serialized string into a Traj object.
        '''
        self.trajectories.update(traj_id, json.loads(serial))

    def next_traj(self, parent_traj, traj_id, param_filter=[0]):
        task_state, agent_state, trans_param = self.state_transition(*parent_traj.get_transition(), param_filter=param_filter)
        self.trajectories.next_node(task_state, agent_state, trans_param, parent_traj, traj_id)

    def retrieve_api(self, traj_id, neighbor=False):
        api = self._retrieve(traj_id, xfilter={ReactActionValue.ToolSelection})
        if neighbor:
            return api[-1]
        else:
            return api
    
    def state_transition(self, t_state, a_state, trans_param, param_filter):
        '''
            state transition function: to obtain the next execution state.
        '''
        if trans_param not in param_filter: trans_param = "*"
        task_state, agent_state = self._transition[("?", a_state, trans_param)]
        if task_state == "?": task_state = t_state
        return task_state, agent_state, trans_param
    
    def initialize_transition(self):
        
        return {
            ("?", ActionStateValue.TaskInit, "*"): ("?", ActionStateValue.TaskPlanning),
            ("?", ActionStateValue.TaskPlanning, "*"): ("?", ActionStateValue.ToolSelection),
            # ("?", ActionStateValue.ToolSelection, "*"): (TaskStateValue.Solved, "*"),
            ("?", ActionStateValue.ToolSelection, "*"): ("?", ActionStateValue.ToolExecution),
            ("?", ActionStateValue.ToolExecution, "*"): ("?", ActionStateValue.ToolReflection),
            
            ("?", ActionStateValue.ToolReflection, 0):(TaskStateValue.PartiallySolved,ActionStateValue.ToolSelection),
            ("?", ActionStateValue.ToolReflection, 1):(TaskStateValue.PartiallySolved,ActionStateValue.ToolSelection),
            ("?", ActionStateValue.ToolReflection, 2):(TaskStateValue.PartiallySolved,ActionStateValue.ToolSelection),
            ("?", ActionStateValue.ToolReflection, 3): (TaskStateValue.PartiallySolved, ActionStateValue.TaskReflection),
            
            ("?", ActionStateValue.TaskReflection, 1): (TaskStateValue.UnSolved, ActionStateValue.TaskPlanning),
            ("?", ActionStateValue.TaskReflection, 0): (TaskStateValue.Solved, "*"),
        }

    def initialize_state(self):
        
        task_state = TaskStateValue.UnSolved
        agent_state = ActionStateValue.TaskInit
        self.trajectories = TreeTrajectory(task_state, agent_state, "*", self.session_id)
    
    def recover_loop(self):

        self._step += 1
        actions = []
        for traj in self.trajectories.find_all_leaves():
            task_state, agent_state, trans_param = traj.get_transition()
            actions.append((agent_state, traj.get_node_id()))
        return actions

    def _cost_func(self, message, traj):
        
        traj_id = traj.get_node_id()
        traj_path = self.trajectories.find_ancestors(traj_id)

        g = len(traj_path)

        if 'golden' in message:
            
            minimum_cost = {"final": 1}
            for golden in message['golden']:
                if golden["conv"]["ID"] not in minimum_cost:
                    minimum_cost[golden["conv"]["ID"]] = 3
                else:
                    minimum_cost[golden["conv"]["ID"]] += 3
            
            for tp in traj_path:
                memory = tp.get_memory()
                t_id = tp.get_node_id()
                if tp.agent_state == ActionStateValue.ToolSelection:
                    
                    if len(memory) and "api_reward" in memory and "parameter_reward" in memory:
                        s_id = memory['free_api']["ID"]
                        if 2 == memory['api_reward'] + memory['parameter_reward'] and s_id in minimum_cost and minimum_cost[s_id] > 0:
                            minimum_cost[s_id] -= 1
                elif tp.agent_state == ActionStateValue.TaskReflection:
                    if len(memory) and "reward" in memory:
                        if 1 == memory['reward'] and minimum_cost["final"] > 0:
                            minimum_cost["final"] -= 1
                elif tp.agent_state == ActionStateValue.ToolReflection or tp.agent_state == ActionStateValue.ToolExecution:
                    s_m = self.retrieve_api(t_id, neighbor=True).get_memory()
                    if 'free_api' in s_m:
                        s_id = s_m['free_api']["ID"]
                        if len(memory) and "reward" in memory:
                            if 1 == memory['reward'] and s_id in minimum_cost and minimum_cost[s_id] > 0:
                                minimum_cost[s_id] -= 1
            
            h = sum(list(minimum_cost.values())) * 1.01
        else:
            h = 0
        return g, h

    def _ranking(self, message, frontier):
        
        cost = []
        for traj in frontier:
            g, h = self._cost_func(message, traj)
            # print(g, h)
            f = g + h
            cost.append(f)
        return cost, [index for index, value in sorted(enumerate(cost), key=lambda x: x[1])]

    def run_action_tree_loop(self, message):

        self._step += 1

        # 1. collect tree candidate node
        frontier = self.trajectories.find_all_leaves()
        random.shuffle(frontier)
        
        # print(len(frontier))
        # 2. node ranking
        if len(frontier) > self.max_frontier:
            cost, sorted_indices = self._ranking(message, frontier)
            # print(cost, sorted_indices)
        
        selected_frontier = []
        for traj_idx, traj in enumerate(frontier):
            task_state, agent_state, trans_param = traj.get_transition()
            if traj.task_state == TaskStateValue.Solved:
                continue
            if len(frontier) > self.max_frontier:
                if traj_idx not in sorted_indices[:self.max_frontier]:
                    continue
            selected_frontier.append((agent_state, traj.get_node_id()))
        
        assert len(selected_frontier) <= self.max_frontier
        return selected_frontier

    def run_action_loop(self, message):

        self._step += 1

        frontier = self.trajectories.find_leaves()
        selected_frontier = []
        for traj_idx, traj in enumerate(frontier):
            task_state, agent_state, trans_param = traj.get_transition()
            selected_frontier.append((agent_state, traj.get_node_id()))
        return selected_frontier

    def stop(self):

        traj_leaves = self.trajectories.find_all_leaves()
        if len(traj_leaves) == 0:
            return True
        
        stop_count = 0
        for traj in traj_leaves:
            if traj.task_state == TaskStateValue.Solved:
                stop_count += 1
        print("stop_count: ", self._step, stop_count, len(traj_leaves))
        return stop_count > 0
    
    def _retrieve(self, 
        traj_id,
        xfilter={
            ActionStateValue.TaskPlanning,
            ActionStateValue.ToolSelection,
            ActionStateValue.ToolExecution,
            ActionStateValue.ToolReflection,
            ActionStateValue.TaskReflection,
        }
    ):
        traj_path = self.trajectories.find_ancestors(traj_id)
        plan_list = []
        for traj in traj_path:
            if traj.agent_state in xfilter:
                plan_list.append(traj)
        return plan_list

    def _retrieve_boundary(self, 
        traj_id,
        xfilter={
            ActionStateValue.TaskPlanning,
            ActionStateValue.ToolSelection,
            ActionStateValue.ToolExecution,
            ActionStateValue.ToolReflection,
            ActionStateValue.TaskReflection,
        },
        boundary={
            ActionStateValue.TaskPlanning,
        }
    ):
        traj_path = self.trajectories.find_ancestors(traj_id)
        plan_list = []
        for traj in traj_path[::-1]:
            if traj.agent_state in boundary:
                break
            if traj.agent_state in xfilter:
                plan_list.append(traj)
        return plan_list

    def _retrieve_path(self, traj_id):
        return self.trajectories.find_ancestors(traj_id)

    def find_neighbor_ancestors(self, plan_traj_id, traj_id):
        
        traj_path = self.trajectories.find_ancestors(traj_id)
        ancestors_node_list = []
        for traj in traj_path[::-1]:
            
            if traj.get_node_id() == traj_id:
                continue
            if traj.get_node_id() == plan_traj_id:
                ancestors_node_list.append(traj)
                break
            ancestors_node_list.append(traj)
        return ancestors_node_list
    
    # for tool selection and reflection
    def get_cur_subplan_idx(self, traj_id):
        
        plan_traj = self.retrieve_plan(traj_id, neighbor=True)
        local_plan = plan_traj.get_memory(local=True)

        if traj_id not in local_plan:
            # 1. traverse the ancestor nodes to find the nearest parent node
            ancestors_node_list = self.find_neighbor_ancestors(plan_traj.get_node_id(), traj_id)
            for ancestors_node in ancestors_node_list:
                ancestors_node_id = ancestors_node.get_node_id()
                if ancestors_node_id in local_plan:
                    # 2. obtain the plan index based on the parent node
                    local_plan[traj_id] = {
                        "plan_idx": local_plan[ancestors_node_id]['plan_idx'],
                        "exec_count": local_plan[ancestors_node_id]['exec_count'],
                    }
                    self.update_local(traj_id, local_plan)
                    break
        
        return local_plan[traj_id]['plan_idx']

    def get_cur_subplan(self, traj_id):
        
        plan_traj = self.retrieve_plan(traj_id, neighbor=True)
        plan = plan_traj.get_memory()
        plan_idx = self.get_cur_subplan_idx(traj_id)
        return plan['plan'][plan_idx]
    
    # for tool execution and tool reflection (to cacluate reward) and tool selection to recover plan
    def retrieve_api(self, traj_id, neighbor=False):
        api = self._retrieve(traj_id, xfilter={ActionStateValue.ToolSelection})
        if neighbor:
            if len(api):
                return api[-1]
            else:
                return None
        else:
            return api
    
    # for tool reflection
    def retrieve_observation(self, traj_id, neighbor=False):
        observation = self._retrieve(traj_id, xfilter={ActionStateValue.ToolExecution})
        if neighbor:
            if len(observation):
                return observation[-1]
            else:
                return None
        else:
            return observation

    def retrieve_plan(self, traj_id, neighbor=False):
        plans = self._retrieve(traj_id, xfilter={ActionStateValue.TaskPlanning})
        if neighbor:
            return plans[-1]
        else:
            return plans

    def get_local_plan(self, traj_id, state):
        plan_traj = state.retrieve_plan(traj_id, neighbor=True)
        local_plan = plan_traj.get_memory(local=True)
        if traj_id not in local_plan:
            # 1. traverse the ancestor nodes to find the nearest parent node
            ancestors_node_list = state.find_neighbor_ancestors(plan_traj.get_node_id(), traj_id)
            for ancestors_node in ancestors_node_list:
                ancestors_node_id = ancestors_node.get_node_id()
                if ancestors_node_id in local_plan:
                    # 2. obtain the plan index based on the parent node
                    local_plan[traj_id] = {
                        "plan_idx": local_plan[ancestors_node_id]['plan_idx'],
                        "exec_count": local_plan[ancestors_node_id]['exec_count'],
                    }
                    state.update_local(traj_id, local_plan)
                    break
        return local_plan
    
    # for task reflection
    def retrieve_reflection(self, traj_id, neighbor=False):
        relection = self._retrieve(traj_id, xfilter={ActionStateValue.ToolReflection})
        if neighbor:
            if len(relection):
                return relection[-1]
            else:
                return None
        else:
            return relection

    def retrieve_reflection_boundary(self, traj_id):
        relection = self._retrieve_boundary(
            traj_id, 
            xfilter={ActionStateValue.ToolReflection},
            boundary={ActionStateValue.TaskPlanning}
        )
        return relection

    # for task planning (to regen plan)
    def retrieve_task_reflection(self, traj_id, neighbor=False):
        relection = self._retrieve(traj_id, xfilter={ActionStateValue.TaskReflection})
        if neighbor:
            if len(relection):
                return relection[-1]
            return None
        else:
            return relection
    
    def is_previous_plan(self, traj_id):
        cur_traj = self.get(traj_id)
        prev_traj = self.get(cur_traj.get_parent_node_id())
        if prev_traj.agent_state == ActionStateValue.TaskPlanning:
            return True
        else:
            return False
    
    def _retrieve_sibling(self, traj_id, prefix):

        traj = self.get(traj_id)
        p_id = traj.get_parent_node_id()
        
        for sibling_id in self.get(p_id).children:
            if sibling_id != traj_id:
                sbling_traj = self.get(sibling_id)
                memory = sbling_traj.get_memory()
                if prefix in memory:
                    if memory[prefix] < 1:
                        return sbling_traj
        return None
