from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from whisper.model import AudioEncoder, sinusoids, Whisper, ModelDimensions


class AudioEncoder_(AudioEncoder):
    def __init__(self, *args, **kwargs):
        super(AudioEncoder_, self).__init__(*args, **kwargs)

    def extract_feature(self, x: Tensor, target_layer: Optional[int] = None):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        length_x = x.shape[1]
        if length_x > self.positional_embedding.shape[0]:
            self.register_buffer("positional_embedding", sinusoids(length_x, self.positional_embedding.shape[1]))
            self.positional_embedding = self.positional_embedding.to(x.device)
        x = (x + self.positional_embedding[:length_x, :]).to(x.dtype)

        if target_layer is None:
            target_layer = len(self.blocks)

        for block in self.blocks[:target_layer]:
            x = block(x)

        return x


class Whisper_(Whisper):
    def __init__(self, dims: ModelDimensions):
        super(Whisper_, self).__init__(dims)
        # replace audio encoder with our audio encoder
        self.encoder = AudioEncoder_(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )

    def extract_features(self, mel: torch.Tensor, target_layer: Optional[int] = None):
        return self.encoder.extract_feature(mel, target_layer)
