import os
import json
import torch
from glob import glob
from tqdm import tqdm

def read_txt(file_name):
    with open(file_name) as f:
        file_list = f.readlines()
    return file_list


SAMPLING_RATE = 16000

model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=True, onnx=False)
model = model.cuda()

(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils

# Define the output JSON file path
output_path = '/data/localssd/texture/speech_non_speech_ambient_timesteps.json'
wav_list = sorted(glob('/data/localssd/actionkid/audio/*.wav'))

# Loop through each WAV file in the input directory
all_data = []
for wav_path in tqdm(sorted(wav_list)):
	wav_path = wav_path.strip()
	file_name = wav_path.strip()
	if 'unbalanced_audio' in wav_path:
		folder = 'separated_unbalanced_audio'
	elif 'unbalanced_wav' in wav_path:
		folder = 'separated_unbalanced_wav'
	else:
		folder = 'separated_audio'
	wav_path = os.path.join("/".join(wav_path.split('/')[:-2]), folder, f"{wav_path.split('/')[-1][:-4]}_speech.wav")

	# Load the WAV file
	wav = read_audio(wav_path, sampling_rate=SAMPLING_RATE).cuda()
	max_onset = wav.shape[0] - 2.56 * SAMPLING_RATE * 2

	# Extract the speech timestamps
	speech_timestamps = get_speech_timestamps(wav, model, threshold=0.7, sampling_rate=SAMPLING_RATE)
	if len(speech_timestamps) == 0:
		continue

	speech_timestps = []
	for i, speech_timestamp in enumerate(speech_timestamps):
		if speech_timestamp['start'] <= max_onset:
			if speech_timestamp['end'] > max_onset:
				speech_timestamp['end'] = int(max_onset)
			if speech_timestamp['start'] == speech_timestamp['end']:
				continue
			speech_timestps.append(speech_timestamp)

	if len(speech_timestps) == 0:
		continue

	data = {
		'name': file_name,
		'speech_timestamps': speech_timestps
	}

	# Add the data to the list
	all_data.append(data)

# Save all the speech intervals to a single JSON file
with open(output_path, 'w') as f:
	if len(all_data) != 0:
		json.dump(all_data, f, indent=2)            


