# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
import typing as tp

import julius
import torch
import torchaudio
from torch import nn
from torch.nn import functional as F
from torchaudio.functional.filtering import highpass_biquad, treble_biquad


def basic_loudness(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
    """This is a simpler loudness function that is more stable.
    Args:
        waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)`
        sample_rate (int): sampling rate of the waveform
    Returns:
        loudness loss as a scalar
    """

    if waveform.size(-2) > 5:
        raise ValueError("Only up to 5 channels are supported.")
    eps = torch.finfo(torch.float32).eps
    gate_duration = 0.4
    overlap = 0.75
    gate_samples = int(round(gate_duration * sample_rate))
    step = int(round(gate_samples * (1 - overlap)))

    # Apply K-weighting
    waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2))
    waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5)

    # Compute the energy for each block
    energy = torch.square(waveform).unfold(-1, gate_samples, step)
    energy = torch.mean(energy, dim=-1)

    # Compute channel-weighted summation
    g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device)
    g = g[: energy.size(-2)]

    energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2)
    # loudness with epsilon for stability. Not as much precision in the very low loudness sections
    loudness = -0.691 + 10 * torch.log10(energy_weighted + eps)
    return loudness


def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
    """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
    with K the kernel size, by extracting frames with the given stride.
    This will pad the input so that `F = ceil(T / K)`.
    see https://github.com/pytorch/pytorch/issues/60466
    """
    *shape, length = a.shape
    n_frames = math.ceil(length / stride)
    tgt_length = (n_frames - 1) * stride + kernel_size
    a = F.pad(a, (0, tgt_length - length))
    strides = list(a.stride())
    assert strides[-1] == 1, "data should be contiguous"
    strides = strides[:-1] + [stride, 1]
    return a.as_strided([*shape, n_frames, kernel_size], strides)


class FLoudnessRatio(nn.Module):
    """FSNR loss.

    Input should be [B, C, T], output is scalar.

    Args:
        sample_rate (int): Sample rate.
        segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
            entire audio only.
        overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
        epsilon (float): Epsilon value for numerical stability.
        n_bands (int): number of mel scale bands that we include
    """
    def __init__(
        self,
        sample_rate: int = 24000,
        segment: tp.Optional[float] = 20,
        overlap: float = 0.5,
        epsilon: float = torch.finfo(torch.float32).eps,
        n_bands: int = 0,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.segment = segment
        self.overlap = overlap
        self.epsilon = epsilon
        if n_bands == 0:
            self.filter = None
        else:
            self.filter = julius.SplitBands(sample_rate=sample_rate, n_bands=n_bands)
        self.loudness = torchaudio.transforms.Loudness(sample_rate)

    def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
        B, C, T = ref_sig.shape
        assert ref_sig.shape == out_sig.shape
        assert self.filter is not None
        bands_ref = self.filter(ref_sig)
        bands_out = self.filter(out_sig)
        l_noise = self.loudness(bands_ref - bands_out)
        l_ref = self.loudness(bands_ref)
        l_ratio = (l_noise - l_ref).view(-1, B)
        loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio
        return loss.sum()


class TLoudnessRatio(nn.Module):
    """TSNR loss.

    Input should be [B, C, T], output is scalar.

    Args:
        sample_rate (int): Sample rate.
        segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
            entire audio only.
        overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
    """
    def __init__(
        self,
        sample_rate: int = 24000,
        segment: float = 0.5,
        overlap: float = 0.5,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.segment = segment
        self.overlap = overlap
        self.loudness = torchaudio.transforms.Loudness(sample_rate)

    def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
        B, C, T = ref_sig.shape
        assert ref_sig.shape == out_sig.shape
        assert C == 1

        frame = int(self.segment * self.sample_rate)
        stride = int(frame * (1 - self.overlap))
        gt = _unfold(ref_sig, frame, stride).view(-1, 1, frame)
        est = _unfold(out_sig, frame, stride).view(-1, 1, frame)
        l_noise = self.loudness(gt - est)  # watermark
        l_ref = self.loudness(gt)  # ground truth
        l_ratio = (l_noise - l_ref).view(-1, B)
        loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio
        return loss.sum()


class TFLoudnessRatio(nn.Module):
    """TF-loudness ratio loss.

    Input should be [B, C, T], output is scalar.

    Args:
        sample_rate (int): Sample rate.
        segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
            entire audio only.
        overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
        n_bands (int): number of bands to separate
        temperature (float): temperature of the softmax step
    """
    def __init__(
        self,
        sample_rate: int = 24000,
        segment: float = 0.5,
        overlap: float = 0.5,
        n_bands: int = 16,
        clip_min: float = -100,
        temperature: float = 1.0,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.segment = segment
        self.overlap = overlap
        self.clip_min = clip_min
        self.temperature = temperature
        if n_bands == 0:
            self.filter = None
        else:
            self.n_bands = n_bands
            self.filter = julius.SplitBands(sample_rate=sample_rate, n_bands=n_bands)

    def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
        B, C, T = ref_sig.shape

        assert ref_sig.shape == out_sig.shape
        assert C == 1
        assert self.filter is not None

        bands_ref = self.filter(ref_sig).view(B * self.n_bands, 1, -1)
        bands_out = self.filter(out_sig).view(B * self.n_bands, 1, -1)
        frame = int(self.segment * self.sample_rate)
        stride = int(frame * (1 - self.overlap))
        gt = _unfold(bands_ref, frame, stride).squeeze(1).contiguous().view(-1, 1, frame)
        est = _unfold(bands_out, frame, stride).squeeze(1).contiguous().view(-1, 1, frame)
        l_noise = basic_loudness(est - gt, sample_rate=self.sample_rate)  # watermark
        l_ref = basic_loudness(gt, sample_rate=self.sample_rate)  # ground truth
        l_ratio = (l_noise - l_ref).view(-1, B)
        loss = torch.nn.functional.softmax(l_ratio / self.temperature, dim=0) * l_ratio
        return loss.mean()
