import librosa
import numpy as np
import torch
import torchaudio
from aves import load_feature_extractor
from transformers import (
    ClapModel,
    ClapProcessor,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model,
)

from wav2vec2_quant import customWav2Vec2ForQuantize, ShuffledWav2Vec2Model
import soundfile as sf


def infer_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("| Using CUDA for computation.")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("| Using Apple MPS for computation.")
    else:
        device = torch.device("cpu")
        print("| Using CPU for computation.")
    return device


def get_waveform(filename, target_sample_rate):
    try:
        waveform, sample_rate = torchaudio.load(filename)
    except RuntimeError:
        import librosa

        waveform, sample_rate = librosa.load(filename, sr=None)
        waveform = torch.tensor(waveform).unsqueeze(0)

    waveform = torch.mean(waveform, dim=0).unsqueeze(0)

    if sample_rate != target_sample_rate:
        transform = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
        waveform = transform(waveform)

    return waveform


class BioLingual:
    def __init__(self, *args, **kwargs):
        self.processor = ClapProcessor.from_pretrained("davidrrobinson/biolingual")
        model = ClapModel.from_pretrained("davidrrobinson/biolingual")

        self.device = infer_device()

        self.model = model.to(self.device).eval()

    def __call__(self, file_path: str):
        waveform = get_waveform(file_path, 48000)
        waveform = waveform.squeeze().numpy()

        processed = self.processor(
            audios=waveform, return_tensors="pt", sampling_rate=48000
        )
        inputs = processed["input_features"].to(self.device)

        with torch.no_grad():
            outputs = self.model.get_audio_features(input_features=inputs)

        return outputs.squeeze(0)


class Aves:
    def __init__(
        self,
        aves_model_path: str,
        aves_config_path: str,
        sample_rate: int = 44100,
        *args,
        **kwargs,
    ):
        self.feature_extractor = load_feature_extractor(
            config_path=aves_config_path,
            model_path=aves_model_path,
            device="cuda" if torch.cuda.is_available() else "cpu",
            for_inference=True,
        )
        self.sample_rate = sample_rate

    def __call__(self, file_path):
        waveform = get_waveform(file_path, self.sample_rate)
        return (
            self.feature_extractor.extract_features(waveform, layers=-1)
            .mean(dim=1)
            .squeeze()
        )


class Dolph2Vec:
    def __init__(
        self, dolph2vec_config_path: str, sample_rate: int = 44100, *args, **kwargs
    ):
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_json_file(
            dolph2vec_config_path
        )
        model = Wav2Vec2Model.from_pretrained(
            "dolphinteam/model-dolph2vec_type-base_data-DolphinChat_version-v0"
        )  # , token=token)

        self.device = infer_device()

        self.model = model.to(self.device).eval()

        self.sample_rate = sample_rate

    def __call__(self, file_path):
        waveform = get_waveform(file_path, self.sample_rate)
        waveform = waveform.squeeze().numpy()

        features = self.feature_extractor(
            raw_speech=waveform,  # [f.numpy() for f in waveform],
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
            sampling_rate=self.sample_rate,
            truncation=False,
            # max_length = int(20* self.feature_extractor.sampling_rate),
            # min_length = int(2*self.feature_extractor.sampling_rate)
        )["input_values"]

        features = features.to(self.device)

        with torch.no_grad():
            res = self.model(features, output_hidden_states=True)

        embedding = res.hidden_states[-1].mean(1).squeeze()
        # embeddings = torch.cat(res.hidden_states, dim=-1).squeeze()
        return embedding



class W2VQuantizer:
    def __init__(self, dolph2vec_config_path: str, sample_rate: int = 44100, *args, **kwargs):

        model = customWav2Vec2ForQuantize.from_pretrained(
           "dolphinteam/model-dolph2vec_type-base_data-DolphinChat_version-v0"
        )  # , token=token)
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_json_file(
            dolph2vec_config_path)

        self.device = infer_device()

        self.model = model.to(self.device).eval()

        self.sample_rate = sample_rate

    def __call__(self, wav_path: str):
        wav, sr = sf.read(wav_path)
        if sr != self.sample_rate:
            raise ValueError(f"Expected {self.sample_rate}, but got {sr} for {wav_path}")

        inputs = self.feature_extractor(wav, sampling_rate=sr, return_tensors="pt")
        input_values = inputs.input_values.to(self.device)

        with torch.no_grad():
            outputs = self.model(input_values=input_values, return_dict=True, return_hidden_activity=True)

        indices = outputs.codevectorIdx.cpu().numpy() # return codebook indices
        quantized = outputs.quantizedactivity_preproj.squeeze(0).mean(dim=0) #Return quantized embedding (pre-projection) averaged 

        return quantized, indices



class ShuffledWav2Vec2Wrapper:
    def __init__(
        self, dolph2vec_config_path: str, sample_rate: int = 44100, *args, **kwargs
        ):
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_json_file(
            dolph2vec_config_path
        )
        model = ShuffledWav2Vec2Model.from_pretrained(
            "dolphinteam/model-dolph2vec_type-base_data-DolphinChat_version-v0"
        )  # , token=token)

        self.device = infer_device()

        self.model = model.to(self.device).eval()

        self.sample_rate = sample_rate

    def __call__(self, file_path):
        waveform = get_waveform(file_path, self.sample_rate)
        waveform = waveform.squeeze().numpy()

        features = self.feature_extractor(
            raw_speech=waveform,  # [f.numpy() for f in waveform],
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
            sampling_rate=self.sample_rate,
            truncation=False,
            # max_length = int(20* self.feature_extractor.sampling_rate),
            # min_length = int(2*self.feature_extractor.sampling_rate)
        )["input_values"]

        features = features.to(self.device)

        with torch.no_grad():
            res = self.model(features, output_hidden_states=True)

        embedding = res.hidden_states[-1].mean(1).squeeze()
        # embeddings = torch.cat(res.hidden_states, dim=-1).squeeze()
        return embedding



class Spectrogram:
    def __init__(self, sample_rate: int = 44100, *args, **kwargs):
        self.sample_rate = sample_rate

    def __call__(self, file_path):
        waveform = get_waveform(file_path, self.sample_rate)
        waveform = waveform.squeeze().numpy()

        # Compute the spectrogram
        # spectrogram = torchaudio.transforms.Spectrogram()(torch.tensor(waveform))
        # spectrogram = torch.mean(spectrogram, dim=0)

        S = librosa.feature.melspectrogram(y=waveform, sr=self.sample_rate, n_mels=128)
        feats = np.mean(
            librosa.power_to_db(S, ref=np.max), axis=1
        )  # Averaged spectrogram
        return torch.Tensor(feats)


class MFCC:
    def __init__(self, sample_rate: int = 44100, n_mfcc: int = 13, *args, **kwargs):
        self.sample_rate = sample_rate
        self.n_mfcc = n_mfcc

    def __call__(self, file_path):
        waveform = get_waveform(file_path, self.sample_rate)
        waveform = waveform.squeeze().numpy()

        mfccs = librosa.feature.mfcc(y=waveform, sr=self.sample_rate, n_mfcc=self.n_mfcc)
        return torch.Tensor(np.mean(mfccs, axis=1))  # Averaged MFCCs


class SpectralFeatures:
    def __init__(self, sample_rate: int = 44100, *args, **kwargs):
        self.sample_rate = sample_rate

    def __call__(self, file_path):
        waveform = get_waveform(file_path, self.sample_rate)
        waveform = waveform.squeeze().numpy()

        spectral_centroid = librosa.feature.spectral_centroid(
            y=waveform, sr=self.sample_rate
        )
        spectral_bandwidth = librosa.feature.spectral_bandwidth(
            y=waveform, sr=self.sample_rate
        )
        spectral_contrast = librosa.feature.spectral_contrast(
            y=waveform, sr=self.sample_rate
        )
        spectral_rolloff = librosa.feature.spectral_rolloff(
            y=waveform, sr=self.sample_rate
        )

        feats = np.concatenate(
            [
                np.mean(spectral_centroid, axis=1),
                np.mean(spectral_bandwidth, axis=1),
                np.mean(spectral_contrast, axis=1),
                np.mean(spectral_rolloff, axis=1),
            ]
        )

        return torch.Tensor(feats)
