{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51bc6b39-a063-49d9-948d-1c230caea2c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import gb\n",
    "from gb.sims import rtd\n",
    "from gb.sims import NSALoss\n",
    "from gb.sims import cka, gram_linear"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68ae98a7-d9ee-478a-a7b0-ba3c62868565",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import numpy as np\n",
    "from scipy.stats import wasserstein_distance\n",
    "from itertools import combinations\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ad3a643-7a85-4ba4-bc8c-32ab2e184341",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c15d4cb7-c2c9-4933-a295-e7501511b589",
   "metadata": {},
   "outputs": [],
   "source": [
    "nsa_criterion = NSALoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70d1b50f-822b-41b7-ba58-39cc7eb84ddc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#Set D to be the size of the dataset\n",
    "D = 2485"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cb44303-2382-4532-8a33-01477c4e1281",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Initialize an empty dictionary to store the combined results\n",
    "combined_dict = {}\n",
    "\n",
    "# List of file names\n",
    "path = 'accuracy_vals/'\n",
    "file_names = ['5.pkl', '10.pkl', '15.pkl', '20.pkl', '25.pkl']\n",
    "# Iterate through each file\n",
    "for file_name in file_names:\n",
    "    # Load the dictionary from the .pkl file\n",
    "    file_name = path+file_name\n",
    "    \n",
    "    data = np.load(file_name, allow_pickle=True)\n",
    "    \n",
    "    # Iterate through the keys in the loaded dictionary\n",
    "    for gnn_architecture, accuracy_data in data.items():\n",
    "        # Check if the GNN architecture key exists in the combined_dict\n",
    "        if gnn_architecture not in combined_dict:\n",
    "            combined_dict[gnn_architecture] = {}\n",
    "        \n",
    "        # Iterate through the second-level keys (type of training)\n",
    "        for training_type, accuracy_values in accuracy_data.items():\n",
    "            # Check if the training type key exists in the combined_dict\n",
    "            if training_type not in combined_dict[gnn_architecture]:\n",
    "                combined_dict[gnn_architecture][training_type] = []\n",
    "            \n",
    "            # Append the accuracy values to the combined dictionary\n",
    "            combined_dict[gnn_architecture][training_type].append(accuracy_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49fa9302-b9e0-4e10-ab9f-a3dad75d1511",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8805525d-c476-428a-bf5e-6109f8ae6d20",
   "metadata": {},
   "source": [
    "## Generate similarity values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86f89287-ff44-4774-8ea0-1aa571195e9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#attack_type = 'gp'\n",
    "attack_type = 'ge'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0848b752-7b1f-43b9-badc-d805449326f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "###Algorithm\n",
    "#Attack type - fixed, run 2 times, for ge, gp\n",
    "value_dict = {}\n",
    "#Iterate through architectures\n",
    "for arch in ['gcn','gcnsvd','gnnguard','prognn','grand']:\n",
    "    value_dict[arch]={}\n",
    "    #Layers\n",
    "    for layer in ['conv0','conv1']:\n",
    "        #Iterate through perturbations\n",
    "        value_dict[arch][layer]={}\n",
    "        value_dict[arch][layer]['cka']=[]\n",
    "        value_dict[arch][layer]['rtd']=[]\n",
    "        value_dict[arch][layer]['nsa']=[]\n",
    "        for ptb_rate in [5,10,15,20,25]:\n",
    "            print(arch,layer,ptb_rate)\n",
    "            #Compare the clean embedding for the architecture\n",
    "            clean_path = 'feature_vals/'+arch+'_'+'clean_'+str(ptb_rate)+'.npz'\n",
    "            print(clean_path)\n",
    "            with np.load(clean_path) as loader1:\n",
    "                clean_vals = dict(loader1)\n",
    "            #Against the embedding for the same architecture with attack and ptb rate of choice\n",
    "            adv_path = 'feature_vals/'+arch+'_'+attack_type+'_'+str(ptb_rate)+'.npz'\n",
    "            print(adv_path)\n",
    "            with np.load(adv_path) as loader2:\n",
    "                attack_vals = dict(loader2)\n",
    "            X = clean_vals[layer]\n",
    "            Y = attack_vals[layer]\n",
    "            #Given X and Y now\n",
    "            #Calculate CKA\n",
    "            print(\"Calculating CKA\")\n",
    "            cka_from_examples = cka(gram_linear(X), gram_linear(Y))\n",
    "            cka_val = (1-cka_from_examples)\n",
    "\n",
    "            value_dict[arch][layer]['cka'].append(cka_val)\n",
    "            #Calculate RTD\n",
    "            print(\"Calculating RTD\")\n",
    "            full_list_indices = np.random.choice(range(D),400, replace=False)\n",
    "            A = X[full_list_indices]\n",
    "            B = Y[full_list_indices]\n",
    "            rtd_val=rtd(A,B)\n",
    "            print(rtd_val)\n",
    "            value_dict[arch][layer]['rtd'].append(rtd_val)\n",
    "            #Calculate NSA\n",
    "            print(\"Calculating NSA\")\n",
    "            nsa_val=nsa_criterion(torch.tensor(X),torch.tensor(Y))\n",
    "            value_dict[arch][layer]['nsa'].append(nsa_val.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec7358db-71e9-45e8-999c-1ad07fa87aaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_line_graph(val_dict,layer,metric):\n",
    "    archs = val_dict.keys()\n",
    "    for arch in archs:\n",
    "        y = val_dict[arch][layer][metric]\n",
    "        x = range(len(y))\n",
    "        plt.plot(x, y, label=arch)\n",
    "    plt.xlabel('Perturbation rate',fontsize=19)\n",
    "    plt.ylabel(metric.upper()+\" values\",fontsize=19)\n",
    "    plt.title(metric.upper()+' value vs Perturbation rate for: '+layer)\n",
    "    plt.legend()\n",
    "    x_ticks = [5,10,15,20,25]\n",
    "    plt.xticks(x, x_ticks)  # Customize x-tick values\n",
    "    #plt.show()\n",
    "    savestr=f'g_{metric}_{layer}.jpg'\n",
    "    plt.savefig(savestr)\n",
    "\n",
    "plot_line_graph(value_dict,'conv0','cka')\n",
    "plot_line_graph(value_dict,'conv1','cka')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d39aefe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_line_graph(val_dict,layer,metric):\n",
    "    archs = val_dict.keys()\n",
    "    for arch in archs:\n",
    "        y = val_dict[arch][layer][metric]\n",
    "        x = range(len(y))\n",
    "        plt.plot(x, y, label=arch)\n",
    "    plt.xlabel('Perturbation rate',fontsize=19)\n",
    "    plt.ylabel(metric.upper()+\" values\",fontsize=19)\n",
    "    plt.title(metric.upper()+' value vs Perturbation rate for: '+layer)\n",
    "    plt.legend()\n",
    "    x_ticks = [5,10,15,20,25]\n",
    "    plt.xticks(x, x_ticks)  # Customize x-tick values\n",
    "    #plt.show()\n",
    "    savestr=f'ge_{metric}_{layer}.jpg'\n",
    "    plt.savefig(savestr)\n",
    "\n",
    "# Example usage\n",
    "\n",
    "plot_line_graph(value_dict,'conv0','rtd')\n",
    "plot_line_graph(value_dict,'conv1','rtd')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41f60138",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_line_graph(val_dict,layer,metric):\n",
    "    archs = val_dict.keys()\n",
    "    for arch in archs:\n",
    "        y = val_dict[arch][layer][metric]\n",
    "        x = range(len(y))\n",
    "        plt.plot(x, y, label=arch)\n",
    "    plt.xlabel('Perturbation rate',fontsize=19)\n",
    "    plt.ylabel(metric.upper()+\" value\",fontsize=19)\n",
    "    #plt.title(metric.upper()+' value vs Perturbation rate for: '+layer)\n",
    "    plt.legend()\n",
    "    x_ticks = [5,10,15,20,25]\n",
    "    plt.xticks(fontsize=17)  # Adjust the fontsize as needed\n",
    "    plt.yticks(fontsize=17)  # Adjust the fontsize as needed\n",
    "    plt.xticks(x, x_ticks)  # Customize x-tick values\n",
    "#    plt.show()\n",
    "    savestr=f'gp_{metric}_{layer}.jpg'\n",
    "    plt.savefig(savestr)\n",
    "\n",
    "# Example usage\n",
    "\n",
    "plot_line_graph(value_dict,'conv0','nsa')\n",
    "plot_line_graph(value_dict,'conv1','nsa')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a89a2191-5910-434c-b95e-645795990aab",
   "metadata": {},
   "source": [
    "## Generate Misclassification Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddff487f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Define the perturbation values (x-axis)\n",
    "perturbations = list(range(5, 26, 5))\n",
    "\n",
    "# Define the data lists\n",
    "GCN = 100-combined_dict['GCN']['evas']\n",
    "GCN_SVD = 100-combined_dict['GCNSVD']['evas']\n",
    "GNNGuard = 100-combined_dict['GNNGuard']['evas']\n",
    "ProGNN = 100-combined_dict['ProGNN']['evas']\n",
    "GRAND = 100-combined_dict['GRAND']['evas']\n",
    "\n",
    "\n",
    "# Create the line graphs\n",
    "plt.plot(perturbations, GCN, label='GCN')\n",
    "plt.plot(perturbations, GCN_SVD, label='GCN-SVD')\n",
    "plt.plot(perturbations, GNNGuard, label='GNNGuard')\n",
    "plt.plot(perturbations, ProGNN, label='ProGNN')\n",
    "plt.plot(perturbations, GRAND, label='GRAND')\n",
    "\n",
    "# Add labels and legend\n",
    "plt.xlabel('Perturbation Rate',fontsize=19)\n",
    "plt.ylabel('Misclassification Rate',fontsize=19)\n",
    "#plt.title('Global Evasion')\n",
    "plt.xticks(perturbations)\n",
    "plt.xticks(fontsize=17)  # Adjust the fontsize as needed\n",
    "plt.yticks(fontsize=17)  # Adjust the fontsize as needed\n",
    "plt.legend()\n",
    "plt.savefig(\"ge_misclassification1.jpg\")\n",
    "# Display the plot\n",
    "#plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af6cc3da",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Define the perturbation values (x-axis)\n",
    "perturbations = list(range(5, 26, 5))\n",
    "\n",
    "# Define the data lists\n",
    "GCN = 100-combined_dict['GCN']['pois']\n",
    "GCN_SVD = 100-combined_dict['GCNSVD']['pois']\n",
    "GNNGuard = 100-combined_dict['GNNGuard']['pois']\n",
    "ProGNN = 100-combined_dict['ProGNN']['pois']\n",
    "GRAND = 100-combined_dict['GRAND']['pois']\n",
    "\n",
    "# Create the line graphs\n",
    "plt.plot(perturbations, GCN, label='GCN')\n",
    "plt.plot(perturbations, GCN_SVD, label='GCN-SVD')\n",
    "plt.plot(perturbations, GNNGuard, label='GNNGuard')\n",
    "plt.plot(perturbations, ProGNN, label='ProGNN')\n",
    "plt.plot(perturbations, GRAND, label='GRAND')\n",
    "\n",
    "# Add labels and legend\n",
    "plt.xlabel('Perturbation Rate',fontsize=19)\n",
    "plt.ylabel('Misclassification Rate',fontsize=19)\n",
    "#plt.title('Global Poisoning')\n",
    "plt.xticks(perturbations)\n",
    "plt.xticks(fontsize=17)  # Adjust the fontsize as needed\n",
    "plt.yticks(fontsize=17)  # Adjust the fontsize as needed\n",
    "plt.legend()\n",
    "plt.savefig(\"gp_misclassification1.jpg\")\n",
    "# Display the plot\n",
    "#plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71218cfb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
