import copy

from Common import Config


class Slav3Config(Config):
    rnd_hidden_size:int
    rnd_use_reward_norm: bool = False
    rnd_state_norm:bool=False
    rnd_int_value_non_episodic:bool=False

    sla_reward:float=0.001 # sla heuristic scale

    sla_value_state_type:int=2
    sla_value_net_arch:dict

    sla_use_erir_mask: bool = False
    sla_erir_coeff:float=0.05
    sla_rnd_lb: float = 0.06
    sla_shortest_coeff: float = 1000 # \epsilon of leaky_relu
    sla_rnd_inte_type:float=0 #0: no rnd 1: rnd modulator
    sla_value_model_type:int=0 #0: value model 1: RND model
    sla_reward_clip:float=-1


    def get_sla_state_str(self,is_next_state=False):
        str_list= ["next_states","obs_next_states","cell_next_states"]if is_next_state \
            else ["states","obs_states","cell_states"]
        return str_list[self.sla_value_state_type]


    def update_spec_net_arch(self):
        base_net_archs=[self.feature_net_arch,self.obs_net_arch,self.cell_net_arch]
        self.sla_value_net_arch = copy.deepcopy(base_net_archs[self.sla_value_state_type])
        del self.sla_value_net_arch["mlp"][1:-1]




