{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torchaudio\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from audioseal import AudioSeal\n",
    "from audio_augmentation import AudioProcessor\n",
    "import gc\n",
    "from transformers import Wav2Vec2Processor, Wav2Vec2Model\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from model import FingerprintGenerator\n",
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "audioseal_gen = AudioSeal.load_generator(\"audioseal_wm_16bits\").to(device)\n",
    "audioseal_det = AudioSeal.load_detector(\"audioseal_detector_16bits\").to(device)\n",
    "\n",
    "model = FingerprintGenerator().to(device)\n",
    "chkpt_name = \"./chkpt-t0.05/SpeeCheck_best.pth\"  \n",
    "test_dir = \"./protected_vox\"\n",
    "log_dir = f\"./logs/epoch{chkpt_name.split('_')[-1].split('.')[0]}_{os.path.basename(os.path.dirname(chkpt_name)).split('-')[1]}_{test_dir.split('_')[-1]}_{timestamp}/\"\n",
    "os.makedirs(log_dir, exist_ok=True)\n",
    "\n",
    "checkpoint = torch.load(chkpt_name, map_location=device)\n",
    "state_dict = {k.replace(\"module.\", \"\"): v for k, v in checkpoint.items()}\n",
    "model.load_state_dict(state_dict)\n",
    "print(model)\n",
    "model.eval()\n",
    "\n",
    "processor = Wav2Vec2Processor.from_pretrained(\"facebook/wav2vec2-base\")\n",
    "wav2vec_model = Wav2Vec2Model.from_pretrained(\"facebook/wav2vec2-base\").to(device)\n",
    "wav2vec_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAMPLE_RATE = 16000\n",
    "MAX_DURATION = 20.0\n",
    "MIN_DURATION = 2.0\n",
    "MAX_LEN = int(SAMPLE_RATE * MAX_DURATION)\n",
    "MIN_LEN = int(SAMPLE_RATE * MIN_DURATION)\n",
    "\n",
    "def safe_vad(waveform, sample_rate=SAMPLE_RATE):\n",
    "    try:\n",
    "        vad = torchaudio.transforms.Vad(sample_rate=sample_rate)\n",
    "        voiced = vad(waveform)\n",
    "        return voiced if voiced.numel() > 0 else None\n",
    "    except Exception as e:\n",
    "        print(\"VAD error:\", e)\n",
    "        return None\n",
    "\n",
    "\n",
    "def load_random_segment_from_same_speaker(current_path, sample_rate=16000, max_duration=3.0):\n",
    "    filename = os.path.basename(current_path)\n",
    "    speaker_id = filename.split(\"_\")[0]        # id10293\n",
    "    if test_dir == \"./protected_vox\":\n",
    "        speaker_dir = os.path.join(\"/data/voxceleb1\", \"test\", speaker_id)\n",
    "    elif test_dir == \"./protected_libri\":\n",
    "        speaker_dir = os.path.join(\"/data/LibriSpeech\", \"dev-clean-wav\", speaker_id)\n",
    "    else:\n",
    "        raise ValueError(f\"Invalid test_dir: {test_dir}\")\n",
    "    candidates = []\n",
    "\n",
    "    for session_name in os.listdir(speaker_dir):\n",
    "        session_path = os.path.join(speaker_dir, session_name)\n",
    "        if not os.path.isdir(session_path):\n",
    "            continue\n",
    "        for f in os.listdir(session_path):\n",
    "            if f.endswith('.wav'):\n",
    "                full_path = os.path.join(session_path, f)\n",
    "                if os.path.abspath(full_path) != os.path.abspath(current_path): \n",
    "                    candidates.append(full_path)\n",
    "\n",
    "    if not candidates:\n",
    "        return None\n",
    "\n",
    "    random.shuffle(candidates)\n",
    "    max_len = int(max_duration * sample_rate)\n",
    "    min_len = int(1.0 *sample_rate)\n",
    "    accumulated = []\n",
    "    total_len = 0\n",
    "    for full_path in candidates: \n",
    "        waveform, sr = torchaudio.load(full_path)\n",
    "        if sr != sample_rate:\n",
    "            waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)\n",
    "\n",
    "        waveform = waveform.to(torch.float32)\n",
    "        voiced = safe_vad(waveform)\n",
    "        if voiced is None or voiced.size(-1) < min_len:\n",
    "            print(f\"VAD failed for {full_path}\")\n",
    "            continue\n",
    "        \n",
    "        accumulated.append(voiced)\n",
    "        total_len += voiced.size(-1)\n",
    "        if total_len >= max_len:\n",
    "            break\n",
    "\n",
    "    if total_len < max_len:\n",
    "        print(f\"VAD is too short for {current_path}\")\n",
    "        return None\n",
    "\n",
    "    merged_voiced = torch.cat(accumulated, dim=-1)\n",
    "    start = random.randint(0, merged_voiced.size(-1) - max_len)\n",
    "    return merged_voiced[..., start:start + max_len]\n",
    "\n",
    "\n",
    "def generate_benign(waveform, sr):\n",
    "    return [\n",
    "        AudioProcessor.benign_resample(waveform, sr),\n",
    "        AudioProcessor.benign_compression(waveform, sr),\n",
    "        AudioProcessor.benign_reencode(waveform, sr),\n",
    "        AudioProcessor.benign_noise_suppression(waveform, sr)\n",
    "    ]\n",
    "    \n",
    "\n",
    "def generate_malicious(waveform, sr, audio_path):\n",
    "    T = waveform.size(-1)/SAMPLE_RATE\n",
    "    segments_cache = {\n",
    "        1.0: load_random_segment_from_same_speaker(audio_path, sr, max_duration=0.1*T),\n",
    "        2.0: load_random_segment_from_same_speaker(audio_path, sr, max_duration=0.3*T),\n",
    "        3.0: load_random_segment_from_same_speaker(audio_path, sr, max_duration=0.5*T),\n",
    "    }\n",
    "    if (segments_cache[1.0] is None) or (segments_cache[2.0] is None) or (segments_cache[3.0] is None):\n",
    "        print(f\"Failed to load segments for {audio_path}\")\n",
    "        return None\n",
    "    else:\n",
    "        return [\n",
    "            AudioProcessor.malicious_delete(waveform, sr, ratio=0.1),\n",
    "            AudioProcessor.malicious_delete(waveform, sr, ratio=0.3),\n",
    "            AudioProcessor.malicious_delete(waveform, sr, ratio=0.5),\n",
    "            AudioProcessor.malicious_splice(waveform, sr, segments_cache[1.0]),\n",
    "            AudioProcessor.malicious_splice(waveform, sr, segments_cache[2.0]),\n",
    "            AudioProcessor.malicious_splice(waveform, sr, segments_cache[3.0]),\n",
    "            AudioProcessor.malicious_silence(waveform, sr, ratio=0.1),\n",
    "            AudioProcessor.malicious_silence(waveform, sr, ratio=0.3),\n",
    "            AudioProcessor.malicious_silence(waveform, sr, ratio=0.5),\n",
    "            AudioProcessor.malicious_substitute(waveform, sr, segments_cache[1.0]),\n",
    "            AudioProcessor.malicious_substitute(waveform, sr, segments_cache[2.0]),\n",
    "            AudioProcessor.malicious_substitute(waveform, sr, segments_cache[3.0]),\n",
    "            AudioProcessor.malicious_reorder(waveform, sr),\n",
    "            AudioProcessor.malicious_voice_conversion(waveform, sr),\n",
    "        ]\n",
    "\n",
    "def extract_watermark(audio_tensor):\n",
    "    if audio_tensor.ndim == 1:\n",
    "        audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)\n",
    "    if audio_tensor.ndim == 2:\n",
    "        audio_tensor = audio_tensor.unsqueeze(1)\n",
    "        \n",
    "    T = audio_tensor.shape[-1]\n",
    "    segment_length = T // 16\n",
    "    segments = [audio_tensor[..., i * segment_length:(i + 1) * segment_length] for i in range(16)]\n",
    "    extracted_bits = []\n",
    "    for seg in segments:\n",
    "        seg = seg.to(device)\n",
    "        with torch.no_grad():\n",
    "            _, detected_msg = audioseal_det.detect_watermark(seg, sample_rate=16000, message_threshold=0.5)\n",
    "        detected_msg = detected_msg.detach().cpu() \n",
    "        extracted_bits.append(detected_msg)\n",
    "    return torch.cat(extracted_bits, dim=1).squeeze() \n",
    "\n",
    "def extract_fingerprint(audio_tensor):\n",
    "    audio_1d = audio_tensor.squeeze().cpu().numpy()\n",
    "    with torch.no_grad():\n",
    "        inputs = processor(audio_1d, sampling_rate=SAMPLE_RATE, return_tensors=\"pt\", padding=False, return_attention_mask=False)\n",
    "        hidden_states = wav2vec_model(inputs.input_values.to(device)).last_hidden_state.squeeze(0)\n",
    "        hidden_states = hidden_states.unsqueeze(0)  # [1, T, D]\n",
    "        T = hidden_states.size(1)\n",
    "        attention_mask = torch.ones(1, T).to(device)\n",
    "        hash_vector = model(hidden_states, attention_mask)\n",
    "        hash_bits = torch.sign(hash_vector)\n",
    "        hash_bits = (hash_bits > 0).long().squeeze()\n",
    "    return hash_bits.detach().cpu() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benign_types = [\"resample\", \"compression\", \"reencode\", \"noise_suppression\"]\n",
    "malicious_types = [\n",
    "    \"delete_minor\", \"delete_moderate\", \"delete_severe\",\n",
    "    \"splice_minor\", \"splice_moderate\", \"splice_severe\",\n",
    "    \"silence_minor\", \"silence_moderate\", \"silence_severe\",\n",
    "    \"substitute_minor\", \"substitute_moderate\", \"substitute_severe\",\n",
    "    \"reorder\", \"voice_conversion\"\n",
    "]\n",
    "benign_hamm_dict = {k: [] for k in benign_types}\n",
    "malicious_hamm_dict = {k: [] for k in malicious_types}\n",
    "\n",
    "for fname in tqdm(os.listdir(test_dir), desc=\"Evaluating\"):\n",
    "    audio_path = os.path.join(test_dir, fname)\n",
    "    wav, sr = torchaudio.load(audio_path)\n",
    "    wm_bits = extract_watermark(wav.unsqueeze(0))\n",
    "    fp_bits = extract_fingerprint(wav)\n",
    "\n",
    "    benign_samples = generate_benign(wav.unsqueeze(0), SAMPLE_RATE)\n",
    "    for op_name, w in zip(benign_types, benign_samples):\n",
    "        wm = extract_watermark(w)\n",
    "        fp = extract_fingerprint(w)\n",
    "        hamm = (wm != fp).sum().item()\n",
    "        benign_hamm_dict[op_name].append(hamm)\n",
    "\n",
    "    mal_samples = generate_malicious(wav, SAMPLE_RATE, audio_path)\n",
    "\n",
    "    for op_name_mal, w_mal in zip(malicious_types, mal_samples):\n",
    "        wm = extract_watermark(w_mal)\n",
    "        fp = extract_fingerprint(w_mal.squeeze())\n",
    "        hamm = (wm != fp).sum().item()\n",
    "        malicious_hamm_dict[op_name_mal].append(hamm)\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "    gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "results = {}\n",
    "for op in benign_types:\n",
    "    vals = benign_hamm_dict[op]\n",
    "    if vals:\n",
    "        results[op] = {\n",
    "            \"mean\": float(np.mean(vals))\n",
    "        }\n",
    "for op in malicious_types:\n",
    "    vals = malicious_hamm_dict[op]\n",
    "    if vals:\n",
    "        results[op] = {\n",
    "            \"mean\": float(np.mean(vals))\n",
    "        }\n",
    "\n",
    "with open(f'{log_dir}/hamming_distances_per_operation.json', 'w') as f:\n",
    "    json.dump(results, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_curve\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib.ticker as ticker\n",
    "\n",
    "pos_hamm_all = torch.tensor([x for v in benign_hamm_dict.values() for x in v])\n",
    "neg_hamm_all = torch.tensor([x for v in malicious_hamm_dict.values() for x in v])\n",
    "\n",
    "labels = np.concatenate([\n",
    "    np.ones_like(pos_hamm_all),\n",
    "    np.zeros_like(neg_hamm_all)\n",
    "])\n",
    "scores = np.concatenate([\n",
    "    -pos_hamm_all,\n",
    "    -neg_hamm_all\n",
    "])\n",
    "fpr, tpr, thresholds = roc_curve(labels, scores)\n",
    "best_threshold = thresholds[np.argmax(tpr - fpr)]\n",
    "best_hamm_threshold = -best_threshold\n",
    "print(\"best_hamm_threshold: \", best_hamm_threshold)\n",
    "\n",
    "# median\n",
    "pos_median = np.median(pos_hamm_all)\n",
    "neg_median = np.median(neg_hamm_all)\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "bins = np.arange(0, 255, 3)\n",
    "plt.hist(pos_hamm_all, bins=bins, alpha=0.8, density=True, label='Benign', color='#4F9DD9', edgecolor='black', linewidth=0.3)\n",
    "plt.hist(neg_hamm_all, bins=bins, alpha=0.8, density=True, label='Malicious', color='#F5A15A', edgecolor='black', linewidth=0.3)\n",
    "plt.axvline(pos_median, color='#4F9DD9', linestyle='--', linewidth=1.0)\n",
    "plt.axvline(neg_median, color='#F5A15A', linestyle='--', linewidth=1.0)\n",
    "\n",
    "plt.annotate(\n",
    "    \"\", xy=(neg_median, 0.05), xytext=(pos_median, 0.05),\n",
    "    arrowprops=dict(arrowstyle='<->', color='black', lw=0.5)\n",
    ")\n",
    "plt.text((pos_median + neg_median)/2, 0.05, f\"{neg_median - pos_median:.2f}\",\n",
    "         ha='center', va='bottom', fontsize=9)\n",
    "\n",
    "\n",
    "plt.xlabel(\"Hamming Distance\", fontsize=14)\n",
    "plt.ylabel(\"Density\", fontsize=14)\n",
    "plt.legend(loc='upper right', frameon=True, fontsize=12)\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"./figure/hamm_distribution.pdf\", format='pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, roc_curve\n",
    "def evaluate_hamming_benign(pos_hamm, neg_hamm, threshold):\n",
    "    labels = np.concatenate([\n",
    "        np.ones_like(pos_hamm),\n",
    "        np.zeros_like(neg_hamm)\n",
    "    ])\n",
    "    scores = np.concatenate([\n",
    "        -pos_hamm,\n",
    "        -neg_hamm\n",
    "    ])\n",
    "    preds = (scores > -threshold).astype(int)\n",
    "\n",
    "    acc = accuracy_score(labels, preds)\n",
    "    auc = roc_auc_score(labels, scores)\n",
    "    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()\n",
    "    tpr = tp / (tp + fn)\n",
    "    fpr = fp / (fp + tn)\n",
    "\n",
    "    fpr_curve, tpr_curve, thresholds = roc_curve(labels, scores)\n",
    "    fnr_curve = 1 - tpr_curve\n",
    "    eer_idx = np.nanargmin(np.abs(fpr_curve - fnr_curve))\n",
    "    eer = (fpr_curve[eer_idx] + fnr_curve[eer_idx]) / 2\n",
    "    \n",
    "    return {\n",
    "        \"Acc\": acc,\n",
    "        \"TPR\": tpr,\n",
    "        \"FPR\": fpr,\n",
    "        \"AUC\": auc,\n",
    "        \"EER\": eer,\n",
    "    }\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_curve\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib.ticker as ticker\n",
    "import json\n",
    "\n",
    "all_metrics = {}\n",
    "pos_hamm_all = torch.tensor([x for v in benign_hamm_dict.values() for x in v])\n",
    "neg_hamm_all = torch.tensor([x for v in malicious_hamm_dict.values() for x in v])\n",
    "\n",
    "for op_name, values in benign_hamm_dict.items():\n",
    "    pos_array = torch.tensor(values).numpy()\n",
    "    neg_samples = np.random.choice(neg_hamm_all, size=len(pos_array), replace=False)\n",
    "    all_metrics[op_name] = evaluate_hamming_benign(pos_array, neg_samples, best_hamm_threshold)\n",
    "\n",
    "json_metrics = {op: {k: f\"{float(v)*100:.2f}\" for k,v in metrics.items()} \n",
    "               for op, metrics in all_metrics.items()}\n",
    "with open(f'{log_dir}/benign_evaluation.json', 'w') as f:\n",
    "    json.dump(json_metrics, f, indent=2)\n",
    "\n",
    "print(\"\\n=== Evaluation Metrics for Benign Operations ===\")\n",
    "for op_name, metrics in all_metrics.items():\n",
    "    print(f\"\\n{op_name.upper()}\")\n",
    "    for k, v in metrics.items():\n",
    "        print(f\"  {k}: {v:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix\n",
    "def evaluate_hamming_malicious(pos_hamm, neg_hamm, threshold):\n",
    "    labels = np.concatenate([\n",
    "        np.ones_like(pos_hamm),\n",
    "        np.zeros_like(neg_hamm)\n",
    "    ])\n",
    "    scores = np.concatenate([\n",
    "        pos_hamm,\n",
    "        neg_hamm\n",
    "    ])\n",
    "    preds = (scores >= threshold).astype(int)  \n",
    "\n",
    "    acc = accuracy_score(labels, preds)\n",
    "    auc = roc_auc_score(labels, scores)\n",
    "    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()\n",
    "    tpr = tp / (tp + fn)\n",
    "    fpr = fp / (fp + tn)\n",
    "\n",
    "    fpr_curve, tpr_curve, thresholds = roc_curve(labels, scores)\n",
    "    fnr_curve = 1 - tpr_curve\n",
    "    eer_idx = np.nanargmin(np.abs(fpr_curve - fnr_curve))\n",
    "    eer = (fpr_curve[eer_idx] + fnr_curve[eer_idx]) / 2\n",
    "    \n",
    "    return {\n",
    "        \"Acc\": acc,\n",
    "        \"TPR\": tpr,\n",
    "        \"FPR\": fpr,\n",
    "        \"AUC\": auc,\n",
    "        \"EER\": eer,\n",
    "    }\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_metrics = {}\n",
    "\n",
    "for op_name, values in malicious_hamm_dict.items():\n",
    "    neg_array = torch.tensor(values).numpy()\n",
    "    pos_samples = np.random.choice(pos_hamm_all, size=len(neg_array), replace=False)\n",
    "    all_metrics[op_name] = evaluate_hamming_malicious(neg_array, pos_samples, best_hamm_threshold)\n",
    "\n",
    "\n",
    "# Save metrics to JSON file\n",
    "json_metrics = {op: {k: f\"{float(v)*100:.2f}\" for k,v in metrics.items()} \n",
    "               for op, metrics in all_metrics.items()}\n",
    "with open(f'{log_dir}/malicious_evaluation.json', 'w') as f:\n",
    "    json.dump(json_metrics, f, indent=2)\n",
    "\n",
    "print(\"\\n=== Evaluation Metrics for Malicious Operations ===\")\n",
    "for op_name, metrics in all_metrics.items():\n",
    "    print(f\"\\n{op_name.upper()}\")\n",
    "    for k, v in metrics.items():\n",
    "        print(f\"  {k}: {v:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "neg_hamm_all_sampled = np.random.choice(neg_hamm_all, size=pos_hamm_all.shape[0], replace=False)\n",
    "overall_metrics = evaluate_hamming_benign(pos_hamm_all, neg_hamm_all_sampled, best_hamm_threshold)\n",
    "\n",
    "\n",
    "json_metrics = {k: f\"{float(v)*100:.2f}\" for k,v in overall_metrics.items()}\n",
    "with open(f'{log_dir}/overall_evaluation.json', 'w') as f:\n",
    "    json.dump(json_metrics, f, indent=2)\n",
    "print(\"\\n=== Overall Evaluation ===\")\n",
    "for k, v in overall_metrics.items():\n",
    "    print(f\"  {k}: {v:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SpeechVerifier",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
