import os
import torch
import torchaudio
import soundfile as sf
from tqdm import tqdm
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from audioseal import AudioSeal
import argparse
from model import FingerprintGenerator
import random


def main(args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.save_dir, exist_ok=True)

    model = FingerprintGenerator().to(device)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
    model.load_state_dict(state_dict)
    model.eval()

    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
    wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(device)
    wav2vec_model.eval()

    audioseal_gen = AudioSeal.load_generator("audioseal_wm_16bits").to(device)
    audioseal_det = AudioSeal.load_detector("audioseal_detector_16bits").to(device)

    test_files = []
    for dirpath, _, filenames in os.walk(args.test_dir):
        for filename in filenames:
            if filename.endswith(".wav"):
                test_files.append(os.path.join(dirpath, filename))
    # test_files = random.sample(test_files, 500)  

    for wav_path in tqdm(test_files, desc="Processing audio"):
        waveform, sr = torchaudio.load(wav_path)
        waveform = waveform.squeeze(0)
        inputs = processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt", padding=False, return_attention_mask=False)
        with torch.no_grad():
            hidden_states = wav2vec_model(inputs.input_values.to(device)).last_hidden_state.squeeze(0)
            hidden_states = hidden_states.unsqueeze(0)  # [1, T, D]
            mask = torch.ones((1, hidden_states.size(1)), dtype=torch.bool).to(device)  # shape: [1, T]
            hash_vector = model(hidden_states, mask)
            hash_bits = torch.sign(hash_vector)  # [1, 256] 
            hash_bits = (hash_bits > 0).long() 

        segments = waveform.split(waveform.size(0) // 16)
        watermarked_segments = []
        for i in range(16):
            msg = (hash_bits[0][i*16:(i+1)*16] > 0).long().unsqueeze(0).to(device) 
            segment = segments[i].unsqueeze(0).unsqueeze(0).to(device)
            wm = audioseal_gen.get_watermark(segment, message=msg, sample_rate=16000)
            watermarked_segments.append((segment + wm).cpu())

        watermarked_audio = torch.cat(watermarked_segments, dim=-1).squeeze(0).squeeze(0).detach().numpy()
        
        rel_path = os.path.relpath(wav_path, start=args.test_dir)
        rel_no_ext = os.path.splitext(rel_path)[0]                # id10282/qkZNuvX1UNo/XXXX
        flat_name = rel_no_ext.replace(os.sep, "_")   
        save_path = os.path.join(args.save_dir, flat_name + "_watermarked.wav")
        sf.write(save_path, watermarked_audio, 16000)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_dir", type=str, required=True, help="Save directory")
    parser.add_argument("--test_dir", type=str, required=True, help="Test directory")
    parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint path")
    args = parser.parse_args()
    main(args)