import datetime
import glob
import os
import wave
import torch
import demucs.api
import demucs.separate
import librosa
import numpy as np
from speechbrain.pretrained import SepformerSeparation as separator
from utils import *

def resample_audio(input_file, output_file, target_sample_rate):
    command = [
        'ffmpeg',
        '-i', input_file,
        '-ar', str(target_sample_rate),
        output_file
    ]
    subprocess.run(command)

def demucs_samplerate(root_folder):
    start = datetime.datetime.now()
    folders = get_all_folders(root_folder)
    for folder in folders:
        subfolders = get_all_folders(folder)
        for subfolder in subfolders:
            subsubfolders = get_all_folders(subfolder)
            for subsubfolder in subsubfolders:
                wav_files = glob.glob(os.path.join(subsubfolder+"\\", "*.wav"))
                for wav_path in wav_files:
                    wav_file = os.path.basename(wav_path)
                    file_name_without_extension = os.path.splitext(wav_file)[0]
                    wav_file = file_name_without_extension+".wav"
                    wav_path = os.path.join(subsubfolder, wav_file)   
                    
                    if(not file_name_without_extension.endswith("_vocals")
                        and not file_name_without_extension.endswith("_background")
                        and not file_name_without_extension.endswith("_vocals_16k")
                        and not file_name_without_extension.endswith("_background_16k")
                    ):
                        background_path = os.path.join(subsubfolder, file_name_without_extension+"_background.wav")
                        vocals_path = os.path.join(subsubfolder, file_name_without_extension + "_vocals.wav")

                        if(os.path.exists(background_path) and os.path.exists(vocals_path)):
                            pass
                        else:
                            separator = demucs.api.Separator(model="mdx_extra", segment=12)
                            origin, separated = separator.separate_audio_file(wav_path)
                            os.makedirs(subsubfolder, exist_ok=True)

                            other = drums = bass = None
                            for item in separated:
                                # for source in sources:
                                if (item == "vocals"):
                                    demucs.api.save_audio(separated[item], f"{subsubfolder}/{file_name_without_extension}_vocals.wav",
                                                      samplerate=separator.samplerate)
                                elif(item == "drums"):
                                    drums = separated[item]
                                elif(item == "bass"):
                                    bass = separated[item]
                                elif(item == "other"):
                                    other = separated[item]

                            no_vocals = torch.add(drums,bass)
                            no_vocals = torch.add(no_vocals,other)
                            demucs.api.save_audio(no_vocals, f"{subsubfolder}/{file_name_without_extension}_background.wav",
                                                  samplerate=separator.samplerate)

                        man_wav_path = f"{subsubfolder}/{file_name_without_extension}_vocals.wav"
                        background_wav_path = f"{subsubfolder}/{file_name_without_extension}_background.wav"

                        man_audio, _ = librosa.load(man_wav_path, sr=None)
                        background_audio, _ = librosa.load(background_wav_path, sr=None)

                        snr = 10 * np.log10(np.mean(man_audio ** 2) / np.mean(background_audio ** 2))

                        if(snr <= 4):
                            if os.path.exists(man_wav_path):
                                os.remove(man_wav_path)
                                print("delete success: "+man_wav_path)
                            if os.path.exists(background_wav_path):
                                os.remove(background_wav_path)
                                print("delete success: " + background_wav_path)
                        else:
                            vocals_16k_path = os.path.join(subsubfolder, file_name_without_extension + "_vocals_16k.wav")
                            if(os.path.exists(vocals_16k_path)):
                                pass
                            else:
                                with wave.open(man_wav_path, 'rb') as audio_file:
                                    sample_rate = audio_file.getframerate()
                                    converted_file = f"{subsubfolder}/{file_name_without_extension}_vocals_16k.wav"
                                    if sample_rate != 16000:
                                        input_file = audio_file
                                        target_sample_rate = 16000
                                        resample_audio(man_wav_path, converted_file, target_sample_rate)

                        print("saved:"+wav_path)

    end = datetime.datetime.now()