from transformers import Wav2Vec2Model, Wav2Vec2Processor
from transformers.modeling_outputs import BaseModelOutput
import torch
import torch.nn.functional as F
import sys
import os
from transformers.models.wav2vec2 import Wav2Vec2Config

sys.path.append(os.curdir)
from mmhug.registry import HF_MODELS


def linear_interpolation(features, seq_len):
    features = features.transpose(1, 2)
    output_features = F.interpolate(
        features.float(), size=seq_len, align_corners=True, mode="linear"
    ).to(features.dtype)
    return output_features.transpose(1, 2)


def get_mask_from_lengths(lengths, max_len=None):
    lengths = lengths.to(torch.long)
    if max_len is None:
        max_len = torch.max(lengths).item()

    ids = (
        torch.arange(0, max_len)
        .unsqueeze(0)
        .expand(lengths.shape[0], -1)
        .to(lengths.device)
    )
    mask = ids < lengths.unsqueeze(1).expand(-1, max_len)

    return mask


# Borrowed from https://github.com/MeiGen-AI/MultiTalk/blob/main/src/audio_analysis/wav2vec2.py#L9
@HF_MODELS.register_module(force=True)
class Wav2Vec2InterpModel(Wav2Vec2Model):
    """
    Extension of Hugging Face's Wav2Vec2Model that inserts a linear-interpolation
    step to align audio token sequence length to a target `seq_len` (e.g., video latent length).

    Inherits:
        transformers.Wav2Vec2Model

    Overrides:
        - forward(): adds interpolation before encoding.
        - feature_extract(): pure feature-extraction + interpolation.
        - encode(): encode-only path for pre-interpolated features.
    """

    def __init__(self, config: Wav2Vec2Config, norm_out: bool = False):
        super().__init__(config)
        if norm_out:
            self.norm_out = torch.nn.LayerNorm(
                config.output_hidden_size, eps=config.layer_norm_eps
            )
        else:
            self.norm_out = None
        self.config.mask_time_prob = 0.0

    def forward(
        self,
        input_values,
        seq_len: int,
        attention_mask=None,
        mask_time_indices=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        """
        Extract features, interpolate to `seq_len`, then run through the full Transformer encoder.

        Args:
            input_values (torch.FloatTensor):
                Raw audio waveform batch of shape (batch_size, sequence_length)
                as expected by the HF feature extractor :contentReference[oaicite:0]{index=0}.
            seq_len (int):
                Target number of time-steps after interpolation (e.g., matching video frames).
            attention_mask (torch.BoolTensor, optional):
                Mask for the input waveforms; will be reduced to match interpolated tokens.
            mask_time_indices (torch.BoolTensor, optional):
                Indices where the model should apply masking (for pretraining scenarios).
            output_attentions (bool, optional):
                Whether to return self-attention weights (default from config).
            output_hidden_states (bool, optional):
                Whether to return hidden states of all layers (default from config).
            return_dict (bool, optional):
                Whether to return a BaseModelOutput or a plain tuple (default from config).

        Returns:
            BaseModelOutput or tuple:
                - last_hidden_state: (batch_size, seq_len, hidden_size)
                - hidden_states (if asked): tuple of (batch_size, seq_len, hidden_size)
                - attentions (if asked): tuple of (batch_size, num_heads, seq_len, seq_len)

        Notes:
            - This overrides the HF `forward()` call to inject `linear_interpolation`
              between feature extraction and projection :contentReference[oaicite:1]{index=1}.
        """
        # Determine output_hidden_states and return_dict from args or config
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # 1. Feature extraction (CNN layers)
        extract_features = self.feature_extract(input_values, seq_len=seq_len)

        # 3. Adjust attention mask to new length, if provided
        if attention_mask is not None:
            attention_mask = self._get_feature_vector_attention_mask(
                extract_features.shape[1],
                attention_mask,
                add_adapter=False,
            )

        # Projection + optional masking
        hidden_states, extract_features = self.feature_projection(extract_features)
        hidden_states = self._mask_hidden_states(
            hidden_states,
            mask_time_indices=mask_time_indices,
            attention_mask=attention_mask,
        )

        # Transformer encoder
        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = encoder_outputs[0]
        # Adapter layer
        if self.adapter is not None:
            hidden_states = self.adapter(hidden_states)
        if self.norm_out is not None:
            hidden_states = self.norm_out(hidden_states)

        # Return in same format as forward()
        if not return_dict:
            return (hidden_states,) + encoder_outputs[1:]
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=encoder_outputs.hidden_states,
        )

    def feature_extract(
        self,
        input_values,
        seq_len: int,
    ):
        """
        Stand-alone feature-extraction + interpolation. Does *not* run the Transformer encoder.

        Args:
            input_values (torch.FloatTensor): raw audio batch.
            seq_len (int): target token length after interpolation.

        Returns:
            torch.FloatTensor: interpolated feature tensor (batch, seq_len, feature_dim).

        Use-case:
            - When you only need to extract and align audio embeddings without encoding.
        """
        # CNN feature extraction
        extract_features = self.feature_extractor(input_values)
        extract_features = extract_features.transpose(1, 2)
        # Linear interpolation
        extract_features = linear_interpolation(extract_features, seq_len=seq_len)
        return extract_features


if __name__ == "__main__":
    import librosa

    fps = 25
    device = "cuda"
    audio_encoder = HF_MODELS.build(
        dict(
            type="Wav2Vec2InterpModel",
            from_pretrained=dict(
                pretrained_model_name_or_path="checkpoints/wav2vec2-base-960h",
                local_files_only=True,
                add_adapter=True,
                adapter_stride=1,
                num_adapter_layers=3,
            ),
        )
    ).to(device)
    processor = Wav2Vec2Processor.from_pretrained(
        "checkpoints/wav2vec2-base-960h", local_files_only=True
    )
    audio_demo_path = "demo_assets/21_ayana_0_10_10_1.wav"
    audio, sr = librosa.load(audio_demo_path, sr=16000)
    video_length = len(audio) / sr * fps
    audio_feature = processor(
        audio, sampling_rate=sr, return_tensors="pt"
    ).input_values[0]
    audio_feature = audio_feature.float().to(device=device)[None]
    embeddings = audio_encoder(
        audio_feature, seq_len=int(video_length), output_hidden_states=True
    )
    # 1, n, 768
    last_hidden_state = embeddings.last_hidden_state
    # 13 * [1, n, 768]
    hidden_states = embeddings.hidden_states
    print(last_hidden_state.shape)
    print(len(hidden_states))
    print([h.shape for h in hidden_states])
