# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchaudio


EMBEDDER_PARAMS = {
    'num_mels': 40,
    'n_fft': 512,
    'emb_dim': 256,
    'lstm_hidden': 768,
    'lstm_layers': 3,
    'window': 80,
    'stride': 40,
}


def set_requires_grad(nets, requires_grad=False):
    """Set requies_grad=Fasle for all the networks to avoid unnecessary
    computations
    Parameters:
        nets (network list)   -- a list of networks
        requires_grad (bool)  -- whether the networks require gradients or not
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad


class LinearNorm(nn.Module):
    def __init__(self, hp):
        super(LinearNorm, self).__init__()
        self.linear_layer = nn.Linear(hp["lstm_hidden"], hp["emb_dim"])

    def forward(self, x):
        return self.linear_layer(x)


class SpeechEmbedder(nn.Module):
    def __init__(self, hp):
        super(SpeechEmbedder, self).__init__()
        self.lstm = nn.LSTM(hp["num_mels"],
                            hp["lstm_hidden"],
                            num_layers=hp["lstm_layers"],
                            batch_first=True)
        self.proj = LinearNorm(hp)
        self.hp = hp

    def forward(self, mel):
        # (num_mels, T) -> (num_mels, T', window)
        mels = mel.unfold(1, self.hp["window"], self.hp["stride"])
        mels = mels.permute(1, 2, 0)  # (T', window, num_mels)
        x, _ = self.lstm(mels)  # (T', window, lstm_hidden)
        x = x[:, -1, :]  # (T', lstm_hidden), use last frame only
        x = self.proj(x)  # (T', emb_dim)
        x = x / torch.norm(x, p=2, dim=1, keepdim=True)  # (T', emb_dim)

        x = x.mean(dim=0)
        if x.norm(p=2) != 0:
            x = x / x.norm(p=2)
        return x


class SpkrEmbedder(nn.Module):
    RATE = 16000

    def __init__(
        self,
        embedder_path,
        embedder_params=EMBEDDER_PARAMS,
        rate=16000,
        hop_length=160,
        win_length=400,
        pad=False,
    ):
        super(SpkrEmbedder, self).__init__()
        embedder_pt = torch.load(embedder_path, map_location="cpu")
        self.embedder = SpeechEmbedder(embedder_params)
        self.embedder.load_state_dict(embedder_pt)
        self.embedder.eval()
        set_requires_grad(self.embedder, requires_grad=False)
        self.embedder_params = embedder_params

        self.register_buffer('mel_basis', torch.from_numpy(
            librosa.filters.mel(
                sr=self.RATE,
                n_fft=self.embedder_params["n_fft"],
                n_mels=self.embedder_params["num_mels"])
        )
                             )

        self.resample = None
        if rate != self.RATE:
            self.resample = torchaudio.transforms.Resample(rate, self.RATE)
        self.hop_length = hop_length
        self.win_length = win_length
        self.pad = pad

    def get_mel(self, y):
        if self.pad and y.shape[-1] < 14000:
            y = F.pad(y, (0, 14000 - y.shape[-1]))

        window = torch.hann_window(self.win_length).to(y)
        y = torch.stft(y, n_fft=self.embedder_params["n_fft"],
                       hop_length=self.hop_length,
                       win_length=self.win_length,
                       window=window)
        magnitudes = torch.norm(y, dim=-1, p=2) ** 2
        mel = torch.log10(self.mel_basis @ magnitudes + 1e-6)
        return mel

    def forward(self, inputs):
        dvecs = []
        for wav in inputs:
            mel = self.get_mel(wav)
            if mel.dim() == 3:
                mel = mel.squeeze(0)
            dvecs += [self.embedder(mel)]
        dvecs = torch.stack(dvecs)

        dvec = torch.mean(dvecs, dim=0)
        dvec = dvec / torch.norm(dvec)

        return dvec
