import argparse
import glob
import os
import librosa
from tqdm import tqdm
import multiprocessing as mp
import logging

def process_file(file):
    generated_audio_path = f'{file}.wav'
    audio_path = f'{file}.wav'.replace(args.generated_audio_dir, args.gt_audio_dir)

    audio_peaks = detect_audio_peaks(audio_path, duration=args.duration, delta=args.delta)
    generated_audio_peaks = detect_audio_peaks(generated_audio_path, duration=args.duration, delta=args.delta)
    if len(audio_peaks) == 0:
        logging.warning(f"No GT peaks detected in {audio_path}")
        return None  
    if len(generated_audio_peaks) == 0:
        logging.warning(f"\nNo Generated peaks detected in {generated_audio_path}")
        return None  
    
    score = calc_intersection_over_union(audio_peaks, generated_audio_peaks, args.threshold)
    return score

# Function to detect audio peaks using the Onset Detection algorithm
def detect_audio_peaks(audio_file, duration=2, required_sr=16000, delta=0.2):
    y, sr = librosa.load(audio_file, sr=required_sr)
    y = y[:int(sr*duration)]
    
    if max(abs(y)) < 0.01:
        logging.warning(f"Audio amplitude is too low: audio_file: {audio_file}, sr: {sr}, max: {max(y)}")
        return []
    # Calculate the onset envelope
    onset_env = librosa.onset.onset_strength(y=y, sr=sr)
    # Get the onset events
    onset_frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr, delta=delta)
    onset_times = librosa.frames_to_time(onset_frames, sr=sr)
    return onset_times


# Function to calculate Intersection over Union (IoU) for audio and generated_audio peaks
def calc_intersection_over_union(audio_peaks, generated_audio_peaks, threshold):
    intersection_length = 0
    used_generated_audio_peaks = [False] * len(generated_audio_peaks)
    for audio_peak in audio_peaks:
        for j, generated_audio_peak in enumerate(generated_audio_peaks):
            if not used_generated_audio_peaks[j] and abs(generated_audio_peak - audio_peak) < threshold:
                intersection_length += 1
                used_generated_audio_peaks[j] = True
                break
    return intersection_length / (len(audio_peaks) + len(generated_audio_peaks) - intersection_length)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--generated_audio_dir", type=str, required=True)
    parser.add_argument("--gt_audio_dir", type=str, required=True)
    parser.add_argument("--result_save_path", type=str, required=True)
    parser.add_argument("--threshold", type=float, default=0.1) 
    parser.add_argument("--delta", type=float, default=0.2) 
    parser.add_argument("--duration", type=float, default=2.0, help="Duration in seconds to consider for audio peaks")
    
    args = parser.parse_args()
    logging_path = args.result_save_path.replace('.txt', '.log')
    logging.basicConfig(
        level=logging.INFO,
        format='[%(asctime)s] %(levelname)s: %(message)s',
        handlers=[
            logging.FileHandler(logging_path),
        ]
    )
    files = glob.glob(os.path.join(args.generated_audio_dir, '**', '*.wav'), recursive=True)
    files = [f[:-4] for f in files]

    n_cores = mp.cpu_count()
    with mp.Pool(n_cores) as pool:
        results = list(tqdm(pool.imap(process_file, files), total=len(files)))
    valid_results = [r for r in results if r is not None]
    exception = len(results) - len(valid_results)

    score_agg = 0

    for score in valid_results:
        score_agg += score
        
    total = len(valid_results)
    score_agg /= total