{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e06b5c1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:10:21.054961Z",
     "start_time": "2024-11-20T06:10:21.051627Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fed698e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:10:25.227794Z",
     "start_time": "2024-11-20T06:10:21.058682Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from scipy import sparse\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.metrics import accuracy_score\n",
    "import urllib.request\n",
    "import tarfile\n",
    "import os\n",
    "from Linear_Separability_numpy_mul import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d80bdd5",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97055052",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:10:25.247566Z",
     "start_time": "2024-11-20T06:10:25.231988Z"
    }
   },
   "outputs": [],
   "source": [
    "def download_and_extract_cora():\n",
    "    url = 'https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz'\n",
    "    filename = 'cora.tgz'\n",
    "    if not os.path.exists('cora'):\n",
    "        print(\"Downloading Cora dataset...\")\n",
    "        urllib.request.urlretrieve(url, filename)\n",
    "        print(\"Extracting Cora dataset...\")\n",
    "        with tarfile.open(filename, 'r:gz') as tar_ref:\n",
    "            tar_ref.extractall()\n",
    "        os.remove(filename)\n",
    "    else:\n",
    "        print(\"Cora dataset already exists.\")\n",
    "\n",
    "def load_data():\n",
    "    download_and_extract_cora()\n",
    "    print(\"Loading and preprocessing data...\")\n",
    "\n",
    "    content = []\n",
    "    with open(\"cora/cora.content\", 'r') as f:\n",
    "        for line in f:\n",
    "            content.append(line.strip().split())\n",
    "\n",
    "    node_ids = [int(line[0]) for line in content]\n",
    "    id_map = {j: i for i, j in enumerate(node_ids)}\n",
    "\n",
    "    features = np.array([list(map(float, line[1:-1])) for line in content])\n",
    "    labels = [line[-1] for line in content]\n",
    "\n",
    "    label_encoder = LabelEncoder()\n",
    "    labels = label_encoder.fit_transform(labels)\n",
    "\n",
    "    edges = []\n",
    "    with open(\"cora/cora.cites\", 'r') as f:\n",
    "        for line in f:\n",
    "            src, dst = map(int, line.strip().split())\n",
    "            if src in id_map and dst in id_map:\n",
    "                edges.append((id_map[src], id_map[dst]))\n",
    "\n",
    "    num_nodes = len(node_ids)\n",
    "    adj = sparse.lil_matrix((num_nodes, num_nodes))\n",
    "    for i, j in edges:\n",
    "        adj[i, j] = 1\n",
    "        adj[j, i] = 1 \n",
    "\n",
    "    adj = preprocess_adj(adj)\n",
    "\n",
    "    return features, adj, labels\n",
    "\n",
    "def preprocess_adj(adj):\n",
    "    adj = sparse.coo_matrix(adj)\n",
    "    adj_ = adj + sparse.eye(adj.shape[0])\n",
    "    rowsum = np.array(adj_.sum(1))\n",
    "    d_inv_sqrt = np.power(rowsum, -0.5).flatten()\n",
    "    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.\n",
    "    d_mat_inv_sqrt = sparse.diags(d_inv_sqrt)\n",
    "    adj_normalized = adj_.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()\n",
    "    return adj_normalized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45af91e5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:10:28.954100Z",
     "start_time": "2024-11-20T06:10:25.250233Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "features, adj, labels = load_data()\n",
    "\n",
    "\n",
    "features = tf.constant(features, dtype=tf.float32)\n",
    "labels = tf.constant(labels, dtype=tf.int32)\n",
    "\n",
    "\n",
    "adj = adj.tocoo()\n",
    "indices = np.column_stack((adj.row, adj.col))\n",
    "adj_tf = tf.SparseTensor(indices=indices, values=adj.data.astype(np.float32), dense_shape=adj.shape)\n",
    "adj_tf = tf.sparse.reorder(adj_tf)\n",
    "\n",
    "\n",
    "num_nodes = features.shape[0]\n",
    "idx = np.arange(num_nodes)\n",
    "np.random.shuffle(idx)\n",
    "train_size = int(0.8 * num_nodes)\n",
    "idx_train = idx[:train_size]\n",
    "idx_test = idx[train_size:]\n",
    "\n",
    "partial_indices = np.random.choice(idx_train, 500, replace=False)  \n",
    "partial_labels = labels.numpy()[partial_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df7977d7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:10:28.970788Z",
     "start_time": "2024-11-20T06:10:28.957362Z"
    }
   },
   "outputs": [],
   "source": [
    "idx_train.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e09ff05a",
   "metadata": {},
   "source": [
    "# network structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2d914e9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:10:28.985209Z",
     "start_time": "2024-11-20T06:10:28.973047Z"
    }
   },
   "outputs": [],
   "source": [
    "class GCN(tf.keras.Model):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, adj):\n",
    "        super(GCN, self).__init__()\n",
    "        self.adj = adj\n",
    "        self.gcn = tf.keras.layers.Dense(hidden_dim, activation='relu', name='gcn_layer')\n",
    "        self.output_layer = tf.keras.layers.Dense(output_dim, name='output_layer')\n",
    "\n",
    "    def call(self, x, return_intermediate=False):\n",
    "        intermediate_outputs = {}\n",
    "        \n",
    "        x = tf.sparse.sparse_dense_matmul(self.adj, x)\n",
    "        x = self.gcn(x)\n",
    "        intermediate_outputs['gcn_layer'] = x\n",
    "        \n",
    "        x = tf.sparse.sparse_dense_matmul(self.adj, x)\n",
    "        x = self.output_layer(x)\n",
    "        intermediate_outputs['output_layer'] = x\n",
    "        \n",
    "        if return_intermediate:\n",
    "            return x, intermediate_outputs\n",
    "        else:\n",
    "            return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdc62ca3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:10:29.202880Z",
     "start_time": "2024-11-20T06:10:28.987493Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "input_dim = features.shape[1]\n",
    "hidden_dim = 16\n",
    "output_dim = len(np.unique(labels))\n",
    "adj_tf = adj_tf  \n",
    "\n",
    "model_with_intermediate = GCN(input_dim, hidden_dim, output_dim, adj_tf)\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)\n",
    "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4712067",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:10:29.901667Z",
     "start_time": "2024-11-20T06:10:29.205549Z"
    }
   },
   "outputs": [],
   "source": [
    "num_epochs = 200\n",
    "\n",
    "output, intermediate_outputs = model_with_intermediate(features, return_intermediate=True)\n",
    "LS_1_squence = np.zeros((len(intermediate_outputs),num_epochs))\n",
    "LS_2_squence = np.zeros((len(intermediate_outputs),num_epochs))\n",
    "J_w_squence = np.zeros((len(intermediate_outputs),num_epochs))\n",
    "LDA_squence = np.zeros((len(intermediate_outputs),num_epochs))\n",
    "\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2343048d",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5341e3b8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:14:39.231441Z",
     "start_time": "2024-11-20T06:10:29.906987Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "for epoch in range(num_epochs):\n",
    "    with tf.GradientTape() as tape:\n",
    "        logits = model_with_intermediate(features)\n",
    "        loss = loss_fn(tf.gather(labels, idx_train), tf.gather(logits, idx_train))\n",
    "    grads = tape.gradient(loss, model_with_intermediate.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(grads, model_with_intermediate.trainable_variables))\n",
    "\n",
    "\n",
    "    train_pred = tf.argmax(tf.gather(logits, idx_train), axis=1)\n",
    "    train_acc = accuracy_score(tf.gather(labels, idx_train).numpy(), train_pred.numpy())\n",
    "    train_accuracy_squence[epoch] = train_acc\n",
    "\n",
    "\n",
    "    test_logits = model_with_intermediate(features)\n",
    "    test_pred = tf.argmax(test_logits, axis=1)\n",
    "    test_acc = accuracy_score(tf.gather(labels, idx_test).numpy(), tf.gather(test_pred, idx_test).numpy())\n",
    "    test_accuracy_squence[epoch] = test_acc\n",
    "    print(f\"Epoch {epoch}, Loss: {loss.numpy():.4f}, Train Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}\")\n",
    "\n",
    "    output, intermediate_outputs = model_with_intermediate(features, return_intermediate=True) \n",
    "\n",
    "    layers_i = 0\n",
    "    for name, intermediate_output in intermediate_outputs.items():\n",
    "        #print(f\"{name} output: {partial_output.shape}\")\n",
    "        partial_output = tf.gather(intermediate_output, partial_indices)\n",
    "        LS_1_squence[layers_i,epoch],LS_2_squence[layers_i,epoch],J_w_squence[layers_i,epoch],\\\n",
    "        LDA_squence[layers_i,epoch]=W(partial_output,partial_labels.reshape(-1,1))\n",
    "        layers_i = layers_i+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b8faeca",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:14:39.251114Z",
     "start_time": "2024-11-20T06:14:39.239025Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "output, intermediate_outputs = model_with_intermediate(features, return_intermediate=True)\n",
    "\n",
    "\n",
    "for name, intermediate_output in intermediate_outputs.items():\n",
    "    partial_output = tf.gather(intermediate_output, partial_indices)\n",
    "    print(f\"{name} output: {partial_output.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5bf2ae0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:14:41.425472Z",
     "start_time": "2024-11-20T06:14:39.256844Z"
    }
   },
   "outputs": [],
   "source": [
    "x_plot = np.arange(num_epochs)\n",
    "layer_name_list = ['gcn_layer_1','output_layer']\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a9c5997",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:14:41.432722Z",
     "start_time": "2024-11-20T06:14:41.428210Z"
    }
   },
   "outputs": [],
   "source": [
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence,'LS_2_squence':LS_2_squence,'J_w_squence':J_w_squence,\n",
    "            'LDA_squence':LDA_squence,'train_loss_squence':train_loss_squence,'train_accuracy_squence':train_accuracy_squence,'test_loss_squence':test_loss_squence,'test_accuracy_squence':test_accuracy_squence,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "210ce64e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T06:14:41.439885Z",
     "start_time": "2024-11-20T06:14:41.435798Z"
    }
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "def save_info(info, file_path):\n",
    "    with open(file_path, 'wb') as file:\n",
    "        pickle.dump(info, file)\n",
    "\n",
    "save_info(info, '../saved_data/hidden_layers_1.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fca37e72",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ba8a7fb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a50c88c7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-lsm",
   "language": "python",
   "name": "tensorflow-lsm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
