# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp

import torch
import torch.nn.functional as F
import torchaudio
from einops import rearrange
from librosa import filters
from torch import nn


class ChromaExtractor(nn.Module):
    """Chroma extraction and quantization.

    Args:
        sample_rate (int): Sample rate for the chroma extraction.
        n_chroma (int): Number of chroma bins for the chroma extraction.
        radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
        nfft (int, optional): Number of FFT.
        winlen (int, optional): Window length.
        winhop (int, optional): Window hop size.
        argmax (bool, optional): Whether to use argmax. Defaults to False.
        norm (float, optional): Norm for chroma normalization. Defaults to inf.
    """

    def __init__(
        self,
        sample_rate: int,
        n_chroma: int = 12,
        radix2_exp: int = 12,
        nfft: tp.Optional[int] = None,
        winlen: tp.Optional[int] = None,
        winhop: tp.Optional[int] = None,
        argmax: bool = False,
        norm: float = torch.inf,
    ):
        super().__init__()
        self.winlen = winlen or 2**radix2_exp
        self.nfft = nfft or self.winlen
        self.winhop = winhop or (self.winlen // 4)
        self.sample_rate = sample_rate
        self.n_chroma = n_chroma
        self.norm = norm
        self.argmax = argmax
        self.register_buffer(
            "fbanks",
            torch.from_numpy(
                filters.chroma(
                    sr=sample_rate, n_fft=self.nfft, tuning=0, n_chroma=self.n_chroma
                )
            ),
            persistent=False,
        )
        self.spec = torchaudio.transforms.Spectrogram(
            n_fft=self.nfft,
            win_length=self.winlen,
            hop_length=self.winhop,
            power=2,
            center=True,
            pad=0,
            normalized=True,
        )

    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        T = wav.shape[-1]
        # in case we are getting a wav that was dropped out (nullified)
        # from the conditioner, make sure wav length is no less that nfft
        if T < self.nfft:
            pad = self.nfft - T
            r = 0 if pad % 2 == 0 else 1
            wav = F.pad(wav, (pad // 2, pad // 2 + r), "constant", 0)
            assert (
                wav.shape[-1] == self.nfft
            ), f"expected len {self.nfft} but got {wav.shape[-1]}"

        spec = self.spec(wav).squeeze(1)
        raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec)
        norm_chroma = torch.nn.functional.normalize(
            raw_chroma, p=self.norm, dim=-2, eps=1e-6
        )
        norm_chroma = rearrange(norm_chroma, "b d t -> b t d")

        if self.argmax:
            idx = norm_chroma.argmax(-1, keepdim=True)
            norm_chroma[:] = 0
            norm_chroma.scatter_(dim=-1, index=idx, value=1)

        return norm_chroma
