from transformers import WhisperConfig


class WhisperVQConfig(WhisperConfig):
    def __init__(self,
                 pooling_kernel_size=None,
                 pooling_type="max",
                 pooling_position=0,
                 quantize_vocab_size=None,
                 quantize_position=16,
                 quantize_commit_coefficient=0.25,
                 quantize_loss_scale=1.0,
                 quantize_ema_decay=None,
                 quantize_restart_interval=None,
                 quantize_encoder_only=False,
                 quantize_causal_encoder=False,
                 quantize_causal_block_size=None,
                 skip_language_detection=False,
                 encoder_causal_attention=False,
                 encoder_causal_convolution=False,
                 **kwargs):
        self.pooling_kernel_size = pooling_kernel_size
        self.pooling_type = pooling_type
        self.pooling_position = pooling_position
        self.quantize_vocab_size = quantize_vocab_size
        self.quantize_position = quantize_position
        self.quantize_commit_coefficient = quantize_commit_coefficient
        self.quantize_loss_scale = quantize_loss_scale
        self.quantize_ema_decay = quantize_ema_decay
        self.quantize_restart_interval = quantize_restart_interval
        self.quantize_encoder_only = quantize_encoder_only
        self.quantize_causal_encoder = quantize_causal_encoder
        self.quantize_causal_block_size = quantize_causal_block_size
        self.skip_language_detection = skip_language_detection
        self.encoder_causal_attention = encoder_causal_attention
        self.encoder_causal_convolution = encoder_causal_convolution
        super().__init__(**kwargs)
