import numpy as np

from Algorithms.Common.workflow_common import BaseWorkflow
from Algorithms.RIDE.brain_ride import RIDEWorkflowController
from Algorithms.RIDE.brain_ride import BrainRIDE
from Algorithms.RIDE.config_ride import RIDEConfig


def run_workers(worker, conn):
    worker.step(conn)


class RIDEWorkFlow(BaseWorkflow):
    desc: str = "RIDE-"
    brain: BrainRIDE
    config: RIDEConfig
    wf_controller: RIDEWorkflowController

    def init_wf_controller(self):
        self.wf_controller = RIDEWorkflowController(self.config)

    def get_brain_cls(self):
        return BrainRIDE

    def roll_t_spec(self,worker_id,t,s_,info):
        self.wf_controller.erir_mask[worker_id,t] = 1/np.sqrt(info["eps_nvisit"])



def main_ride(config:RIDEConfig):
    WF=RIDEWorkFlow(config)
    WF.run()
