from ..zipformer.model_config import ZipformerConfig
from dataclasses import dataclass, asdict
from typing import Optional

@dataclass
class ZipformerForAudioCaptioningConfig(ZipformerConfig):
    name: str = "zipformer-captioning"
    text_tokenizer_type: str = "t5-small"
    decoder_nhead: int = 8
    decoder_bias: bool = False
    num_decoder_layers: int = 6
    decoder_shared_emb: bool = False
    decoder_dropout: float = 0.1
    decoder_activation: str = "gelu"
    decoder_norm_first: bool = False
    label_smoothing: float = 0.1

    @classmethod
    def from_preset(cls, preset: str, **kwargs):
        presets = {
            "base": {
                **asdict(ZipformerConfig.from_preset("base")),
                "name": "zipformer-captioning",
            },
            "large": {
                **asdict(ZipformerConfig.from_preset("large")),
                "name": "zipformer-captioning",
            }
        }

        if preset not in presets:
            raise ValueError(
                f"Unsupported preset '{preset}' for {cls.__name__}. Supported: {list(presets.keys())}"
            )

        config_dict = presets[preset].copy()
        config_dict.update(kwargs)
        return cls(**config_dict)
    
@dataclass
class ZipformerForAudioCaptioningWithMaskingConfig(ZipformerForAudioCaptioningConfig):
    name: str = 'zipformer-masked-captioning'
    parallel_decoding_prob: float = 0.75

    @classmethod
    def from_preset(cls, preset: str, **kwargs):
        presets = {
            "base": {
                **asdict(ZipformerConfig.from_preset("base")),
                "name": "zipformer-masked-captioning",
            },
            "large": {
                **asdict(ZipformerConfig.from_preset("large")),
                "name": "zipformer-masked-captioning",
            }
        }

        if preset not in presets:
            raise ValueError(
                f"Unsupported preset '{preset}' for {cls.__name__}. Supported: {list(presets.keys())}"
            )

        config_dict = presets[preset].copy()
        config_dict.update(kwargs)
        return cls(**config_dict)