import torch
import librosa
import os

from transformers import WhisperFeatureExtractor
from .glm4.speech_tokenizer.modeling_whisper import WhisperVQEncoder
from .glm4_utils import extract_speech_token
from torch import nn


class Glm4Tokenizer(nn.Module):
    def __init__(self, tokenizer_path):
        super().__init__()
        self.whisper_model = WhisperVQEncoder.from_pretrained(tokenizer_path).eval()
        self.feature_extractor = WhisperFeatureExtractor.from_pretrained(tokenizer_path)

    def tokenize(self, speech=None, audio_path=None, sr=16000):
        if audio_path:
            audio, sr = librosa.load(audio_path, sr=16000)
            audio = torch.tensor(audio).unsqueeze(0)
            audio_info = (audio, sr)
        else:
            assert speech is not None
            assert sr
            if isinstance(speech, list):
                speech = torch.tensor(speech).unsqueeze(0)
            if len(speech.shape) == 1:
                speech = speech.unsqueeze(0)
            audio_info = (speech, sr)

        audio_tokens = extract_speech_token(
            self.whisper_model, self.feature_extractor, [audio_info]
        )[0]
        audio_tokens = torch.tensor(audio_tokens).unsqueeze(0)
        return audio_tokens
