import glob
import os
import torchaudio
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor

import warnings
warnings.filterwarnings("ignore", category=UserWarning)


def check_file(filepath):
    """Try to load the file. Return filepath if corrupt, else None."""
    try:
        # Just try to read the metadata
        info = torchaudio.info(filepath)
        # Optional: Try to load a tiny chunk to ensure header is valid
        torchaudio.load(filepath, frame_offset=0, num_frames=1)
        return None
    except Exception:
        return filepath

def clean_dataset(roots):
    # 1. Find all files
    audio_files = []
    extensions = ('wav', 'mp3', 'flac', 'ogg')
    
    print("Scanning directories...")
    if isinstance(roots, str): roots = [roots]
    
    for root in roots:
        for ext in extensions:
            audio_files.extend(glob.glob(os.path.join(root, f"**/*.{ext}"), recursive=True))
            
    print(f"Found {len(audio_files)} files. Checking for corruption...")

    # 2. Parallel Check (Fast)
    bad_files = []
    with ProcessPoolExecutor(max_workers=32) as executor:
        for result in tqdm(executor.map(check_file, audio_files), total=len(audio_files)):
            if result:
                bad_files.append(result)

    # 3. Report & Delete
    print(f"\nFound {len(bad_files)} corrupt files.")
    if len(bad_files) > 0:
        print("Example corrupt file:", bad_files[0])
        confirm = input("Delete these files? (y/n): ")
        if confirm.lower() == 'y':
            for f in bad_files:
                try:
                    os.remove(f)
                except OSError:
                    print(f"Could not remove {f}")
            print("Cleanup complete.")
        else:
            print("Files kept. Training will likely crash without try...except.")

if __name__ == "__main__":
    # Update these paths to your actual data locations
    clean_dataset([
        "/storage/data/LibriTTS",
        "/storage/data/FMA"
    ])

    # # The physical cores on the box (misleading in SLURM/Docker)
    # physical_cores = os.cpu_count()

    # # The actual cores your job is allowed to use (TRUST THIS ONE)
    # # Note: This function is only available on Linux (standard for HPC)
    # usable_cores = len(os.sched_getaffinity(0))

    # print(f"Physical Cores: {physical_cores}")
    # print(f"Usable Cores:   {usable_cores}")
