from Common import Config


def conv_shape(input, kernel_size, stride, padding=0):
    return (input + 2 * padding - kernel_size) // stride + 1


class ETDConfig(Config):
    future_sample_gamma:float=0.99
    int_adv_coeff:float=0.01
    etd_use_int_reward_rms:bool=True
    model_norm_type:int=2 #0:nonorm 1:batchnorm 2:layernorm
    adv_norm_type:int=0
    etd_int_reward_momentum:float=0.9
    etd_model_kwargs:dict={}
    # etd_model_kwargs:dict={
    #     "etd_mlp_arch":[1024,128],
    #     "etd_encoder_arch":[128,256,64]
    # }

