import datetime
import glob
import soundfile as sf
from speechbrain.inference.separation import SepformerSeparation as separator
import torchaudio
from utils import *
import wave
from pydub import AudioSegment
model = separator.from_hparams(source="speechbrain/sepformer-dns4-16k-enhancement", savedir='pretrained_models/sepformer-dns4-16k-enhancement')

def resample_audio(input_file, output_file, target_sample_rate):
    command = [
        'ffmpeg',
        '-i', input_file,
        '-ar', str(target_sample_rate),
        output_file
    ]
    subprocess.run(command)

def enhance_vocals(root_folder):
    start = datetime.datetime.now()
    folders = get_all_folders(root_folder)
    subsubfolder = root_folder
    wav_files = glob.glob(os.path.join(subsubfolder+"\\", "*.wav"))

    for wav_path in wav_files:
        wav_file = os.path.basename(wav_path)
        file_name_without_extension = os.path.splitext(wav_file)[0]

        new_wav_path = ""
        with wave.open(wav_path, 'rb') as audio_file:
            sample_rate = audio_file.getframerate()
            if sample_rate != 16000:
                input_file = wav_path
                output_file = f"{subsubfolder}/{file_name_without_extension}_16k.wav"
                target_sample_rate = 16000
                resample_audio(input_file, output_file, target_sample_rate)

                new_wav_path = output_file
            else:
                new_wav_path = wav_path
            est_sources = model.separate_file(path=new_wav_path)
            new_folder = "L:\\"
            torchaudio.save(new_folder+"\\"+file_name_without_extension+"_enhanced.wav", est_sources[:, :, 0].detach().cpu(), 16000)