import io
import logging
import os
from typing import Optional, Union

import soundfile as sf
import torch
from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models
from whisper.audio import log_mel_spectrogram
from whisper.model import ModelDimensions

from whisper_model import Whisper_

logger = logging.getLogger("dump_feature")


def load_model(
        name: str,
        device: Optional[Union[str, torch.device]] = None,
        download_root: str = None,
        in_memory: bool = False,
) -> Whisper_:
    """

    Parameters
    ----------
    name : str
        one of the official model names listed by `whisper.available_models()`, or
        path to a model checkpoint containing the model dimensions and the model state_dict.
    device : Union[str, torch.device]
        the PyTorch device to put the model into
    download_root: str
        path to download the model files; by default, it uses "~/.cache/whisper"
    in_memory: bool
        whether to preload the model weights into host memory

    Returns
    -------
    model : Whisper
        The Whisper ASR model instance
    """

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    if download_root is None:
        default = os.path.join(os.path.expanduser("~"), ".cache")
        download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

    if name in _MODELS:
        checkpoint_file = _download(_MODELS[name], download_root, in_memory)
        alignment_heads = _ALIGNMENT_HEADS[name]
    elif os.path.isfile(name):
        checkpoint_file = open(name, "rb").read() if in_memory else name
        alignment_heads = None
    else:
        raise RuntimeError(
            f"Model {name} not found; available models = {available_models()}"
        )

    with (
            io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
    ) as fp:
        checkpoint = torch.load(fp, map_location=device)
    del checkpoint_file

    dims = ModelDimensions(**checkpoint["dims"])
    model = Whisper_(dims)
    model.load_state_dict(checkpoint["model_state_dict"])

    if alignment_heads is not None:
        model.set_alignment_heads(alignment_heads)

    return model.to(device)


class WhisperFeatureReader(object):
    def __init__(self, root, ckpt, layer, device):
        self.device = device
        logger.info(f"device = {self.device}")

        self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval()
        self.model.decoder = None  # to save some memory by deleting the decoder
        self.layer = layer  # one-based

    def read_audio(self, path, ref_len=None):
        wav, sample_rate = sf.read(path)
        assert sample_rate == 16000, sample_rate
        if ref_len is not None and abs(ref_len - len(wav)) > 160:
            logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
        return wav

    def get_feats(self, path, ref_len=None):
        wav = self.read_audio(path, ref_len)
        audio_length = len(wav)
        with torch.no_grad():
            mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device))
            hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer)
            feature_length = audio_length // 320
            hidden = hidden[0, :feature_length]
        return hidden.contiguous()
