{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7664ce59",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "import torch\n",
    "import dgl\n",
    "import cupy as cp\n",
    "import collections\n",
    "import random\n",
    "from dgl import function as fn\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from dataset import load_graph_dataset\n",
    "from tqdm import tqdm\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "from torch_geometric.datasets import Planetoid \n",
    "from pygod.utils import gen_attribute_outliers\n",
    "from pygod.models import CoLA, DOMINANT\n",
    "from dgl.data import FraudDataset\n",
    "from pygod.utils.metric import eval_roc_auc, eval_recall_at_k, eval_precision_at_k\n",
    "from sklearn.metrics import precision_score, accuracy_score\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "from sklearn.neighbors import NearestCentroid\n",
    "\n",
    "from outlier_generator import inject_edge_outlier, generate_node_feature_outliers\n",
    "\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "00156adf",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_pyg = 'Cora'\n",
    "data_dgl = 'cora'\n",
    "NUM_OUTLIER = 200\n",
    "NUM_CANDIDATE = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "097aa6c6",
   "metadata": {},
   "source": [
    "##### 1, load in the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e4995b26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "data = Planetoid(root='/tmp/Cora', name=data_pyg)[0]\n",
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(data_dgl)\n",
    "data.x = feat\n",
    "data.edge_index = torch.stack((graph.edges()[0], graph.edges()[1]))\n",
    "data.y = labels\n",
    "data.train_mask = train_mask\n",
    "data.test_mask = test_mask\n",
    "data.val_mask = val_mask"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83e7db2e",
   "metadata": {},
   "source": [
    "##### 2, Inject outliers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3397e13c",
   "metadata": {},
   "outputs": [],
   "source": [
    "(feat_class1_to_class2, label_class1_k_orig, label_class1_k_convert), (feat_class2_to_class1, label_class2_k_orig, label_class2_k_convert) = generate_node_feature_outliers(feat,labels, 0, 1, 10, 10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "548c3493",
   "metadata": {},
   "source": [
    "##### Inject edge outliers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1ebb27bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_graph, edge_oulier_labels = inject_edge_outlier(graph, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "7f22ee6a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13264"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "735683a6",
   "metadata": {},
   "source": [
    "##### Functions for generating label outliers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab63ac09",
   "metadata": {},
   "outputs": [],
   "source": [
    "itertools.combinations(iterable, r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65cdd622",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c73d922",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e410d9b7",
   "metadata": {},
   "source": [
    "##### 3, outlier detection via DOMINANT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "85bff270",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Dominant_outlier_dectection(data):\n",
    "    model = DOMINANT(num_layers=4) # init. detection model\n",
    "    model.fit(data)\n",
    "    labels = model.predict(data)\n",
    "    outlier_scores = model.decision_function(data)\n",
    "    prob = model.predict_proba(data)\n",
    "    labels, confidence = model.predict(data, return_confidence=True)\n",
    "    recall_at_k = eval_recall_at_k(data.y.numpy(), outlier_scores,\n",
    "                               k=NUM_OUTLIER, threshold=model.threshold_)\n",
    "    precision_at_k = eval_precision_at_k(data.y.numpy(), outlier_scores,\n",
    "                                         k=NUM_OUTLIER, threshold=model.threshold_)\n",
    "    return recall_at_k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2bf3bad4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DOMINANT(num_layers=4) # init. detection model\n",
    "model.fit(data)\n",
    "labels = model.predict(data)\n",
    "outlier_scores = model.decision_function(data)\n",
    "prob = model.predict_proba(data)\n",
    "labels, confidence = model.predict(data, return_confidence=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "df567965",
   "metadata": {},
   "outputs": [],
   "source": [
    "# auc_score = eval_roc_auc(data.y.numpy(), outlier_scores)\n",
    "recall_at_k = eval_recall_at_k(data.y.numpy(), outlier_scores,\n",
    "                               k=NUM_OUTLIER, threshold=model.threshold_)\n",
    "precision_at_k = eval_precision_at_k(data.y.numpy(), outlier_scores,\n",
    "                                     k=NUM_OUTLIER, threshold=model.threshold_)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e5607f68",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probability:\n",
      "[[0.90694323 0.09305677]\n",
      " [0.87451244 0.12548756]\n",
      " [0.91590343 0.08409657]\n",
      " ...\n",
      " [0.98888129 0.01111871]\n",
      " [0.98795747 0.01204253]\n",
      " [0.98511295 0.01488705]]\n"
     ]
    }
   ],
   "source": [
    "print('Probability:')\n",
    "print(prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "118102bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall@200: 0.021591055134301505\n",
      "Precision@200: 0.84\n"
     ]
    }
   ],
   "source": [
    "# print('AUC Score:', auc_score)\n",
    "print(f'Recall@{NUM_OUTLIER}:', recall_at_k)\n",
    "print(f'Precision@{NUM_OUTLIER}:', precision_at_k)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8296e201",
   "metadata": {},
   "source": [
    "##### 4, Outlier detection via influence function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3a2ed9dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "l2_regularlization_term = 0.01\n",
    "lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3ca9b400",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = data.x.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(2):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e94106f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0.numpy().astype(np.float32)\n",
    "train_y = data.y.numpy().astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "79bf247f",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6558d8fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "\"\"\" Train Logistic Regression \"\"\"\n",
    "lr = SimplifiedGraphNeuralNetwork(l2_reg = l2_regularlization_term, fit_intercept=True)\n",
    "lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "logits_train_y = train_x @ lr.model.coef_.T + lr.model.intercept_\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "eb1916d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████| 2707/2707 [00:22<00:00, 122.26it/s]\n"
     ]
    }
   ],
   "source": [
    "train_total_grad, train_indiv_grad = lr.grad(train_x, logits_train_y, \n",
    "                                             one_hot_labels_train, l2_reg=True)\n",
    "hess = fast_hess_cuda(train_x, logits_train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5816d57d",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_grad_hvp = fast_get_inv_hvp_cuda(hess, train_total_grad.T, cholskey=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2b3e7037",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_grad_hvp = cp.asnumpy(loss_grad_hvp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2047dd0c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10038, 1)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_grad_hvp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "9d281d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_infl = train_indiv_grad.dot(loss_grad_hvp)\n",
    "pred_infl = pred_infl.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3945dca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "ori_val_loss, ave_ori_val_loss = lr.log_loss(logits_train_y, one_hot_labels_train, l2_reg = True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e9765c5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# act_infl = []\n",
    "# for i in tqdm(range(len(pred_infl))):\n",
    "#     lr_new = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "#     train_x_new = np.delete(train_x, i, axis = 0)\n",
    "#     train_y_new = np.delete(train_y, i)\n",
    "#     lr_new.fit(train_x_new, train_y_new)\n",
    "    \n",
    "#     logits_train_y_new = train_x @ lr_new.model.coef_.T + lr_new.model.intercept_\n",
    "    \n",
    "    \n",
    "#     new_ori_val_loss, new_ave_ori_val_loss = lr_new.log_loss(logits_train_y_new, one_hot_labels_train, l2_reg = True)\n",
    "#     act_infl.append(new_ori_val_loss - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "24c1d835",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8ba67c6d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7fa2ae647c70>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAD4CAYAAAAUymoqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgDklEQVR4nO3deXxU5dnG8d89k0wIIJssoqCgUCtYtZqCWn3VtgjiguJSbOvWWlzbqlVEra0WbVHrUq3V8rZW9LVStyoqLkjd6xbcUZEIChGEoLKEJJNl7vePHGwIk8wcMpnJcn0/nzgzz3meM/dzmHjlLDNj7o6IiEgYkVwXICIi7Y/CQ0REQlN4iIhIaAoPEREJTeEhIiKh5eW6gGzp27evDxkyJNdliIi0K/Pnz1/t7v0at3ea8BgyZAjFxcW5LkNEpF0xs0+SteuwlYiIhKbwEBGR0BQeIiISmsJDRERCU3iIiHRQy957lbf//QDla8oyvu5Oc7WViEhn8cWKJfz6sAv4+AMnmufU1dzFCZfswvcvuSJjz6E9DxGRDubyI6dQ8g7EKyNUrI8Sr4pw5+/e55WHbs/Ycyg8REQ6kBUlb1PydoK6WtukPV4Z4b7rH8vY8yg8REQ6kHWff0ZefvJla8pqM/Y8Cg8RkQ5kyG77kOw7/vJjCUaPG5Sx51F4iIh0IAWFW3H6NaMpKExgVp8i+QUJevRxjrnwlxl7Hl1tJSLSwYw/fQqDd36Q+6+/l9WfxikaO4Sjzj2Hnv0GZ+w5FB4iIh3QNw46km8cdGSrrV+HrUREJDSFh4iIhKbwEBGR0BQeIiISmsJDRERCU3iIiEhoCg8REQlN4SEiIqEpPEREJDSFh4iIhKbwEBGR0BQeIiISmsJDRERCU3iIiEhoCg8REQlN4SEiIqEpPEREJLSMhIeZjTOzhWZWYmZTkyw3M7sxWP62me2ZaqyZ9TGzuWa2KLjt3Wid25tZuZmdn4k5iIhI+locHmYWBW4GDgFGAMeb2YhG3Q4Bhgc/k4Fb0hg7FZjn7sOBecHjhq4HHmtp/SIiEl4m9jxGASXuvtjdq4FZwIRGfSYAd3i9l4FeZjYwxdgJwMzg/kzgyI0rM7MjgcXAggzULyIiIWUiPLYDljV4XBq0pdOnubED3H0FQHDbH8DMugEXApenKszMJptZsZkVl5WVpT0hERFpXibCw5K0eZp90hnb2OXA9e5enqowd5/h7kXuXtSvX79U3UVEJE15GVhHKTC4weNBwPI0+8SaGbvSzAa6+4rgENeqoH00cIyZXQ30AhJmVuXuf8rAXEREJA2Z2PN4DRhuZkPNLAZMAmY36jMbODG46mpvYG1wKKq5sbOBk4L7JwEPAbj7/u4+xN2HADcAv1NwiIhkV4v3PNy91szOBp4AosBt7r7AzE4Plt8KzAHGAyVABXBKc2ODVU8H7jGznwBLgWNbWquIiGSGuac6xdAxFBUVeXFxca7LEBFpV8xsvrsXNW7XO8xFRCQ0hYeIiISm8BARkdAUHiIiEprCQ0REQlN4iIhIaAoPEREJTeEhIiKhKTxERCQ0hYeIiISm8BARkdAUHiIiEprCQ0REQlN4iIhIaAoPEREJTeEhIiKhKTxERCQ0hYeIiISm8BARkdAUHiIiEprCQ0REQsvLdQEi0rnFK+M8d+/LLFv4KUN23Z79Jo4mVpCf67IkBYWHiOTMqmWr+dneF1O5vpLK8ioKu3fhbxffxZ9e/j29B/TKdXnSDB22EpGc+eMZM1izai2V5VUAVJZX8fmnX3LLebfntjBJSeEhIjmRSCQofuItEnWJTdrraut48cHXclSVpEvhISI5Y2ZJ2yOR5O3Sdig8RCQnIpEI+xxeRDQ/ukl7XiyPA47bN0dVSboUHiKSMz+/5af0H9yXwq26EM2LUrhVF7bdaQCn/eHEXJcmKehqKxHJmd79e/L3D/7IK3Nep3ThcnYYOZiisbsTjUZTD5acUniISEYt//ANHrnlNlYtXcue3x3BQSedTmG3Xk32j+ZF2feIb2WvQMkIHbYSkYxwr2HJK+ez8t0fs1vRf4hv+IxbprzKGXv8mPIvV+a6PMkwhYeItJh7HYkvTmZAv4fZfd8NjPrueqbe/AknT/2MVcuMe6dPz3WJkmEKDxFpufjTePxtunT1r5oKuznjf/Q5fQfW8NwDn+SwOGkNGQkPMxtnZgvNrMTMpiZZbmZ2Y7D8bTPbM9VYM+tjZnPNbFFw2ztoH2Nm883sneD2O5mYg4hsOY8/SyQS36w9UQe7f7ucLt30vo2OpsXhYWZR4GbgEGAEcLyZjWjU7RBgePAzGbgljbFTgXnuPhyYFzwGWA0c7u7fAE4C7mzpHESkhSJ9SHb9TSJhVFVEmHDG3tmvSVpVJvY8RgEl7r7Y3auBWcCERn0mAHd4vZeBXmY2MMXYCcDM4P5M4EgAd3/D3ZcH7QuALmZWkIF5iMgWssKjSRoedUbXXoMZc+q52S9KWlUmwmM7YFmDx6VBWzp9mhs7wN1XAAS3/ZM899HAG+6++f4yYGaTzazYzIrLysrSnI6INOZejcefw6vm4ol1my23vO2h57Vg3XDrTl1dAfGqrsTzruDc2/5ONKp3BXQ0mfgXTXYw09Psk87Y5E9qNhK4Cji4qT7uPgOYAVBUVJTWekVkU0vffZIlr1zBoJ2qGbpLNXgN3uM3RLoes0m/SOEYvMsrWPUbRKyAvPzdKDRdk9NRZSI8SoHBDR4PApan2SfWzNiVZjbQ3VcEh7hWbexkZoOAfwEnuvtHGZiDiDRSHa9h2nHX8Prc+eTl96Wu1hi+WwXT7lxCV36Lx/bA8oZtMsYsBgWjc1SxZFMm/ix4DRhuZkPNLAZMAmY36jMbODG46mpvYG1wKKq5sbOpPyFOcPsQgJn1Ah4FLnL3FzNQv4g04O6899JCLj18OsWPv0V1VYSK9VHilREWvtGVmy/ZDqjBKx/IdamSQy3e83D3WjM7G3gCiAK3ufsCMzs9WH4rMAcYD5QAFcApzY0NVj0duMfMfgIsBY4N2s8GhgGXmtmlQdvB7v7VnomIbJlPFy3nF/tdytqyzc9rANRUR3j2oV6cd90yoknOfUjnYe6d41RAUVGRFxcX57oMkTZr/ZflHD/oNOKV1c32i0Sc2UtKiPW/ASs4MDvFSc6Y2Xx3L2rcrksgRDq50kUrmH3z4xQ/+VbK4ABn2G5V5HcfBbH/yUp90jYpPEQ6qZrqGv73wrt48KY5eCL1EYi8fMgviHDOn4/Hek3CdCVVp6bwEOmEvli5hinfu5xPFpSm1b+ga4yJvziUI84cS9/ttm7l6qQ9UHiIdCLuzu2/+Sf3XPUgtTV1aY2JRCPc8MIVDNtjaCtXJ+2J9jtFOpEXH3yVB65/JK3gsIgx7JtDuf3DGxUcshnteYh0Ig/c8ChVG5J+ms9X8gvymHzNCRx22sHk5et/EZKcXhkiHdiGL96kZv0zdO8zkGi38az/sjzlmIOO348jzhxHJKIDE9I0vTpEOqCqDVW8/q8J2Nrv04VbiK/6DTXL92Ximf3IL0j+N+NWfbpz5SMXccFtZyk4JCW9QkQ6oIeun8LOuy2kS6ETK3AKuyXIy4vznUP/Sb/telJQGAPArP4w1Y+vPJ77Vv2NUeP3TLFmkXo6bCXSQaxe/gUzfz2Llx+Zz89//xaF3RKb9amrqebW147n4b+tpfjxN+g7eGsm/vxQhn1TJ8QlHIWHSAew/styztxrCus+L6euto5oNPmb/hIJp1vXPI47/wiOO/+ILFcpHYkOW4l0AI/OeIoN6yqpq62/BHfe/X2o3LD5r3c0LwKxzT6mSCQ0hYdIO+TuvP/KIp7554uUfrict59dQHWDz6V6/pGeFD+9FZUbIiQSUF1lxCsjVHAF+tZmyQQdthJpZ9aUrWXKmGms+OgzLGLU1dTRd9DWRPIiJGrrz3O4G1dM3oFd967g2+Pj9BqwPSO/cx4Dd/pGjquXjkLhIdJO1NVV8Y/Lr+Gu6W9TV7vpstWln2ONvtU5Lz+P6ppvcMyvrspildJZKDxE2oHXn3qNKyddxbovgEYhAVBdVUPvbXpSUFjA58u/xN3Za8xuTLn97KzXKp2DwkOkDXN3bjjtT8z563MkC42GEnXOHSV/4suVayjoWkC3Hl2zU6R0SgoPkTaqeM5dPH/vTB6/cytSXdsSzYuyz+F7YWb02aZ3dgqUTk3hIdLGvP38e/zygN8ADvQg1R5HJGL07NeDk6cdn43yRACFh0ibctLXzmZ5yUrqg6P50ACwiHPkzw7lxMuOpVvPbq1en8hGCg+RNmBt2TJ+tOM5VG3YGBipgsPZqleCyx74ObsdeGArVyeyOYWHSI5de/IpPH7nenAjndAAY9/Dh/Kre6eRH+uShQpFNqfwEMmR+66/m8f+8n8s/bCQdA5RgbPf4Xn84LLfMfybO7Z2eSLNUniIZFnlhnIuOOBYFr3ZlUQineCo/5DDo87qwRk3zsBMv7aSe3oVimTR36eeRn7kXQ4/sZY3d6zm2Yd7UhOPphhlTDjzYM686adZqVEkHQoPkSwoXbKUZ/93Esf99HOieU6sAL49fi3Hnb2Kcw4bTkV54wCp39uIxvJ4aO1MCgr0YYbStig8RFrZMf2PZd3nCe6av4bCbv/9no2u3RNss301R00u467rtmkwwunaPcHF/5zC6EP2zn7BImnQR7KLtJJnH3qWMZGjWbvaGbRTnK5b1W3Wp6CLc+CENcEjJz9WxwFH5fOvNfcpOKRN056HSCs4dsBhVKyPYZEInjDilVEiTfypFq+MAM6AwXHOu+1c9vzud7Jaq8iWUHiIZNAvDvw5HxUvI16x8f0X9VdSrfo0xrKPChi6SxXRBqc3KjcYD8/cmpF7V3DdC48QaSphRNoYvVJFMmRM5Cjee2458Yoo9aGx6SW4004dwuef5VOxPkJleYR4pfHav7di8B7f54b/zFFwSLuiPQ+RFhoTOZj6DzDcGBrJfba0gJNG78Ie+6+nT/8aBg6Jc+JVL2arTJGM0p86IlvoggvO5ei+R1AfHJDOm/0SCXj92e4M2+dUBYe0axkJDzMbZ2YLzazEzKYmWW5mdmOw/G0z2zPVWDPrY2ZzzWxRcNu7wbKLgv4LzWxsJuYgEsa4bmNYcNMnrPsiFrSk9y5xMOYm7ufoc37QitWJtL4Wh4eZRYGbgUOAEcDxZjaiUbdDgOHBz2TgljTGTgXmuftwYF7wmGD5JGAkMA74c7AekVb31xk3MTZ/Il7Tg5rqCMnObWzKAWfQsErO+vPxzE3cm51CRVpZJs55jAJK3H0xgJnNAiYA7zXoMwG4w90deNnMepnZQGBIM2MnAAcG42cCzwAXBu2z3D0OLDGzkqCGlzIwF5EmjYkcDhTw37+50tvb6Duwkr9/+GgrViaSfZk4bLUdsKzB49KgLZ0+zY0d4O4rAILb/iGeDwAzm2xmxWZWXFZWlvaERBqa98i/GRM5mvrggHT3NsA5+uL9uftTBYd0PJnY80j2W+Rp9kln7JY8X32j+wxgBkBRUVGq9YpsZkzkMKALqQNjo40vsxrmJh5qtbpEci0Tex6lwOAGjwcBy9Ps09zYlcGhLYLbVSGeT6RFHrv70WBvI93gqN/T6NazFgpQcEiHl4k9j9eA4WY2FPiU+pPZjS8lmQ2cHZzTGA2sdfcVZlbWzNjZwEnA9OD2oQbt/zCz64BtqT8J/2oG5iECwI92OJSVpRu/ZyO9vY3uvWrJy4tz76rHWrk6kbahxeHh7rVmdjbwBPXvkrrN3ReY2enB8luBOcB4oASoAE5pbmyw6unAPWb2E2ApcGwwZoGZ3UP9SfVa4Cx33/wT50RCev7BZ7np9Gv5clVh0JLeCfFoXh3lOztzX1JwSOdh9RdAdXxFRUVeXFyc6zKkjRoTOZaG78Vo3n9/ZyxvA09WKzSk4zKz+e5e1Lhd7zCXTu2D4jf59bgxhA+OBH/7+FoFh3Ra+mwr6bQuPvgEIraSeGWEbj1q2bAuP8WIhu8Sf6C1yxNp0xQe0umUfvgi1508jYVvdgO2Ihp1IlHn69/cwAdvdEsy4r97GwoNkXoKD+k0EokEP9nlh+Tnr2PFx12prtr0qO3KT528WILa6o3t9aHRc+sadhq9C1c9Mj3LFYu0XQoP6RSef/AVfjvxGgCG7mJUVWz+cWhVGyIM362C94u7A/V7Iz1613DvqtlZrlak7VN4SIfm7pxRNIWP3vg4aDHWrG7mczT9q/8w7ZGpjBr7rVauUKR9UnhIhxWv2MCZRRew9INVNHzD35rV+UTzEtTVbnrYyqz+ENX2u8T424J/ZL9gkXZE4SEdjrvz5K2nMXC7V1m6cBiNL791NzwBBYV1xCuj5OUniEZh5KhyzrvrAXpv3SP5ikXkKwoP6VDcq5j7lzPYf+yLzLmrb5Mfs5lIQN9taujeq4r8fGfiOf/D/pMuym6xIu2YwkM6hPJ1Fdz1qzNZXVpK322q2WW3GFv1rMMMmvoQhU+XFNBj6wT/t/RuCgsLkncSkaQUHtLufbZkJWcWnUFNHKoqehPNS/Dw7f0459pl5McSVMeTnSA3Tr7iOH548bFZr1ekI1B4SLtVHd/ALT+7ksdu/zA4+V1/bqOuNkJdLdw0dRDT7lzCZT8eStWGSLAHYuy0+w7c+sYfclm6SLun8JB2afFb7/CLfS+jqrKZj013iHVx7n/vXd58oTtvvTqaE39/I7GCVB9DIiKpKDykXfli5RpumHwrLz08n1Tft5FIQKxLgsXvFbI+fhSnXndp1uoU6egUHtJuLF34KT/d9TwSdYmgpengMHN6bl1Lv50vpNegH7CzpfOlTiKSLoWHtBuXH/2HBsHRFCea5xR2S3Dq78fTe/APs1KbSGej8JB2oaa6hqXvlabsF4nAXgcXceGdP6NH7+5ZqEykc1J4SLtgZvVHqZr84kunoCtc9eQvGbnvPlmsTKRz0jcJSruQl5/HbgeMTLrMIvDjaftwf9kdCg6RLNGeh7QZ678s5+WH51NbU8voQ/ekzza9N1l+0Z0/44y9prCmbN1XeyCxLvnc+NKV7LT70BxULNJ5mTf12Q0dTFFRkRcXF+e6DGnC8w+8wlUn3EgkGsHdSdQlOO0PJ3LEmeM26VdbU8tLs4tZ8NKHDN11MGNOPIBIRDvQIq3FzOa7e9Fm7QoPybW1q9fxg+3PoLqqepP2WGGMW1+/msE7b5ejykSkqfDQn2ySc/956DUi0c3fh1FXU8fTs17MQUUikorCQ3KuJl5LIrH5HnCiLkF1ZXWSESKSawoPybnRh+6Z9HPTY4Uxvn3U6BxUJCKpKDwk5wbs0I8TLzuOgsIYkWgEM6OgawFjTz6QXUYPz3V5IpKELtWVNuH7U46kaOwe/Puu56mtreOAY/dhxD4757osEWmCwkNazVvPLGD2nx9n/Zcb2H/iaMaechCxLrEm+++0+xB22n1I9goUkS2m8JBW8c+rH+TO395HvCIOwHsvfcicv87jjy9e0WyAiEj7oHMeknHrPl/PHZfd81VwAMQr4pR+uJx//+OFHFYmIpmi8JCMe/eFD8iLbb5TW7Uhzgv/eiUHFYlIpumwlbRIIpHghtP+wpMzn6Wuto7C7l04dPIYkn1ygUWM3v175qBKEck0hYe0yG+OupqXH57/1ePK8iruu+5huvfuhtmmb9+IdcnnsNMPzkGVIpJpLTpsZWZ9zGyumS0Kbns30W+cmS00sxIzm5rOeDO7KOi/0MzGBm1dzexRM/vAzBaY2fSW1C8tU1VRtUlwNJRIOP0G96Wwexe69ehKQWGMM64/mZ2/NSzLVYpIa2jpnsdUYJ67Tw9CYSpwYcMOZhYFbgbGAKXAa2Y2293fa2q8mY0AJgEjgW2Bp8zsa8Eq/+DuT5tZDJhnZoe4+2MtnIdsgSXvLG1yWeW6Sh784nY+eLWEinUV7LL31+i6VWEWqxOR1tTSE+YTgJnB/ZnAkUn6jAJK3H2xu1cDs4JxzY2fAMxy97i7LwFKgFHuXuHuTwME63odGNTCOcgWGrzztk0u69K9ADNjl9HD2WvM7goOkQ6mpeExwN1XAAS3/ZP02Q5Y1uBxadDW3PjmxgBgZr2Aw4F5TRVnZpPNrNjMisvKytKdk6Spe6/u7Dwq+WGoH1w0McvViEg2pQwPM3vKzN5N8jMh1diNq0jSlupLRJodY2Z5wN3Aje6+uKmVuPsMdy9y96J+/fqlVayEc91zv2XX/b7+1eNINMLEcw5l0tSjcliViLS2lOc83P17TS0zs5VmNtDdV5jZQGBVkm6lwOAGjwcBy4P7TY1vbgzADGCRu9+Qqn5pXbFYPtc/N43a2lrWrV5Pr/499c1+Ip1AS3/LZwMnBfdPAh5K0uc1YLiZDQ1Ock8KxjU3fjYwycwKzGwoMBx4FcDMrgB6Aue0sHbJoLy8PPps01vBIdJJtPQ3fTowxswWUX811XQAM9vWzOYAuHstcDbwBPA+cI+7L2hufLD8HuA94HHgLHevM7NBwCXACOB1M3vTzE5t4RxERCQkfYe5iIg0Sd9hLiIiGaPwEBGR0BQeIiISmsJDRERCU3iIiEhoCg8REQlN4SEiIqEpPEREJDSFh4iIhKbwEBGR0BQeIiISmsJDRERCU3iIiEhoCg8REQlN4SEiIqEpPEREJDSFh4iIhKbwEBGR0BQeIiISmsJDRERCU3iIiEhoCg8REQlN4SEiIqEpPEREJDSFh4iIhKbwEBGR0BQeIiISmsJDRERCU3iIiEhoCg8REQlN4SEiIqEpPEREJDSFh4iIhNai8DCzPmY218wWBbe9m+g3zswWmlmJmU1NZ7yZXRT0X2hmY5Osc7aZvduS+kVEZMu0dM9jKjDP3YcD84LHmzCzKHAzcAgwAjjezEY0Nz5YPgkYCYwD/hysZ+M6JwLlLaxdRES2UEvDYwIwM7g/EzgySZ9RQIm7L3b3amBWMK658ROAWe4ed/clQEmwHsysO3AecEULaxcRkS3U0vAY4O4rAILb/kn6bAcsa/C4NGhrbnxzY6YB1wIVqYozs8lmVmxmxWVlZenNSEREUspL1cHMngK2SbLokjSfw5K0+ZaMMbM9gGHufq6ZDUn1xO4+A5gBUFRUlOo5N5FIJFhcfC/lKx8EYmy948kMHnlQmFWIiHRYKcPD3b/X1DIzW2lmA919hZkNBFYl6VYKDG7weBCwPLjf1PimxuwD7GVmHwe19zezZ9z9wFTzCCORSPDOnIkMG/EBBYMSeAJqa17mzTlHsMf4azL5VCIi7VJLD1vNBk4K7p8EPJSkz2vAcDMbamYx6k+Ez04xfjYwycwKzGwoMBx41d1vcfdt3X0IsB/wYaaDA2Dx/PsZNuIDCrsmiEQgmgcFhc7OI2bzxYpFmX46EZF2p6XhMR0YY2aLgDHBY8xsWzObA+DutcDZwBPA+8A97r6gufHB8nuA94DHgbPcva6Ftaat/LMHKOiS2Kw9UWeUvnN3tsoQEWmzUh62ao67fw58N0n7cmB8g8dzgDnpjg+WXQlc2cxzfwzsGrrotMRIJCDSKFrdIRIpbJ2nFBFpR/QO8yT67nQKdTWbn7OPRGBo0Qk5qEhEpG1ReCQxaMSBLHz/CKqrjMoNESrKI1RVRFi6/AK69Up24ZmISOfSosNWHdke46/hixWTKX1nFpFIF4YWncDXd1RwiIiAwqNZfQYOp8/AS3NdhohIm6PDViIiEprCQ0REQlN4iIhIaAoPEREJTeEhIiKhmXuoD5ttt8ysDPiklVbfF1jdSutuDzr7/EHboLPPHzruNtjB3fs1buw04dGazKzY3YtyXUeudPb5g7ZBZ58/dL5toMNWIiISmsJDRERCU3hkxoxcF5BjnX3+oG3Q2ecPnWwb6JyHiIiEpj0PEREJTeEhIiKhKTwaMLM+ZjbXzBYFt72b6DfOzBaaWYmZTU1nvJldFPRfaGZjk6xztpm92zozS1+2t4GZdTWzR83sAzNbYGbTW3+W6c+nwXIzsxuD5W+b2Z6pxrbk9ZAL2dwGZjbGzOab2TvB7XeyM8umZfs1ECzf3szKzez81p1dK3B3/QQ/wNXA1OD+VOCqJH2iwEfAjkAMeAsY0dx4YETQrwAYGoyPNljnROAfwLudbRsAXYGDgj4x4HngkCzPucn5NOgzHngMMGBv4JXWej3k6N8929vgm8C2wf1dgU870/wbrPN+4F7g/FzOf0t+tOexqQnAzOD+TODIJH1GASXuvtjdq4FZwbjmxk8AZrl73N2XACXBejCz7sB5wBUZncmWy+o2cPcKd38aIFjX68CgjM4otebms9EE4A6v9zLQy8wGphgb+vWQQ1ndBu7+hrsvD9oXAF3MrKCV5paObL8GMLMjgcXUz7/dUXhsaoC7rwAIbvsn6bMdsKzB49KgrbnxzY2ZBlwLVGRiAhmQi20AgJn1Ag4H5rVsCqGlrK2ZPq2yLXIg29ugoaOBN9w9vsXVt1xW529m3YALgcszVH/WdbpvEjSzp4Bk3yd7SbqrSNKW6nrnpGPMbA9gmLufa2ZD0nz+FmtL26BBTXnA3cCN7r44zToyJZ35NNUn49siR7K9DepXaDYSuAo4OJ3+rSjb878cuN7dy82SDW/7Ol14uPv3mlpmZivNbKC7rwh2R1cl6VYKDG7weBCwcfe7qfFNjdkH2MvMPqb+36K/mT3j7gduwdTS1sa2wUYzgEXufkO42WREqtqa6xNrZuyWbotcyPY2wMwGAf8CTnT3jzIyiy2X7fmPBo4xs6uBXkDCzKrc/U+ZmExW5PqkS1v6Aa5h05NbVyfpk0f9ccqh/Pfk2MjmxgMj2fQE6WIanSAFhtA2TphnfRtQf77nfiCSozk3OZ8GfQ5l05Olr7bm66ETbINeQb+jc/2az8X8G633MtrhCfOcF9CWfoCtqT/evii47RO0bwvMadBvPPAh9VdYXJJqfLDskqD/QpJcTUTbCY+sbgPq/0pz4H3gzeDn1BzMe7P5AKcDpwf3Dbg5WP4OUNSar4cc/dtnbRsAvwI2NPg3fxPo31nm3+h5L6Mdhoc+nkRERELT1VYiIhKawkNEREJTeIiISGgKDxERCU3hISIioSk8REQkNIWHiIiE9v9+tzreYco3nAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(pred_infl, pred_infl, c = y_outlier.numpy(), label = y_outlier.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93991588",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = np.where(y_outlier.numpy() == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b61897f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_infl[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ec215f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_inf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33442647",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
