import numpy as np
import torch.nn as nn
import torch
from torchaudio.transforms import MelSpectrogram
from torchvision import transforms

class MelPreprocessor(nn.Module):
    def __init__(self, cfg):
        super(MelPreprocessor, self).__init__()
        self.cfg = cfg
        
    def get_mel(self, x, fs, n_fft, hop_length, n_mels, normalizing=True, **kwargs):
        # Compute the Mel spectrogram

        mel_spec = MelSpectrogram(
            sample_rate=fs,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            f_min=0,
            f_max=fs/2,
            power=1
        )
        magnitude = mel_spec(x)
        epsilon = 1e-10
        magnitude = torch.log(magnitude ** 2 + epsilon)

        # Crop Mel spectrograms to 64 x 64
        magnitude = magnitude[:, :64, :64]

        # Z-score the spectrograms globally per channel
        if normalizing:
            mean = magnitude.mean(dim=(1, 2), keepdim=True)
            std = magnitude.std(dim=(1, 2), keepdim=True)
            std = std.clamp(min=epsilon)  # Avoiding division by zero
            normalize = transforms.Normalize(mean=mean.squeeze().tolist(), std=std.squeeze().tolist())
            magnitude = normalize(magnitude)

        return magnitude

    # Override the forward method for the nn.Module
    def forward(self, x):
        mel_spec = self.get_mel(x, fs=self.cfg.fs,
                                    n_fft=self.cfg.n_fft,
                                     hop_length=self.cfg.hop_length,
                                     n_mels=self.cfg.n_mels,
                                     normalizing=self.cfg.normalizing)
        return mel_spec