from dataclasses import asdict, dataclass, field
from typing import Dict, List

from .utils.coqpit import MISSING
from .utils.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig


@dataclass
class SpeakerEncoderConfig(BaseTrainingConfig):
    """Defines parameters for Speaker Encoder model."""

    model: str = "speaker_encoder"
    audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
    datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
    # model params
    model_params: Dict = field(
        default_factory=lambda: {
            "model_name": "lstm",
            "input_dim": 80,
            "proj_dim": 256,
            "lstm_dim": 768,
            "num_lstm_layers": 3,
            "use_lstm_with_projection": True,
        }
    )

    audio_augmentation: Dict = field(default_factory=lambda: {})

    storage: Dict = field(
        default_factory=lambda: {
            "sample_from_storage_p": 0.66,  # the probability with which we'll sample from the DataSet in-memory storage
            "storage_size": 15,  # the size of the in-memory storage with respect to a single batch
        }
    )

    # training params
    max_train_step: int = 1000000  # end training when number of training steps reaches this value.
    loss: str = "angleproto"
    grad_clip: float = 3.0
    lr: float = 0.0001
    lr_decay: bool = False
    warmup_steps: int = 4000
    wd: float = 1e-6

    # logging params
    tb_model_param_stats: bool = False
    steps_plot_stats: int = 10
    checkpoint: bool = True
    save_step: int = 1000
    print_step: int = 20

    # data loader
    num_speakers_in_batch: int = MISSING
    num_utters_per_speaker: int = MISSING
    num_loader_workers: int = MISSING
    skip_speakers: bool = False
    voice_len: float = 1.6

    def check_values(self):
        super().check_values()
        c = asdict(self)
        assert (
            c["model_params"]["input_dim"] == self.audio.num_mels
        ), " [!] model input dimendion must be equal to melspectrogram dimension."
