import warnings

import os
import numpy as np
import pydub
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from PIL import Image

import csv

from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import audio_util, torch_util

class SpectrogramConverter:
    
    def __init__(self, params: SpectrogramParams, device: str = "cuda"):
        self.p = params

        self.device = torch_util.check_device(device)

        if device.lower().startswith("mps"):
            warnings.warn(
                "WARNING: MPS does not support audio operations, falling back to CPU for them",
                stacklevel=2,
            )
            self.device = "cpu"

        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
        self.spectrogram_func = torchaudio.transforms.Spectrogram(
            n_fft=params.n_fft,
            hop_length=params.hop_length,
            win_length=params.win_length,
            pad=0,
            window_fn=torch.hann_window,
            power=None,
            normalized=False,
            wkwargs=None,
            center=True,
            pad_mode="reflect",
            onesided=True,
        ).to(self.device)

        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
        self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
            n_fft=params.n_fft,
            n_iter=params.num_griffin_lim_iters,
            win_length=params.win_length,
            hop_length=params.hop_length,
            window_fn=torch.hann_window,
            power=1.0,
            wkwargs=None,
            momentum=0.99,
            length=None,
            rand_init=True,
        ).to(self.device)

        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
        self.mel_scaler = torchaudio.transforms.MelScale(
            n_mels=params.num_frequencies,
            sample_rate=params.sample_rate,
            f_min=params.min_frequency,
            f_max=params.max_frequency,
            n_stft=params.n_fft // 2 + 1,
            norm=params.mel_scale_norm,
            mel_scale=params.mel_scale_type,
        ).to(self.device)

        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
        self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
            n_stft=params.n_fft // 2 + 1,
            n_mels=params.num_frequencies,
            sample_rate=params.sample_rate,
            f_min=params.min_frequency,
            f_max=params.max_frequency,
            # max_iter=params.max_mel_iters,
            # tolerance_loss=1e-5,
            # tolerance_change=1e-8,
            # sgdargs=None,
            norm=params.mel_scale_norm,
            mel_scale=params.mel_scale_type,
        ).to(self.device)

    def spectrogram_from_audio(
        self,
        audio: pydub.AudioSegment,
    ) -> np.ndarray:
        
        assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"

        # Get the samples as a numpy array in (batch, samples) shape
        waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])

        # Convert to floats if necessary
        if waveform.dtype != np.float32:
            waveform = waveform.astype(np.float32)

        waveform_tensor = torch.from_numpy(waveform).to(self.device)
        amplitudes,amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
        # print('amplitudes:',amplitudes.shape, 'mel:',amplitudes_mel.shape)
        return amplitudes_mel.cpu().numpy()

    def getspec(
        self,
        audio: pydub.AudioSegment,
    ) -> np.ndarray:
        
        assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"

        # Get the samples as a numpy array in (batch, samples) shape
        waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])

        # Convert to floats if necessary
        if waveform.dtype != np.float32:
            waveform = waveform.astype(np.float32)

        waveform_tensor = torch.from_numpy(waveform).to(self.device)
        amplitudes,amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
        # print('amplitudes:',amplitudes.shape, 'mel:',amplitudes_mel.shape)
        return amplitudes.cpu().numpy()

    def audio_from_spectrogram(
        self,
        spectrogram: np.ndarray,
        apply_filters: bool = True,
    ) -> pydub.AudioSegment:
        
        # Move to device
        amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)

        # Reconstruct the waveform
        waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)

        # Convert to audio segment
        segment = audio_util.audio_from_waveform(
            samples=waveform.cpu().numpy(),
            sample_rate=self.p.sample_rate,
            # Normalize the waveform to the range [-1, 1]
            normalize=False,
        )

        # Optionally apply post-processing filters
        if apply_filters:
            segment = audio_util.apply_filters(
                segment,
                compression=False,
            )

        return segment

    def mel_amplitudes_from_waveform(
        self,
        waveform: torch.Tensor,
    ) -> torch.Tensor:
        
        # Compute the complex-valued spectrogram
        spectrogram_complex = self.spectrogram_func(waveform)

        # Take the magnitude
        amplitudes = torch.abs(spectrogram_complex)

        # Convert to mel scale
        return amplitudes, self.mel_scaler(amplitudes)

    def waveform_from_mel_amplitudes(
        self,
        amplitudes_mel: torch.Tensor,
    ) -> torch.Tensor:
        
        # Convert from mel scale to linear
        amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)

        # Run the approximate algorithm to compute the phase and recover the waveform
        return self.inverse_spectrogram_func(amplitudes_linear)

class SpectrogramImageConverter_mod:
    
    def __init__(self, params: SpectrogramParams, device: str = "cuda"):
        self.p = params
        self.device = device
        self.converter = SpectrogramConverter(params=params, device=device)

    def spectrogram2_from_audio(
        self,
        segment: pydub.AudioSegment,
    ):
        """
        Compute a spectrogram image from an audio segment.

        Args:
            segment: Audio segment to convert

        Returns:
            Spectrogram image (in pillow format)
        """
        assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch"

        if self.p.stereo:
            if segment.channels == 1:
                print("WARNING: Mono audio but stereo=True, cloning channel")
                segment = segment.set_channels(2)
            elif segment.channels > 2:
                print("WARNING: Multi channel audio, reducing to stereo")
                segment = segment.set_channels(2)
        else:
            if segment.channels > 1:
                print("WARNING: Stereo audio but stereo=False, setting to mono")
                segment = segment.set_channels(1)

        spec = self.converter.getspec(segment)
        spectrogram = self.converter.spectrogram_from_audio(segment)

        return spec,spectrogram

    def audio2_from_spectrogram(
        self,
        spectrogram: np.ndarray,
        apply_filters: bool = True,
        max_value: float = 30e6,
    ) -> pydub.AudioSegment:


        segment = self.converter.audio_from_spectrogram(
            spectrogram,
            apply_filters=apply_filters,
        )

        return segment

def combine_matrices_with_mask(matrix1, matrix2, mask_image_path):

    matrix1_2d = matrix1.squeeze() 
    matrix2_2d = matrix2.squeeze()

    mask_image = Image.open(mask_image_path).convert('L')  
    # print(mask_image.size)
    mask = np.array(mask_image) > 128  
    combined_matrix_2d = np.where(mask, matrix1_2d, matrix2_2d)
    combined_matrix = np.expand_dims(combined_matrix_2d, axis=0)

    return combined_matrix

music_dir = '/root/autodl-tmp/Audio_tests/BlendnMask/samp5000_normed/r2/r2n_-24/'
noise_dir = '/root/autodl-tmp/Audio_tests/BlendnMask/noise1000/naud_norm/n_norm-24/'
mask_dir = '/root/autodl-tmp/Audio_tests/BlendnMask/noise1000/mask_1000/'
res_dir = '/root/autodl-tmp/Audio_tests/BlendnMask/recon_res/r2n-24_1022/'
music_list = os.listdir(music_dir)
noise_list = os.listdir(noise_dir)
mask_list = os.listdir(mask_dir)

c = 0
optim_ts2 = []
for i in music_list:
# i = music_list[1]
    musicfile = music_dir+i
    noisename = i.split('_im')[0]
    for j in noise_list:
        if noisename in j:
            noisefile = noise_dir+j
    for k in mask_list:
        if noisename in k:
            maskfile = mask_dir+k
    segment1 = pydub.AudioSegment.from_file(musicfile)
    noise1 = pydub.AudioSegment.from_file(noisefile)
    sample_rate = 44100
    step_size_ms = 10
    window_duration_ms = 100
    padded_duration_ms = 400

    # Mel scale parameters
    num_frequencies = 512

    min_frequency = 0
    max_frequency = 10000
    mel_scale_norm = None
    mel_scale_type = "htk"
    max_mel_iters = 200

    # Griffin Lim parameters
    num_griffin_lim_iters = 32

    n_fft = padded_duration_ms / 1000.0 * sample_rate
    win_length = window_duration_ms / 1000.0 * sample_rate
    hop_length = step_size_ms / 1000.0 * sample_rate

    params = SpectrogramParams(
        sample_rate=segment1.frame_rate,
        # stereo=stereo,
        window_duration_ms=window_duration_ms,
        padded_duration_ms=padded_duration_ms,
        step_size_ms=step_size_ms,
        min_frequency=min_frequency,
        max_frequency=max_frequency,
        num_frequencies=num_frequencies,
        # power_for_image=power_for_image,
    )

    converter = SpectrogramImageConverter_mod(params=params,device='cpu')
    spec1,spectrog1 = converter.spectrogram2_from_audio(segment1)

    params = SpectrogramParams(
        sample_rate=noise1.frame_rate,
        # stereo=stereo,
        window_duration_ms=window_duration_ms,
        padded_duration_ms=padded_duration_ms,
        step_size_ms=step_size_ms,
        min_frequency=min_frequency,
        max_frequency=max_frequency,
        num_frequencies=num_frequencies,

    )

    converter = SpectrogramImageConverter_mod(params=params, device='cpu')
    converter0 = SpectrogramConverter(params=params, device='cpu')
    spec2,spectrog2 = converter.spectrogram2_from_audio(noise1)

    mask_image = Image.open(maskfile).convert('L')  # 读取为灰度图
    mask_array = np.array(mask_image)

    threshold = 100  
    binary_mask = (mask_array <= threshold).astype(np.float32)

    amplitude_to_db = torchaudio.transforms.AmplitudeToDB(stype='power', top_db=80.0) 

    noise_linear = converter0.inverse_mel_scaler(torch.tensor(spectrog2))
    magnitude2 = torch.abs(noise_linear)
    power_spectrogram = magnitude2 ** 2
    db_spectrogram3 = amplitude_to_db(power_spectrogram)
    db_spectrogram3[db_spectrogram3 > 0] += 21
    spectrog3 = torch.pow(10.0, db_spectrogram3 / 20.0)

    Threshold = converter0.mel_scaler(spectrog3)



    alpha = 0.14#0.0038900000000001 
    T = Threshold*torch.from_numpy(binary_mask)

    anchor = torch.full_like(torch.from_numpy(spectrog1), 0)

    t = torch.tensor(1.0, requires_grad=True)

    triplet_loss_fn = nn.TripletMarginLoss(margin=0, p=1,reduction='sum')

    optimizer = optim.SGD([t], lr=0.01)
    for epoch in range(1000):
        optimizer.zero_grad()

        a_scaled = torch.tensor(spectrog1)*t
        
        triplet_loss = triplet_loss_fn(anchor, T, a_scaled*torch.from_numpy(binary_mask).float())

        custom_loss = alpha * (torch.sum(a_scaled))

        loss = custom_loss + triplet_loss
        # if epoch % 100 == 0:
        #     print('epoch:',epoch,'t:', t.item(), 'loss:' , custom_loss.item() + triplet_loss.item(), 'triplet_loss:',triplet_loss.item(), 'custom_loss:', custom_loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_([t], max_norm=1.0)
        optimizer.step()

        with torch.no_grad():
            t.clamp_(min=1.0)#

    # print('loss:' , custom_loss + triplet_loss, 'triplet_loss:',triplet_loss, 'custom_loss:', custom_loss)
    # print('t:', t.item(), 'loss:' , custom_loss.item() + triplet_loss.item(), 'triplet_loss:',triplet_loss.item(), 'custom_loss:', custom_loss.item())
    print(f'Optimal t: {t.item()}')
    optim_ts2.append([i,t.item()])
    if t>1:
        real_t = min(t.item(),4)
    else:
        real_t = 1

    # spec1_scaled = amplitude_to_db(torch.tensor(spectrog1)*real_t)
    # spectrog3 = combine_matrices_with_mask(spectrog1,spec1_scaled,maskfile)
    segmentrecon = converter.audio2_from_spectrogram(spectrog1*real_t)
    segmentrecon.export(res_dir+i, format="wav")
    c+=1
    if c%100==0:
        print(c)

df_optim = pd.DataFrame(array, columns=['music_name', 't_value'])
df_optim.to_csv('/root/autodl-tmp/Audio_tests/BlendnMask/results/1022/1022_output_-24.csv', index=False)