from Algorithms.Slav3.config_slav3 import Slav3Config
from Common.runner import Worker

from Common.utils import *




class Slav3Worker(Worker):
    config:Slav3Config
    sla_ret:float
    def reset(self):
        super().reset()
        self.reset_spec()


    def reset_spec(self):
        self.reset_buffer()

    def reset_buffer(self):
        self.last_sla_value_buffer = []
        self.sla_ret = 0

    def update_erir(self):
        if not self.config.sla_use_erir_mask:
            return 1,1
        return super().update_erir()

    def step_slav3(self,sla_value_s_):
        reward_clip_max=1000 if self.config.sla_reward_clip<0 else self.config.sla_reward_clip
        sla_reward=np.clip(sla_value_s_-np.max(self.last_sla_value_buffer),0,reward_clip_max) if self.last_sla_value_buffer else 0.0
        self.last_sla_value_buffer.append(sla_value_s_)
        return sla_reward,{}

    def step_extra(self,conn):
        sla_value_s_ = conn.recv()
        sla_reward, sla_info = self.step_slav3(sla_value_s_)
        self.sla_ret+=1
        conn.send((sla_reward,self.sla_ret))






