{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:56:02.554818Z",
     "start_time": "2024-07-17T17:56:01.344555Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from model.RawNet3 import RawNet3_detect\n",
    "from model.RawNetBasicBlock import Bottle2neck\n",
    "import librosa\n",
    "from IPython.display import Audio, display\n",
    "import torch.nn.functional as F\n",
    "import soundfile as sf\n",
    "import torch.nn as nn\n",
    "from scipy.fft import fft, ifft  # For performing Fourier Transform\n",
    "import numpy as np\n",
    "import torchaudio\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "self.encoder_type ECA\n"
     ]
    }
   ],
   "source": [
    "model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True, out_bn=True,\n",
    "                       block=Bottle2neck, model_scale=8, context=True, summed=True)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:56:02.607964Z",
     "start_time": "2024-07-17T17:56:02.555589Z"
    }
   },
   "id": "8daf348e3ffd02e6"
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "data": {
      "text/plain": "<All keys matched successfully>"
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('../weights/epoch_13.pth', map_location=torch.device('cpu'))['model_state_dict'])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:56:02.673260Z",
     "start_time": "2024-07-17T17:56:02.605260Z"
    }
   },
   "id": "8389a65026601b23"
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "upper_limit, lower_limit = 1, -1\n",
    "\n",
    "\n",
    "def clamp(X, lower_limit, upper_limit):\n",
    "    return torch.max(torch.min(X, upper_limit), lower_limit)\n",
    "\n",
    "\n",
    "def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts,\n",
    "               norm, early_stop=False, mixup=False):\n",
    "    # batch normailzation, calculate mean and std for X: (B, L)\n",
    "\n",
    "    max_loss = torch.zeros(y.shape[0])  # (B)\n",
    "    max_delta = torch.zeros_like(X)  # (B, L)\n",
    "\n",
    "    for _ in range(restarts):\n",
    "        delta = torch.zeros_like(X)  # (B, L)\n",
    "        if norm == \"l_inf\":\n",
    "            delta.uniform_(-epsilon, epsilon)\n",
    "        elif norm == \"l_2\":\n",
    "            delta.normal_()\n",
    "            d_flat = delta.view(delta.size(0), -1)\n",
    "            n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1)\n",
    "            r = torch.zeros_like(n).uniform_(0, 1)\n",
    "            delta *= r / n * epsilon\n",
    "        else:\n",
    "            raise ValueError\n",
    "        delta = clamp(delta, lower_limit - X, upper_limit - X)\n",
    "        delta.requires_grad = True\n",
    "        for _ in range(attack_iters):\n",
    "            output = model(X + delta)\n",
    "            index = slice(None, None, None)\n",
    "            if not isinstance(index, slice) and len(index) == 0:\n",
    "                break\n",
    "            loss = F.cross_entropy(output, y)\n",
    "            loss.backward()\n",
    "            grad = delta.grad.detach()\n",
    "            d = delta[index, :]\n",
    "            g = grad[index, :]\n",
    "            x = X[index, :]\n",
    "            if norm == \"l_inf\":\n",
    "                d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)\n",
    "            elif norm == \"l_2\":\n",
    "                g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1)\n",
    "                scaled_g = g / (g_norm + 1e-10)\n",
    "                d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)\n",
    "            d = clamp(d, lower_limit - x, upper_limit - x)\n",
    "            delta.data[index, :] = d\n",
    "            delta.grad.zero_()\n",
    "            all_loss = F.cross_entropy(model(X + delta), y, reduction='none')\n",
    "        max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]\n",
    "        max_loss = torch.max(max_loss, all_loss)\n",
    "    # dict\n",
    "    return max_delta"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:56:02.677453Z",
     "start_time": "2024-07-17T17:56:02.670353Z"
    }
   },
   "id": "fc43f5046e161b55"
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "audio_path = \"Elon Mask Fake.wav\"\n",
    "audio, sr = librosa.load(audio_path, sr=16000)\n",
    "audio = torch.tensor(audio).unsqueeze(0)\n",
    "label = torch.tensor([1])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:56:04.048587Z",
     "start_time": "2024-07-17T17:56:02.676456Z"
    }
   },
   "id": "35e2049cece1d412"
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "delta = attack_pgd(model.eval(), audio, label, epsilon=0.0005, alpha=0.0001, attack_iters=2, restarts=1, norm='l_inf')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:57:08.574560Z",
     "start_time": "2024-07-17T17:57:03.635050Z"
    }
   },
   "id": "ffdd299b73811395"
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "tensor([[ 8.0873e-05, -4.0000e-04,  1.9834e-05,  ..., -3.0759e-04,\n          2.9107e-04, -3.1622e-04]])"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "delta"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:57:09.174971Z",
     "start_time": "2024-07-17T17:57:09.167608Z"
    }
   },
   "id": "78563bafc470f0dc"
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [],
   "source": [
    "attacked_audio = audio + delta"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:57:09.999231Z",
     "start_time": "2024-07-17T17:57:09.995242Z"
    }
   },
   "id": "fadffe0234460e66"
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [],
   "source": [
    "# save attacked audio\n",
    "# sf.write('attacked_audio.wav', attacked_audio[0].numpy(), samplerate=16000)\n",
    "# sf.write('original_audio.wav', audio[0].numpy(), samplerate=16000)\n",
    "# sf.write('attack.wav', delta[0].numpy(), samplerate=16000)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:57:10.240373Z",
     "start_time": "2024-07-17T17:57:10.234880Z"
    }
   },
   "id": "dfb191a0e2bcfe21"
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [],
   "source": [
    "# 保存被攻击的音频\n",
    "torchaudio.save('attacked_audio.wav', attacked_audio, sample_rate=16000)\n",
    "# \n",
    "# # 保存原始音频\n",
    "# torchaudio.save('original_audio.wav', audio, sample_rate=16000)\n",
    "# \n",
    "# # 保存攻击效果音频\n",
    "# torchaudio.save('attack.wav', delta, sample_rate=16000)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T17:57:11.809507Z",
     "start_time": "2024-07-17T17:57:11.804142Z"
    }
   },
   "id": "7a043b248a44e594"
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "outputs": [],
   "source": [
    "# 设置STFT参数\n",
    "n_fft = 1024\n",
    "hop_length = 512\n",
    "win_length = 1024"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-16T23:58:44.085103Z",
     "start_time": "2024-07-16T23:58:43.966131Z"
    }
   },
   "id": "77ef663de08e0f6e"
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "outputs": [],
   "source": [
    "audio = torch.tensor(audio)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-16T23:59:10.891562Z",
     "start_time": "2024-07-16T23:59:10.877298Z"
    }
   },
   "id": "b89e2551a82fea40"
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "outputs": [],
   "source": [
    "stft = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length),return_complex=True)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T00:02:13.236254Z",
     "start_time": "2024-07-17T00:02:13.117737Z"
    }
   },
   "id": "f38b4979a280a90f"
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([1, 513, 779])"
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stft.shape"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T00:10:13.548736Z",
     "start_time": "2024-07-17T00:10:13.403540Z"
    }
   },
   "id": "956fcb934e3b3e15"
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "outputs": [
    {
     "data": {
      "text/plain": "tensor([[-4.4405e-03+0.0000e+00j, -4.3200e-03+0.0000e+00j,\n         -2.6289e-03+0.0000e+00j,  ...,\n         -9.6706e-03+0.0000e+00j,  3.6776e-02+0.0000e+00j,\n         -7.7417e-02+0.0000e+00j],\n        [ 2.0231e-03-1.3583e-09j,  1.4025e-03+1.0221e-04j,\n          7.7536e-04+7.6833e-04j,  ...,\n          8.3327e-03+3.8058e-03j, -2.7566e-02+1.9257e-02j,\n          1.4384e-02-2.7868e-02j],\n        [ 1.2919e-03-1.6165e-09j,  9.0993e-04-8.4998e-04j,\n          2.8173e-03+6.2566e-04j,  ...,\n         -6.0342e-03+1.1191e-03j, -1.8533e-03-3.4098e-02j,\n         -1.6743e-02-3.3127e-02j],\n        ...,\n        [ 2.9226e-05+4.2322e-10j, -1.9674e-08+1.7462e-10j,\n         -1.8859e-08-4.7905e-08j,  ...,\n         -4.2608e-08+1.8801e-08j, -9.6159e-08-4.4703e-08j,\n         -4.7806e-05+1.5938e-04j],\n        [-2.8926e-05-3.8793e-10j, -1.3795e-08-1.0652e-08j,\n         -3.1258e-08-2.9628e-08j,  ...,\n         -3.6787e-08+1.0943e-08j, -2.9802e-08-5.7742e-08j,\n         -1.3344e-04+9.9288e-05j],\n        [ 2.9050e-05+0.0000e+00j,  4.6566e-10+0.0000e+00j,\n         -3.0501e-08+0.0000e+00j,  ...,\n         -1.9092e-08+0.0000e+00j,  1.1176e-08+0.0000e+00j,\n         -1.6630e-04+0.0000e+00j]])"
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stft[0]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-17T00:10:23.606152Z",
     "start_time": "2024-07-17T00:10:23.480739Z"
    }
   },
   "id": "14847e8e9511fd9d"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "3dc071b23727fd4e"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
