{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34e767c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "physics\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import dgl\n",
    "import dgl.function as fn\n",
    "from utils_data import load_dataset\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib.colors import ListedColormap\n",
    "from scipy.stats import gaussian_kde\n",
    "import os\n",
    "from matplotlib.ticker import MultipleLocator, FormatStrFormatter\n",
    "from model import preprocess_feature\n",
    "\n",
    "def get_ent(true_dist, false_dist, pos_scores, neg_scores):\n",
    "    \n",
    "    d_true_pos = true_dist(pos_scores)\n",
    "    d_false_pos = false_dist(pos_scores)\n",
    "    \n",
    "    d_true_neg = true_dist(neg_scores)\n",
    "    d_false_neg = false_dist(neg_scores)\n",
    "    \n",
    "    p_true_pos = d_true_pos / (d_true_pos + d_false_pos)\n",
    "    p_false_neg = d_false_neg / (d_true_neg + d_false_neg)\n",
    "    \n",
    "    ent = - (np.log(p_true_pos).sum() + np.log(p_false_neg).sum() ) / (pos_scores.shape[0] + neg_scores.shape[0])\n",
    "    \n",
    "    return ent\n",
    "\n",
    "datasets = ['cora','citeseer','pubmed'] + ['computer','photo','cs','physics'] \n",
    "\n",
    "\n",
    "\n",
    "NAME_DICT = {'cora': 'Cora',\n",
    "             'citeseer': 'Citeseer',\n",
    "             'pubmed': 'Pubmed',\n",
    "            'computer': 'Amazon-Computer',\n",
    "            'photo': 'Amazon-Photo',\n",
    "            'cs': 'Coauthor-CS',\n",
    "            'physics': 'Coauthor-Physics'}\n",
    "\n",
    "SIZE = 19\n",
    "\n",
    "fig = plt.figure(figsize=(24, 9), dpi=400)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "for name in datasets[6:]:\n",
    "    print(name)\n",
    "    num = 1\n",
    "    dataset = load_dataset(name)\n",
    "    graph = dataset[0]\n",
    "    feat = dataset[1]\n",
    "\n",
    "    mode = 'transductive'\n",
    "    dataname = name\n",
    "\n",
    "    if not os.path.exists(f'store_feats/{mode}/{dataname}'):\n",
    "        os.makedirs(f'store_feats/{mode}/{dataname}')\n",
    "\n",
    "    if not os.path.exists(f'store_feats/{mode}/{dataname}/10_hop.npy'):\n",
    "        Feat = preprocess_feature(graph, feat)\n",
    "        np.save(f'store_feats/{mode}/{dataname}/10_hop.npy', Feat.cpu().numpy())\n",
    "    else:\n",
    "        Feat = np.load(f'store_feats/{mode}/{dataname}/10_hop.npy') \n",
    "        Feat = torch.from_numpy(Feat)\n",
    "\n",
    "\n",
    "    for i in range(10):\n",
    "        ax = fig.add_subplot(2, 5, num)\n",
    "\n",
    "        if i == 0:\n",
    "            feat = feat\n",
    "        else:\n",
    "            feat = Feat[:i].sum(0) / i\n",
    "    #         feat = Feat[i+1]\n",
    "        num += 1       \n",
    "\n",
    "        neg_edges = dgl.sampling.global_uniform_negative_sampling(graph, graph.num_edges())\n",
    "        neg_graph = dgl.graph(neg_edges, num_nodes = graph.num_nodes())\n",
    "\n",
    "        feat = F.normalize(feat, p = 2, dim = -1)\n",
    "        graph.ndata['h'] = feat\n",
    "        graph = graph.remove_self_loop()\n",
    "        graph.apply_edges(fn.u_sub_v('h', 'h', 'm'))\n",
    "        m = graph.edata['m']\n",
    "        pos_m = m.pow(2).sum(-1)\n",
    "        if name in ['computer', 'photo', 'cs', 'physics']:\n",
    "        # subsampling\n",
    "            idx = np.arange(graph.num_edges())\n",
    "            np.random.shuffle(idx)\n",
    "\n",
    "            pos_m = pos_m[idx][:20000]\n",
    "        kde0 = gaussian_kde(pos_m, bw_method = 0.2)\n",
    "\n",
    "        feat = F.normalize(feat, p = 2, dim = -1)\n",
    "        neg_graph.ndata['h'] = feat\n",
    "        neg_graph = neg_graph.remove_self_loop()\n",
    "        neg_graph.apply_edges(fn.u_sub_v('h', 'h', 'm'))\n",
    "        m = neg_graph.edata['m']\n",
    "        neg_m = m.pow(2).sum(-1)\n",
    "        if name in ['computer', 'photo', 'cs', 'physics']:\n",
    "        # subsampling\n",
    "            idx = np.arange(graph.num_edges())\n",
    "            np.random.shuffle(idx)\n",
    "\n",
    "            neg_m = neg_m[idx][:20000]\n",
    "        kde1 = gaussian_kde(neg_m, bw_method = 0.2)\n",
    "\n",
    "\n",
    "        x = np.linspace(0, 2, 1000)\n",
    "        kde0_x = kde0(x)\n",
    "        kde1_x = kde1(x)\n",
    "\n",
    "        inters_x = np.minimum(kde0_x, kde1_x)\n",
    "\n",
    "        ax.plot(x, kde0_x, color='b', label=r'$p(\\hat{A}_{ij}|A_{ij} = 1)$')\n",
    "        ax.fill_between(x, kde0_x, 0, color='b', alpha=0.2)\n",
    "        ax.plot(x, kde1_x, color='orange', label=r'$p(\\hat{A}_{ij}|A_{ij} = 0)$')\n",
    "        ax.fill_between(x, kde1_x, 0, color='orange', alpha=0.2)\n",
    "        ax.plot(x, inters_x, color='r')\n",
    "        #     ax.fill_between(x, inters_x, 0, facecolor='none', edgecolor='r', hatch='xx', label='intersection')\n",
    "\n",
    "        ax.set_xlabel(r'$\\ell_2$ distance', size=SIZE+8)\n",
    "        ax.set_ylabel('probability denstiy', size=SIZE+2)\n",
    "        ax.set_ylim(0, 4)\n",
    "        ax.set_xlim(0, 2)\n",
    "\n",
    "        xmajorLocator = MultipleLocator(0.5)  \n",
    "        xminorLocator = MultipleLocator(0.5)\n",
    "        ax.xaxis.set_major_locator(xmajorLocator)\n",
    "        ax.xaxis.set_minor_locator(xminorLocator)\n",
    "        ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))\n",
    "\n",
    "\n",
    "        ymajorLocator = MultipleLocator(1)  \n",
    "        yminorLocator = MultipleLocator(0.5)\n",
    "        ax.yaxis.set_major_locator(ymajorLocator)\n",
    "        ax.yaxis.set_minor_locator(yminorLocator)\n",
    "        ax.yaxis.set_major_formatter(FormatStrFormatter('%d'))\n",
    "\n",
    "        labels = ax.get_xticklabels() + ax.get_yticklabels()\n",
    "        [label.set_fontsize(SIZE+4) for label in labels]\n",
    "\n",
    "        area_inters_x = np.trapz(inters_x, x)\n",
    "        handles, labels = plt.gca().get_legend_handles_labels()\n",
    "        #     labels[2] += f': {area_inters_x * 100:.1f} %'\n",
    "        plt.legend(handles, labels, prop = {'size': 18}, loc = 'best')\n",
    "\n",
    "        num_edges = pos_m.shape[0]\n",
    "        idx = np.arange(num_edges)\n",
    "\n",
    "        if num_edges > 10000:\n",
    "\n",
    "            np.random.shuffle(idx)\n",
    "            pos_m = pos_m[:10000]\n",
    "\n",
    "            np.random.shuffle(idx)\n",
    "            neg_m = neg_m[:10000]\n",
    "\n",
    "        ent = get_ent(kde0, kde1, pos_m, neg_m)\n",
    "\n",
    "        ent = np.around(ent, 4)\n",
    "#         plt.title(f'K = {i}, ' + r'$H(A|\\hat{A}) = $' +  f'{ent}', fontsize = 22)\n",
    "        plt.title(f'K = {i}', fontsize = 22)\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(f'FIGS/distribution_aug_{name}.pdf', bbox_inches='tight')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98ab3ec4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch] *",
   "language": "python",
   "name": "conda-env-torch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
