import os
import torch
import torchaudio
import gc
from tqdm import tqdm
from audioseal import AudioSeal
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from model import FingerprintGenerator
import argparse
import json

def extract_watermark(audio_tensor, audioseal_det, device, sample_rate):
    if audio_tensor.ndim == 1:
        audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
    if audio_tensor.ndim == 2:
        audio_tensor = audio_tensor.unsqueeze(1)

    T = audio_tensor.shape[-1]
    segment_length = T // 16
    segments = [audio_tensor[..., i * segment_length:(i + 1) * segment_length] for i in range(16)]
    extracted_bits = []
    for seg in segments:
        seg = seg.to(device)
        with torch.no_grad():
            _, detected_msg = audioseal_det.detect_watermark(seg, sample_rate=sample_rate, message_threshold=0.5)
        extracted_bits.append(detected_msg.detach().cpu())
    return torch.cat(extracted_bits, dim=1).squeeze()

def extract_fingerprint(audio_tensor, processor, wav2vec_model, model, device, sample_rate):
    audio_1d = audio_tensor.squeeze().cpu().numpy()
    with torch.no_grad():
        inputs = processor(audio_1d, sampling_rate=sample_rate, return_tensors="pt", padding=False, return_attention_mask=False)
        hidden_states = wav2vec_model(inputs.input_values.to(device)).last_hidden_state.squeeze(0)
        hidden_states = hidden_states.unsqueeze(0)  # [1, T, D]
        attention_mask = torch.ones(1, hidden_states.size(1)).to(device)
        hash_vector = model(hidden_states, attention_mask)
        hash_bits = (torch.sign(hash_vector) > 0).long().squeeze()
    return hash_bits.detach().cpu()

def bits_to_hex(bit_tensor):
    bit_list = bit_tensor.numpy().tolist()
    bit_str = ''.join(str(int(b)) for b in bit_list)
    # Convert to hex string
    hex_str = hex(int(bit_str, 2))[2:]  # Remove "0x"
    return hex_str.upper()

def main(args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    audioseal_gen = AudioSeal.load_generator("audioseal_wm_16bits").to(device)
    audioseal_det = AudioSeal.load_detector("audioseal_detector_16bits").to(device)

    model = FingerprintGenerator().to(device)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
    model.load_state_dict(state_dict)
    model.eval()

    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
    wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(device)
    wav2vec_model.eval()
    results = []

    for fname in tqdm(os.listdir(args.test_dir), desc="Evaluating"):
        if not fname.lower().endswith((".wav")):
            continue
        audio_path = os.path.join(args.test_dir, fname)
        wav, sr = torchaudio.load(audio_path)
        if sr != args.sample_rate:
            wav = torchaudio.transforms.Resample(sr, args.sample_rate)(wav)

        wm_bits = extract_watermark(wav.unsqueeze(0), audioseal_det, device, args.sample_rate)
        fp_bits = extract_fingerprint(wav, processor, wav2vec_model, model, device, args.sample_rate)

        hamm = (wm_bits != fp_bits).sum().item()
        verdict = "Tampered" if hamm > args.threshold else "Authentic"

        print(f"{fname}: Hamming distance = {hamm} | {verdict}")

        results.append({
            "fname": fname,
            "fingerprint": bits_to_hex(fp_bits),
            "watermark": bits_to_hex(wm_bits),
            "hamming_distance": hamm,
            "verdict": verdict
        })

        torch.cuda.empty_cache()
        gc.collect()

    out_path = os.path.join(args.test_dir, "results.json")
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\n✅ Results saved to {out_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate Hamming distance between watermark and fingerprint")
    parser.add_argument("--test_dir", type=str, required=True, help="Directory containing test audio files")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to fingerprint model checkpoint (.pth)")
    parser.add_argument("--threshold", type=int, default=50, help="Hamming distance threshold")
    parser.add_argument("--sample_rate", type=int, default=16000, help="Audio sample rate")
    args = parser.parse_args()

    main(args)


