from Algorithms.Common.workflow_common import BaseWorkflow
from Algorithms.ETD.brain_etd import BrainETD, ETDWorkflowController
from Algorithms.ETD.config_etd import ETDConfig
from Algorithms.ETD.etd_runner import ETDWorker

def run_workers(worker, conn):
    worker.step(conn)


class ETDWorkFlow(BaseWorkflow):
    desc: str = "ETD-"
    brain: BrainETD
    config: ETDConfig
    wf_controller: ETDWorkflowController

    def init_wf_controller(self):
        self.wf_controller = ETDWorkflowController(self.config)

    def get_worker_cls(self):
        return ETDWorker

    def get_brain_cls(self):
        return BrainETD

    def roll_extra(self,t):
        (
            self.wf_controller.etd_emb_states[:,t],
            self.wf_controller.etd_emb_next_states[:,t]
        )=self.brain.get_etd_embedding(
            self.wf_controller.states[:,t],
            self.wf_controller.next_states[:,t],
            batch=True
        )
        for worker_id,parent, emb_s,emb_s_ in zip(
            range(len(self.parents)),
            self.parents,
            self.wf_controller.etd_emb_states[:,t],
            self.wf_controller.etd_emb_next_states[:,t]
        ):
            parent.send((emb_s, emb_s_))
        for worker_id,parent in enumerate(self.parents):
            etd_int_reward=parent.recv()
            self.wf_controller.etd_dists[worker_id,t]=etd_int_reward

    def rollout_post(self):
        super(ETDWorkFlow, self).rollout_post()
        self.wf_controller.get_future_states()


def main_etd(config:ETDConfig):
    WF=ETDWorkFlow(config)
    WF.run()
