import numpy as np
import pyloudnorm as pyln
import torch

from config import SILENCE_RATIO, SR
import warnings

warnings.filterwarnings("ignore", message="Possible clipped samples in output.")


def loudness_normalize(wav, sr=SR, target_lufs=-23.0):
    """
    Apply loudness normalization on an audio signal.
    :param wav: waveform signal to normalize.
    :param sr: sampling rate.
    :param target_lufs: LUFS points to normalize to.
    :return: normalized signal.
    """
    meter = pyln.Meter(sr)
    loudness = meter.integrated_loudness(wav)
    normalized_wav = pyln.normalize.loudness(wav, loudness, target_lufs)
    peak = np.max(np.abs(normalized_wav))
    if peak > 1.0:
        normalized_wav = normalized_wav / max(peak, 1e-12)
    return np.clip(normalized_wav, -1.0, 1.0)


def frame_rms_torch(sig, win, hop):
    """
    Calculates the RMS of a signal with a moving window.
    :param sig: signal for calculation.
    :param win: analysis window size.
    :param hop: analysis window hop size.
    :return: RMS of signal.
    """
    dev = sig.device
    frames = sig.unfold(0, win, hop)
    if frames.size(0) and (frames.size(0) - 1) * hop == sig.numel() - win:
        frames = frames[:-1]
    rms = torch.sqrt((frames ** 2).mean(1) + 1e-12)
    return rms.to(dev)


def compute_speaker_activity_masks(refs_tensors, win, hop):
    """
    Computes individual voice activity for each speaker and determines which frames
    have at least 2 active speakers.
    :param refs_tensors: references that compose the mixture.
    :param win: analysis window size.
    :param hop: analysis window hop size.
    :return: (multi_speaker_mask, individual_speaker_masks)
        - multi_speaker_mask: boolean mask of frames where at least 2 speakers are active
        - individual_speaker_masks: list of boolean masks, one per speaker
    """
    device = refs_tensors[0].device
    individual_masks = []
    lengths = []

    for ref in refs_tensors:
        rms = frame_rms_torch(ref, win, hop)
        threshold = SILENCE_RATIO * torch.sqrt((ref ** 2).mean())
        voiced = rms > threshold
        individual_masks.append(voiced)
        lengths.append(voiced.numel())

    L_max = max(lengths)
    padded_masks = []
    for mask, L in zip(individual_masks, lengths):
        if L < L_max:
            padded = torch.cat([mask, torch.zeros(L_max - L, dtype=torch.bool, device=device)])
        else:
            padded = mask
        padded_masks.append(padded)

    stacked = torch.stack(padded_masks, dim=0)
    active_count = stacked.sum(dim=0)
    multi_speaker_mask = active_count >= 2
    return multi_speaker_mask, padded_masks