# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import glob
import json
import random
import sphn

import torchaudio
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

from .dist import get_rank, get_world_size

CACHE_DIR = '.cache/datafiles/'
if not os.path.exists(CACHE_DIR):
    print(f"Cache directory {CACHE_DIR} does not exist. Creating it.")
    os.makedirs(CACHE_DIR, exist_ok=True)

def get_cached_audio_files(audio_dir, extensions=('wav', 'mp3', 'flac', 'ogg')):
    """Find all audio files recursively with caching for faster loading."""
    if isinstance(audio_dir, list):
        print(f"Processing mixed dataset from {len(audio_dir)} sources...")
        all_sources = []
        
        # Load each source using recursion (reuses the caching logic above)
        for d in audio_dir:
            all_sources.append(get_cached_audio_files(d, extensions))
            
        # Balance sizes (Oversample smaller datasets to match the largest)
        max_len = max(len(s) for s in all_sources)
        balanced_sources = []
        for src in all_sources:
            if len(src) < max_len:
                # Repeat source until it fills length
                factor = (max_len // len(src)) + 1
                extended = (src * factor)[:max_len]
                balanced_sources.append(extended)
            else:
                balanced_sources.append(src)
                
        # Interleave (Zip) them: [A1, B1, A2, B2, ...]
        # This ensures every batch has a mix of all directories
        interleaved = [val for tup in zip(*balanced_sources) for val in tup]
        print(f"Balanced Mixed Dataset: {len(interleaved)} total files (Interleaved).")
        return interleaved

    # Create a cache filename based on the audio directory path
    cache_file = os.path.basename(audio_dir.rstrip('/')) + '_' + audio_dir.replace('/', '_') + '.json'
    cache_file = os.path.join(CACHE_DIR, cache_file)
    
    if os.path.exists(cache_file):
        print(f"Loading audio files from cache: {cache_file}")
        with open(cache_file, 'r') as f:
            audio_files = json.load(f)
    else:
        print(f"Finding audio files in {audio_dir}...")
        audio_files = []
        for ext in extensions:
            audio_files.extend(glob.glob(os.path.join(audio_dir, f"**/*.{ext}"), recursive=True))
        audio_files = sorted(audio_files)
        
        print(f"Caching {len(audio_files)} audio paths to {cache_file}")
        with open(cache_file, 'w') as f:
            json.dump(audio_files, f)
    
    return audio_files


class AudioDataset(Dataset):
    """Dataset for loading audio files from a directory recursively."""
    
    def __init__(self, audio_dir, target_sr=24000, target_duration=5.0, extensions=('wav', 'mp3', 'flac', 'ogg')):
        """
        Initialize the dataset.
        
        Args:
            audio_dir: Directory containing audio files (OR directory list)
            target_sr: Target sample rate
            target_duration: Target duration in seconds
            extensions: Audio file extensions to look for
        """
        self.audio_dir = audio_dir
        self.target_sr = target_sr
        self.target_duration = target_duration
        self.target_length = int(target_sr * target_duration)
        
        # Find all audio files recursively using caching
        self.audio_files = get_cached_audio_files(audio_dir, extensions)
        
        print(f"Found {len(self.audio_files)} audio files.")
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_file = self.audio_files[idx]
        
        try:
            # Check if file is likely compressed (MP3/OGG) which fails with random seeking
            is_compressed = audio_file.lower().endswith(('.mp3', '.ogg'))
            
            info = torchaudio.info(audio_file)
            sr = info.sample_rate
            total_frames = info.num_frames
            src_target_len = int(self.target_duration * sr)
            
            # STRATEGY: 
            # 1. MP3/OGG: Load FULL file to prevent libmpg123 sync errors.
            # 2. WAV/FLAC: Load CROP (lazy) to prevent OOM on large files.
            if is_compressed:
                # Load full file (Safe decoding)
                audio, sr = torchaudio.load(audio_file)
                
                # Post-load crop
                if audio.shape[1] > src_target_len:
                    start = random.randint(0, audio.shape[1] - src_target_len)
                    audio = audio[:, start : start + src_target_len]
            else:
                # Lazy load (Safe seeking, saves RAM)
                if total_frames > src_target_len:
                    start = random.randint(0, total_frames - src_target_len)
                    audio, sr = torchaudio.load(audio_file, frame_offset=start, num_frames=src_target_len)
                else:
                    audio, sr = torchaudio.load(audio_file)

            # Resample
            if sr != self.target_sr:
                audio = torchaudio.functional.resample(audio, sr, self.target_sr)
            
            # Mono Mixdown
            if audio.shape[0] > 1:
                audio = audio.sum(dim=0, keepdim=True) 
            
            # Final Length Fix (Pad or Trim)
            if audio.shape[1] < self.target_length:
                pad_length = self.target_length - audio.shape[1]
                audio = F.pad(audio, (0, pad_length), "constant", 0)
            else:
                audio = audio[:, :self.target_length]

            return audio

        except Exception as e:
            # Fallback for corrupt files (Critical for preventing training crash)
            # print(f"Skipping corrupt file: {audio_file} ({e})") 
            return self.__getitem__((idx + 1) % len(self))


def get_audio_dataloader(dataset: Dataset, batch_size=16, num_workers=16, shuffle=True, distributed=False):
    """Create a dataloader for a given dataset."""
    # dataset = AudioDataset(audio_dir, target_sr, target_duration) # Removed dataset creation
    
    if distributed:
        sampler = DistributedSampler(
            dataset,
            num_replicas=get_world_size(),
            rank=get_rank(),
            shuffle=shuffle
        )
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            sampler=sampler,
            prefetch_factor=4
        )
    else:
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=shuffle,
            pin_memory=True,
            prefetch_factor=4
        )
    return dataloader
