import json
from dataclasses import dataclass
from typing import Dict, Tuple, List, Optional

from numpy import extract


@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int


_MODEL_DIMS = {
    "tiny.en": ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=384, n_audio_head=6, n_audio_layer=4, n_vocab=51864, 
                               n_text_ctx=448, n_text_state=384, n_text_head=6, n_text_layer=4),
    "tiny": ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=384, n_audio_head=6, n_audio_layer=4, n_vocab=51865, 
                            n_text_ctx=448, n_text_state=384, n_text_head=6, n_text_layer=4),
    "base.en": ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=512, n_audio_head=8, n_audio_layer=6, n_vocab=51864, 
                               n_text_ctx=448, n_text_state=512, n_text_head=8, n_text_layer=6),
    "base": ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=512, n_audio_head=8, n_audio_layer=6, n_vocab=51865, 
                            n_text_ctx=448, n_text_state=512, n_text_head=8, n_text_layer=6),
    "small.en": ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=768, n_audio_head=12, n_audio_layer=12, n_vocab=51864, 
                                n_text_ctx=448, n_text_state=768, n_text_head=12, n_text_layer=12),
    "small": ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=768, n_audio_head=12, n_audio_layer=12, n_vocab=51865, 
                             n_text_ctx=448, n_text_state=768, n_text_head=12, n_text_layer=12),
    "medium.en": ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=1024, n_audio_head=16, n_audio_layer=24, n_vocab=51864, 
                                 n_text_ctx=448, n_text_state=1024, n_text_head=16, n_text_layer=24),
    "medium": ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=1024, n_audio_head=16, n_audio_layer=24, n_vocab=51865, 
                              n_text_ctx=448, n_text_state=1024, n_text_head=16, n_text_layer=24),
    "large-v2": ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=1280, n_audio_head=20, n_audio_layer=32, n_vocab=51865, 
                                n_text_ctx=448, n_text_state=1280, n_text_head=20, n_text_layer=32),
    "large-v3": ModelDimensions(n_mels=128, n_audio_ctx=1500, n_audio_state=1280, n_audio_head=20, n_audio_layer=32, n_vocab=51866, 
                                n_text_ctx=448, n_text_state=1280, n_text_head=20, n_text_layer=32),
}


@dataclass
class WhisperConfig:
    model_name: str = 'medium'
    dims: ModelDimensions = None
    apply_padding_mask: bool = False
    is_multilingual: bool = True
    context_size: int = 448
    blank_id: int = 50256
    vocab_size: int = 51865
    num_languages: int = 99
    extra_languages: Optional[Dict] = None
    extra_tokens: Optional[Tuple] = None

    def __post_init__(self):
        assert self.model_name in _MODEL_DIMS, f"{self.model_name} not found."
        self.dims = _MODEL_DIMS[self.model_name]

        self.extra_tokens = self.extra_tokens or tuple()
        self.extra_languages = self.extra_languages or dict()

        if '.en' in self.model_name:
            self.is_multilingual = False

        self.num_languages += len(self.extra_languages)
        self.vocab_size += len(self.extra_languages)
        self.vocab_size += len(self.extra_tokens)

    def to_dict(self):
        """Convert config to dictionary."""
        return self.__dict__

    def to_json(self, json_path: str):
        """Save config to a JSON file."""
        with open(json_path, "w") as f:
            json.dump(self.to_dict(), f, indent=4)
