from Algorithms.Common.workflow_common import BaseWorkflow
from Algorithms.Slav3.brain_slav3 import BrainSlav3, Slav3WorkflowController
from Algorithms.Slav3.config_slav3 import Slav3Config
from Algorithms.Slav3.slav3_runner import Slav3Worker



class SlaV3WorkFlow(BaseWorkflow):
    desc:str="SlaV3-"
    brain:BrainSlav3
    config:Slav3Config
    wf_controller:Slav3WorkflowController

    def init_wf_controller(self):
        self.wf_controller=Slav3WorkflowController(self.config)

    def get_brain_cls(self):
        return BrainSlav3

    def get_worker_cls(self):
        return Slav3Worker


    def roll_t_spec(self,worker_id,t,s_,info):
        self.wf_controller.sla_is_starts[worker_id, t] = (info['episode_step'] == 1)

    def get_intrinsic_post(self,info):
        self.wf_controller.calc_sla_gae()
    def roll_extra(self,t):
        (
            self.wf_controller.sla_values[:, t],
            self.wf_controller.sla_next_values[:, t]
        ) = self.brain.calculate_sla_value(self.wf_controller.get_sla_value_state(is_next_state=False)[:, t],
                                              self.wf_controller.get_sla_value_state(is_next_state=True)[:, t])
        for worker_id, parent, sla_value_s_ in zip(
            range(len(self.parents)),
            self.parents,
            self.wf_controller.sla_next_values[:, t],
        ):
            parent.send(sla_value_s_)
        for worker_id, parent in enumerate(self.parents):
            sla_int_reward,sla_ret = parent.recv()
            self.wf_controller.sla_int_rewards[worker_id, t] = sla_int_reward
            self.wf_controller.episode_steps[worker_id, t] = sla_ret






def main_slav3(config:Slav3Config):
    WF=SlaV3WorkFlow(config)
    WF.run()
