from .action import APIAction
import uuid


class ChainInit(APIAction):
    
    name: str = "chain_init"
    fields: list = ["sess_id", "step", "traj_id", name]
    
    def __init__(self, profile, policy, handler):
        super(ChainInit, self).__init__(profile, handler)
        self.policy = policy
        self.beam = 1

    def run(self, state, message, traj_id, recover=False):
        
        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"
            # pass
            state.update(traj_id, {"child_id": [str(uuid.uuid4())]})
            self.save_serialized_obj(state, traj_id)
        
        traj = state.get(traj_id)
        traj.set_trans_param("*")
        state.next_traj(traj, traj.get_memory()['child_id'][0])
        
        self.logger.info("[assistant] step-{} {}: start beam={}...".format(str(state.step), self.name, self.beam))


class TreeInit(ChainInit):
    
    name: str = "chain_init"
    fields: list = ["sess_id", "step", "traj_id", name]
    
    def __init__(self, profile, policy, handler):
        super(TreeInit, self).__init__(profile, policy, handler)
        self.beam = 2
    
    def run(self, state, message, traj_id, recover=False):
        
        if not self.load_serialized_obj(state, traj_id):
            if recover: return "leaf_node"
            state.update(traj_id, {"child_id": [str(uuid.uuid4()) for _ in range(self.beam)]})
            self.save_serialized_obj(state, traj_id)
        
        traj = state.get(traj_id)
        # reset the transfer parameter
        traj.set_trans_param("*")
        for i in range(self.beam):
            state.next_traj(traj, traj.get_memory()['child_id'][i])
        self.logger.info("[assistant] step-{} {}: start beam={}...".format(str(state.step), self.name, self.beam))
