{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch cross-domain integrated gradients demo\n",
    "\n",
    "We demostrate the PyTorch Cross-domain IG library on a Convolutional Neural Network (CNN) for a synthetic timeseries dataset. \n",
    "\n",
    "We consider a two-channel time-domain input (mixed signal), generated as the mixture of two sinusoidal inputs (source signal). The goal is to classify the frequency of the source signal. We train a CNN to classify the mixed inputs and then deploy Cross-domain IG to evaluate which input components affect the classification output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from cross_domain_saliency_maps.torch_ig.cross_domain_integrated_gradients import FourierIG\n",
    "from cross_domain_saliency_maps.torch_ig.cross_domain_integrated_gradients import ICAIG\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.utils import shuffle\n",
    "from sklearn.decomposition import FastICA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Input generation\n",
    "\n",
    "We generate signals as the sum of two oscillating source signals with frequencies at $f_1 Hz$ and $f_2 Hz$. Each source signal window is sampled from the single-oscillation time-signal:\n",
    "$$x_i(t) = cos(2 \\pi f_i + \\phi)$$\n",
    "\n",
    "We sample windows of $8$ seconds at a sampling frequency $f_s = 32 Hz$, generating $256$ time-points for each sample, $\\boldsymbol{x} \\in \\mathbb{R}^{256}$. The source vectors $\\boldsymbol{x}$ are then linearly mixed using the ```mixing_matrix``` (A) to form the mixed inputs $\\boldsymbol{x}_{mixed}$:\n",
    "$$ X_{mixed} = A \\cdot X$$\n",
    "where $X = [\\boldsymbol{x_1}, \\boldsymbol{x_2}] \\in \\mathbb{R}^{2 \\times 256}$ and similarly for $X_{mixed}$.\n",
    "\n",
    "We generate windows of two classes, where only $f_1$ is informative of the class:\n",
    "1. **Class 1:** $f_1 \\sim \\mathcal{N}(1.0, 0.1)$, $f_2 \\sim \\mathcal{N}(4.0, 0.1)$\n",
    "1. **Class 2:** $f_1 \\sim \\mathcal{N}(2.5.0, 0.1)$, $f_2 \\sim \\mathcal{N}(4.0, 0.1)$\n",
    "\n",
    "We generate $10^4$ windows for training and testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = 32.0 #Hz\n",
    "simulation_time = 8.0 #seconds\n",
    "\n",
    "n_samples_per_class = 10_000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_signals(t, f1, f2, f_std, mixing_matrix, n_samples_per_class):\n",
    "    phi = np.random.uniform(0, 2 * np.pi, size = n_samples_per_class)\n",
    "    freqs = np.random.normal(loc = f1, scale = f_std, size = n_samples_per_class)\n",
    "    x1 = np.cos(2 * np.pi * freqs[..., None] * t[None, ...] + phi[..., None])\n",
    "\n",
    "    phi = np.random.uniform(0, 2 * np.pi, size = n_samples_per_class)\n",
    "    freqs = np.random.normal(loc = f2, scale = f_std, size = n_samples_per_class)\n",
    "    x2 = np.cos(2 * np.pi * freqs[..., None] * t[None, ...] + phi[..., None])\n",
    "\n",
    "    X1 = mixing_matrix[0, 0] * x1 + mixing_matrix[0, 1] * x2\n",
    "    X2 = mixing_matrix[1, 0] * x1 + mixing_matrix[1, 1] * x2\n",
    "\n",
    "    X1 = X1 - X1.mean(axis = -1)[..., None]\n",
    "    X2 = X2 - X2.mean(axis = -1)[..., None]\n",
    "\n",
    "    x_mixed = np.stack([X1, X2], axis = 1)\n",
    "    x_source = np.stack([x1, x2], axis = 1)\n",
    "\n",
    "    return x_mixed, x_source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_timepoints = int(fs * simulation_time)\n",
    "t = np.linspace(0, simulation_time, N_timepoints)\n",
    "\n",
    "mixing_matrix = np.array([[0.5, 0.75], \n",
    "                          [0.9, 0.1]])\n",
    "\n",
    "# Generate training samples\n",
    "\n",
    "X1, X1_unmixed = generate_signals(t, f1 = 1.0, f2 = 4.0, f_std = 0.1, mixing_matrix = mixing_matrix, n_samples_per_class = n_samples_per_class)\n",
    "y1 = np.zeros((n_samples_per_class))\n",
    "\n",
    "X2, _ = generate_signals(t, f1 = 2.5, f2 = 4.0, f_std = 0.1, mixing_matrix = mixing_matrix, n_samples_per_class = n_samples_per_class)\n",
    "y2 = np.ones((n_samples_per_class))\n",
    "\n",
    "X_train = np.concatenate([X1, X2], axis = 0)\n",
    "y_train = np.concatenate([y1, y2], axis = 0)\n",
    "\n",
    "X_train, y_train = shuffle(X_train, y_train)\n",
    "\n",
    "# Generate test samples\n",
    "\n",
    "X1, _ = generate_signals(t, f1 = 1.0, f2 = 4.0, f_std = 0.1, mixing_matrix = mixing_matrix, n_samples_per_class = n_samples_per_class)\n",
    "y1 = np.zeros((n_samples_per_class))\n",
    "\n",
    "X2, _ = generate_signals(t, f1 = 2.5, f2 = 4.0, f_std = 0.1, mixing_matrix = mixing_matrix, n_samples_per_class = n_samples_per_class)\n",
    "y2 = np.ones((n_samples_per_class))\n",
    "\n",
    "X_test = np.concatenate([X1, X2], axis = 0)\n",
    "y_test = np.concatenate([y1, y2], axis = 0)\n",
    "\n",
    "X_test, y_test = shuffle(X_test, y_test)\n",
    "\n",
    "X_train = torch.from_numpy(X_train).type(torch.float32).to(device)\n",
    "y_train = torch.from_numpy(y_train).type(torch.float32).to(device)\n",
    "\n",
    "training_set = torch.utils.data.TensorDataset(X_train, y_train)\n",
    "training_dataloader = torch.utils.data.DataLoader(training_set,\n",
    "                                                  batch_size = 128,\n",
    "                                                  shuffle = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We plot the first two seconds of an example input window:\n",
    "- **Left panel.** We plot the original unmixed channels in which the slower (green) oscillation bears the class information. The fast signal oscillates always at the same $4.0 Hz$ frequency regardless of the sample class.\n",
    "- **Right panel.** We plot the input mixed channels in which the class information is mixed in both input channels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(t, X1_unmixed[0, 0, :], color = 'C2', label = 'Channel 1')\n",
    "plt.plot(t, X1_unmixed[0, 1, :], color = 'C3', label = 'Channel 2')\n",
    "plt.xlim([0, 2])\n",
    "plt.title('Source (unmixed) signals')\n",
    "plt.legend()\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(t, X1[0, 0, :], label = 'Channel 1')\n",
    "plt.plot(t, X1[0, 1, :], label = 'Channel 2')\n",
    "plt.xlim([0, 2])\n",
    "plt.title('Input (mixed) signals')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The CNN\n",
    "\n",
    "We build a three-channel CNN to classify the mixed inputs based on $f_1$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = torch.nn.Conv1d(in_channels = 2, \n",
    "                                     out_channels = 32, \n",
    "                                     kernel_size = 5,\n",
    "                                     padding = 'same')\n",
    "        self.conv2 = torch.nn.Conv1d(in_channels = 32, \n",
    "                                     out_channels = 32, \n",
    "                                     kernel_size = 5,\n",
    "                                     padding = 'same')\n",
    "        self.conv3 = torch.nn.Conv1d(in_channels = 32, \n",
    "                                     out_channels = 32, \n",
    "                                     kernel_size = 5,\n",
    "                                     padding = 'same')\n",
    "        \n",
    "        self.fc = torch.nn.Linear(32, 2)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "\n",
    "        x = torch.mean(x, dim = -1) # flatten all dimensions except batch\n",
    "\n",
    "        x = self.fc(x)\n",
    "        return x\n",
    "\n",
    "class SoftmaxNet(torch.nn.Module):\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.softmax = torch.nn.Softmax(dim = -1)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.model(x)\n",
    "        x = self.softmax(x)\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy_fn(y_true, y_pred):\n",
    "    correct = torch.eq(y_true, y_pred).sum().item() # torch.eq() calculates where two tensors are equal\n",
    "    acc = (correct / len(y_pred)) * 100 \n",
    "    return acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net()\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model training\n",
    "\n",
    "We now train the CNN to classify our times series samples. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 20\n",
    "\n",
    "for epoch in range(n_epochs):  # loop over the dataset multiple times\n",
    "\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(training_dataloader, 0):\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, labels = data\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels.long())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # print statistics\n",
    "        running_loss += loss.item()\n",
    "        if i % 150 == 149:    # print every 2000 mini-batches\n",
    "            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 150:.3f}')\n",
    "            running_loss = 0.0\n",
    "\n",
    "print('Finished Training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "softmax_model = SoftmaxNet(model = model)\n",
    "softmax_model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cross-domain saliency maps\n",
    "\n",
    "We use cross-domain IG to generate saliency maps on two domains:\n",
    "1. **Frequency domain** via the Fourier transform.  \n",
    "2. **Independent components domain** via the Independent Component Analysis (ICA). \n",
    "\n",
    "Our library provides direct implementations of both domains. Extending our supported domains is easily via the ```Domain``` base class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_ig_iterations = 500\n",
    "\n",
    "x = X_test[y_test == 0][:1]\n",
    "x_baseline = np.zeros_like(x)\n",
    "\n",
    "fourierIG = FourierIG(softmax_model, n_ig_iterations, output_channel = 0, device = device)\n",
    "ig_frequency = fourierIG.run(x, x_baseline).cpu().numpy()\n",
    "\n",
    "fastICA = FastICA(tol = 1e-9)\n",
    "fastICA.fit(x[0].T)\n",
    "print(\"FasICA converged in \", fastICA.n_iter_, \" iterations.\")\n",
    "\n",
    "icaIG = ICAIG(softmax_model, fastICA, n_ig_iterations, output_channel = 0, device = device)\n",
    "ig_ica = icaIG.run(x, x_baseline).cpu().numpy()\n",
    "\n",
    "X_ica = fastICA.transform(x[0].T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot results\n",
    "\n",
    "We not plot the IG results."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Frequency-domain IG\n",
    "The IG attribution maps return maps for the two input channels in the frequency domain. The input sample we have used here is oscillating at $1.0 Hz$ and the saliency map highlights this frequency in both input channels. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xf = np.linspace(0.0, 1.0/(2.0*(1/fs)), N_timepoints//2)\n",
    "plt.figure()\n",
    "plt.plot(xf, 2 * ig_frequency[0, 0, :128], label = 'Input Channel 0')\n",
    "plt.plot(xf, 2 * ig_frequency[0, 1, :128], label = 'Input Channel 1')\n",
    "\n",
    "plt.xlabel('Freq. (Hz)')\n",
    "plt.ylabel('ICA IG')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ICA IG\n",
    "\n",
    "We now demonstrate the IG in the basis defined by the independent components."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we visualise the decomposition of the inputs into the ICA basis. Observe that through ICA we have successfully recovered the two original single-frequency oscillations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_decomp = icaIG.domain.forward_transform(x[0].T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "information_channel = np.argmax(ig_ica)\n",
    "\n",
    "channel_colors = ['C3','C3']\n",
    "channel_colors[information_channel] = 'C2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "\n",
    "for i in range(2):\n",
    "    plt.plot(x_decomp[:, i], color = channel_colors[i])\n",
    "plt.title('Reconstructed Source (unmixed) channels.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now visualise the ICA IG. The attributions highlight the $f_1 = 1.0Hz$ oscillation (slower oscillation - green) as the one tilting the model towards its final classification. Observe that these attributions exactly match the frequency ones. However, here, each channel is composed of only one single component oscillating at a single frequency. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=1, ncols=2)\n",
    "\n",
    "OFFSET = 5.0\n",
    "scale = 1.0\n",
    "\n",
    "for i in range(2):\n",
    "    axes[0].plot(t, scale * x_decomp[..., i] - i * OFFSET, \n",
    "                 color=channel_colors[i], linewidth = 1.0,\n",
    "                 alpha = 1.0)\n",
    "    axes[0].set_xlim([t[0], t[-1]])\n",
    "    axes[0].spines[['top', 'right']].set_visible(False)\n",
    "    \n",
    "    axes[1].barh(y = - i * OFFSET, width=ig_ica[i], \n",
    "                 height = 0.75, color=channel_colors[i])\n",
    "    axes[1].set_xlim([0, max(ig_ica) * 1.1])\n",
    "    axes[1].spines[['top', 'right', 'left']].set_visible(False)\n",
    "\n",
    "ymin, ymax = axes[0].get_ylim()\n",
    "\n",
    "axes[1].vlines(0.0, ymin, ymax, color = 'black', linestyles = 'dashed')\n",
    "\n",
    "axes[1].set_xticks([ig_ica.min(), ig_ica.max()],\n",
    "                   [f\"{ig_ica.min():.1f}\", f\"{ig_ica.max():.1f}\"])\n",
    "axes[0].set_yticks([- i * OFFSET for i in range(2)], ['Component ' + str(i) for i in range(2)])\n",
    "axes[1].set_yticks([])\n",
    "    \n",
    "margin_ratio = 1.15\n",
    "axes[0].set_ylim([margin_ratio * ymin, margin_ratio * ymax])\n",
    "axes[1].set_ylim([margin_ratio * ymin, margin_ratio * ymax])\n",
    "\n",
    "axes[0].set_xlabel('Time (s)')\n",
    "axes[1].set_xlabel('ICA IG')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.12.6 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
