import os
import argparse
import torchaudio
import torch
import random
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from collections import defaultdict

# --- CONFIGURATION ---
DEFAULT_PATHS = [
    "/storage/data/LibriTTS",
    "/storage/data/FMA"
]
# ---------------------

def check_file(file_path, target_duration=5.0):
    """
    Exact simulation of your __getitem__ logic to catch real training failures.
    """
    try:
        # 1. Metadata Check
        try:
            info = torchaudio.info(file_path)
        except RuntimeError as e:
            return "corrupt_metadata", str(e)
            
        sr = info.sample_rate
        total_frames = info.num_frames
        
        if total_frames == 0:
             return "empty", "Zero frames"

        is_compressed = file_path.lower().endswith(('.mp3', '.ogg'))
        src_target_len = int(target_duration * sr)

        # 2. Simulation of Loading Logic
        if is_compressed:
            # MP3/OGG Strategy: Full Load
            try:
                # We load the file to verify decoding integrity (often fails on bad headers)
                torchaudio.load(file_path)
            except Exception as e:
                return "corrupt_decode", str(e)
        else:
            # WAV/FLAC Strategy: Lazy Load (Seek Test)
            try:
                if total_frames > src_target_len:
                    # Test a random seek to ensure the file body exists/is readable
                    start = random.randint(0, total_frames - src_target_len)
                    torchaudio.load(file_path, frame_offset=start, num_frames=src_target_len)
                else:
                    torchaudio.load(file_path)
            except Exception as e:
                return "corrupt_seek", str(e)

        return "valid", None

    except Exception as e:
        return "error_unknown", str(e)

def scan_paths(paths, extensions, num_workers=16):
    stats = defaultdict(lambda: {"valid": 0, "corrupt": 0, "errors": []})
    all_files = []

    # 1. Collection Phase
    print(f"Collecting files from {len(paths)} locations...")
    for p in paths:
        root_path = Path(p)
        if not root_path.exists():
            print(f"Warning: Path not found: {p}")
            continue
        
        # Use the folder name (e.g., 'LibriTTS') as the dataset label
        dataset_name = root_path.name 
        
        curr_files = []
        for ext in extensions:
            curr_files.extend(list(root_path.rglob(f"*{ext}")))
        
        print(f"  -> {dataset_name}: Found {len(curr_files)} files.")
        
        # Store tuple: (full_path, dataset_label)
        for f in curr_files:
            all_files.append((str(f), dataset_name))

    print(f"\nChecking integrity of {len(all_files)} files using {num_workers} workers...")
    
    # 2. Parallel Processing
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        # Map future -> (file_path, dataset_name)
        futures = {
            executor.submit(check_file, f_path): (f_path, ds_name) 
            for (f_path, ds_name) in all_files
        }
        
        for future in tqdm(as_completed(futures), total=len(all_files), unit="files"):
            file_path, dataset_name = futures[future]
            
            try:
                status, error_msg = future.result()
                
                if status == "valid":
                    stats[dataset_name]["valid"] += 1
                else:
                    stats[dataset_name]["corrupt"] += 1
                    if len(stats[dataset_name]["errors"]) < 5:
                        stats[dataset_name]["errors"].append(f"{Path(file_path).name}: {error_msg}")
                        
            except Exception as e:
                stats[dataset_name]["corrupt"] += 1
                print(f"System Error on {file_path}: {e}")

    return stats

def print_report(stats):
    print("\n" + "="*80)
    print(f"{'DATASET':<20} | {'VALID':<10} | {'CORRUPT':<10} | {'VALID (%)':<10} | {'SHARE (%)':<10}")
    print("-" * 80)
    
    total_valid_global = sum(d['valid'] for d in stats.values())
    total_corrupt_global = sum(d['corrupt'] for d in stats.values())
    
    for dataset in sorted(stats.keys()):
        data = stats[dataset]
        v = data["valid"]
        c = data["corrupt"]
        local_total = v + c
        
        # % of this specific dataset that is valid
        valid_ratio = (v / local_total * 100) if local_total > 0 else 0
        
        # % of the TOTAL training data this dataset represents
        share_ratio = (v / total_valid_global * 100) if total_valid_global > 0 else 0
        
        print(f"{dataset:<20} | {v:<10} | {c:<10} | {valid_ratio:<9.1f}% | {share_ratio:<9.1f}%")

    print("-" * 80)
    print(f"{'TOTAL':<20} | {total_valid_global:<10} | {total_corrupt_global:<10} | "
          f"{(total_valid_global/(total_valid_global+total_corrupt_global)*100):.1f}%      | 100.0%")
    print("="*80)
    
    print("\n--- Corruption Diagnosis ---")
    for dataset, data in stats.items():
        if data["errors"]:
            print(f"\n[{dataset}] Common Errors:")
            for err in data["errors"]:
                print(f"  - {err}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--paths", nargs='+', default=DEFAULT_PATHS, help="List of paths to scan")
    parser.add_argument("--workers", type=int, default=32, help="Number of CPU workers")
    args = parser.parse_args()
    
    # Common audio extensions
    EXTENSIONS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a', '.opus']
    
    stats = scan_paths(args.paths, EXTENSIONS, args.workers)
    print_report(stats)