import os
import torch
import torchaudio
import random
import argparse
from tqdm import tqdm
import torch.multiprocessing as mp
from audio_augmentation import AudioProcessor
from transformers import Wav2Vec2Processor, Wav2Vec2Model

SAMPLE_RATE = 16000
MAX_DURATION = 20.0
MIN_DURATION = 2.0
MAX_LEN = int(SAMPLE_RATE * MAX_DURATION)
MIN_LEN = int(SAMPLE_RATE * MIN_DURATION)

def safe_vad(waveform, sample_rate=SAMPLE_RATE):
    try:
        vad = torchaudio.transforms.Vad(sample_rate=sample_rate)
        voiced = vad(waveform)
        return voiced if voiced.numel() > 0 else None
    except Exception as e:
        print("VAD error:", e)
        return None

def load_random_segment_from_same_speaker(current_path, sample_rate=16000, max_duration=3.0):
    session_dir = os.path.dirname(current_path)      # e.g., id10285/LbFFEF1pHO0
    speaker_dir = os.path.dirname(session_dir)        # e.g., id10285
    candidates = []

    for session_name in os.listdir(speaker_dir):
        session_path = os.path.join(speaker_dir, session_name)
        if not os.path.isdir(session_path):
            continue
        for f in os.listdir(session_path):
            if f.endswith('.wav'):
                full_path = os.path.join(session_path, f)
                if os.path.abspath(full_path) != os.path.abspath(current_path):
                    candidates.append(full_path)

    if not candidates:
        return None

    random.shuffle(candidates)
    max_len = int(max_duration * sample_rate)
    min_len = int(1.0 *sample_rate)
    accumulated = []
    total_len = 0
    for full_path in candidates: 
        waveform, sr = torchaudio.load(full_path)
        if sr != sample_rate:
            waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
            print(f"Resampled {full_path} from {sr} to {sample_rate}")

        waveform = waveform.to(torch.float32)
        voiced = safe_vad(waveform)
        if voiced is None or voiced.size(-1) < min_len:
            print(f"VAD failed for {full_path}")
            continue
        
        accumulated.append(voiced)
        total_len += voiced.size(-1)
        if total_len >= max_len:
            break

    if total_len < max_len:
        print(f"VAD is too short for {current_path}")
        return None

    merged_voiced = torch.cat(accumulated, dim=-1)
    start = random.randint(0, merged_voiced.size(-1) - max_len)
    return merged_voiced[..., start:start + max_len]

def generate_benign_samples(waveform, sr):
    return [
        AudioProcessor.benign_resample(waveform, sr),
        AudioProcessor.benign_compression(waveform, sr),
        AudioProcessor.benign_reencode(waveform, sr),
        AudioProcessor.benign_noise_suppression(waveform, sr)
    ]
    
def generate_malicious_samples(waveform, sr, audio_path):
    T = waveform.size(-1)/SAMPLE_RATE
    segments_cache = {
        1.0: load_random_segment_from_same_speaker(audio_path, sr, max_duration=0.1*T),
        2.0: load_random_segment_from_same_speaker(audio_path, sr, max_duration=0.3*T),
        3.0: load_random_segment_from_same_speaker(audio_path, sr, max_duration=0.5*T),
    }
    if (segments_cache[1.0] is None) or (segments_cache[2.0] is None) or (segments_cache[3.0] is None):
        print(f"Failed to load segments for {audio_path}")
        return None
    else:
        return [
            AudioProcessor.malicious_delete(waveform, sr, ratio=0.1),
            AudioProcessor.malicious_delete(waveform, sr, ratio=0.3),
            AudioProcessor.malicious_delete(waveform, sr, ratio=0.5),
            AudioProcessor.malicious_splice(waveform, sr, segments_cache[1.0]),
            AudioProcessor.malicious_splice(waveform, sr, segments_cache[2.0]),
            AudioProcessor.malicious_splice(waveform, sr, segments_cache[3.0]),
            AudioProcessor.malicious_silence(waveform, sr, ratio=0.1),
            AudioProcessor.malicious_silence(waveform, sr, ratio=0.3),
            AudioProcessor.malicious_silence(waveform, sr, ratio=0.5),
            AudioProcessor.malicious_substitute(waveform, sr, segments_cache[1.0]),
            AudioProcessor.malicious_substitute(waveform, sr, segments_cache[2.0]),
            AudioProcessor.malicious_substitute(waveform, sr, segments_cache[3.0]),
            AudioProcessor.malicious_reorder(waveform, sr),
            AudioProcessor.malicious_voice_conversion(waveform, sr),
        ]
        
def process_subset(rank, audio_files, save_dir, device):
    torch.cuda.set_device(device)
    model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(device)
    model.eval()
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

    for audio_path in tqdm(audio_files, desc=f"GPU{device}"):
        waveform, sr = torchaudio.load(audio_path)
        if sr != SAMPLE_RATE:
            waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
        waveform = waveform.squeeze(0)

        if waveform.size(0) > MAX_LEN:
            waveform = waveform[:MAX_LEN]
        if waveform.size(0) < MIN_LEN:
            print(f"[SKIP] {audio_path} too short: {waveform.size(0)/SAMPLE_RATE:.2f}s")
            continue

        benign = generate_benign_samples(waveform.unsqueeze(0), SAMPLE_RATE)
        malicious = generate_malicious_samples(waveform.unsqueeze(0), SAMPLE_RATE, audio_path)
        if malicious is None:
            print(f"Malicious sample generation failed for {audio_path}")
            continue

        all_samples = [waveform] + [b.squeeze(0) for b in benign] + [m.squeeze(0) for m in malicious]

        processed_data = []
        with torch.no_grad():
            for sample in all_samples:
                inputs = processor(sample.numpy(), sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=False, return_attention_mask=False)
                hidden_states = model(inputs.input_values.to(device)).last_hidden_state.squeeze(0)
                processed_data.append({"hidden_states": hidden_states.cpu()})

        speaker_id = os.path.basename(os.path.dirname(os.path.dirname(audio_path)))     # id10282
        session_id = os.path.basename(os.path.dirname(audio_path))                      # qkZNuvX1UNo
        file_id = os.path.basename(audio_path).replace('.wav', '')                      # 00001
        base_name = f"{speaker_id}_{session_id}_{file_id}"                              # id10282_qkZNuvX1UNo_00001

        torch.save({
            "anchor": processed_data[0],
            "positives": processed_data[1:5],
            "negatives": processed_data[5:19],
        }, os.path.join(save_dir, f"{base_name}.pt"))

def run_worker(rank, save_dir, chunks):
    device = f"cuda:{rank}"
    audio_files = chunks[rank]
    process_subset(rank, audio_files, save_dir, device)

def preprocess_and_save_parallel(data_dir, save_dir, num_gpus=4, num_samples=5000):
    os.makedirs(save_dir, exist_ok=True)
    audio_files = []
    for dirpath, _, filenames in os.walk(data_dir):
        for filename in filenames:
            if filename.endswith('.wav'):
                full_path = os.path.join(dirpath, filename)
                audio_files.append(full_path)

    print(f"Found {len(audio_files)} audio files.")
    audio_files = random.sample(audio_files, min(num_samples, len(audio_files)))
    chunks = [audio_files[i::num_gpus] for i in range(num_gpus)]

    mp.set_start_method("spawn", force=True)
    mp.spawn(run_worker, args=(save_dir, chunks), nprocs=num_gpus)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Preprocess and save audio files in parallel.")
    parser.add_argument("--data_dir", type=str, default="../vox_train/", help="Path to the directory containing the audio files.")
    parser.add_argument("--save_dir", type=str, default="../preprocessed_vox_train", help="Path to save the preprocessed data.")
    parser.add_argument("--num_gpus", type=int, default=4, help="Number of GPUs to use for parallel processing.")
    parser.add_argument("--num_samples", type=int, default=5000, help="Number of samples to use for training.")
    args = parser.parse_args()

    preprocess_and_save_parallel(
        data_dir=args.data_dir,
        save_dir=args.save_dir,
        num_gpus=args.num_gpus,
        num_samples=args.num_samples
    )


