{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b9a64a2-0b33-40ab-a7fd-ddc8627eccfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "import argparse\n",
    "import torch_geometric\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from scipy.stats import pearsonr\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64a92301-7cf9-446a-a2af-b64cf20c4656",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TUDataset('data', name='MUTAG')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8b01744-de54-4c90-a4e7-a6a08be6555a",
   "metadata": {},
   "outputs": [],
   "source": [
    "M = np.load('ggd_values.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04f84793-b942-499f-9154-dc7d2fd9e65f",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = [i for i in range(len(dataset))]\n",
    "random.shuffle(idx)\n",
    "\n",
    "n = len(dataset) // 10\n",
    "\n",
    "idx_train = idx[n:]\n",
    "idx_test = idx[:n]\n",
    "train_dataset = dataset[idx_train]\n",
    "test_dataset = dataset[idx_test]\n",
    "M = M[idx, :]\n",
    "M = M[:, idx]\n",
    "\n",
    "y = []\n",
    "for i in range(len(dataset)):\n",
    "    y.append(dataset[i].y)\n",
    "y = np.array(y)\n",
    "\n",
    "\n",
    "# Cross Val\n",
    "M_cv = M[n:, n:]\n",
    "lams = [0.01, 0.075, 0.1]\n",
    "best_k_count = np.zeros(len(lams))\n",
    "for it in range(10):\n",
    "    idx_cv = [i for i in range(len(train_dataset))]\n",
    "    random.shuffle(idx_cv)\n",
    "    n_cv = len(train_dataset) // 10\n",
    "    idx_train_cv = idx_cv[n_cv:]\n",
    "    idx_test_cv = idx_cv[:n_cv]\n",
    "\n",
    "    for lam in lams:\n",
    "        model = SVC(kernel = 'precomputed')\n",
    "        model.fit(np.exp(-lam * M_cv[idx_train_cv][:, idx_train_cv]), y[idx_train][idx_train_cv])\n",
    "        y_pred = model.predict(np.exp(-lam * M_cv[idx_test_cv][:, idx_train_cv]))\n",
    "        acc = sum(y_pred == y[idx_train][idx_test_cv]) / len(y_pred)\n",
    "        best_k_count[lams.index(lam)] += acc\n",
    "\n",
    "best_lam = np.argmax(best_k_count)\n",
    "lam = lams[best_lam]\n",
    "M_ = np.exp(-lam*M)\n",
    "model = SVC(kernel = 'precomputed')\n",
    "model.fit(M_[n:, n:], y[idx_train])\n",
    "\n",
    "y_pred = model.predict(M_[:n, n:])\n",
    "acc = sum(y_pred == y[idx_test]) / len(y_pred)\n",
    "print('{}, L: {}, w: {}, Acc: {}'.format(dataset_name, L, w, acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87912f0-c6ee-4493-8b4a-72bc095205c7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce6f0a04-b619-4bc9-a827-d8c74a0baa9f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b03d6aa0-e38b-4bc8-959c-57cb9ec7f570",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
