{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b52f73e7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "decomposition (generic function with 1 method)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "include(\"decomposition.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a77c929c",
   "metadata": {},
   "outputs": [],
   "source": [
    "using LightGraphs\n",
    "using TickTock\n",
    "using SparseArrays\n",
    "using ProgressMeter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3a123bb0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "get_term (generic function with 1 method)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function get_node_types(input::AbstractString)\n",
    "    if input == \"MUTAG\"\n",
    "        return 7\n",
    "    elseif input == \"PC-3H\"\n",
    "        return 46\n",
    "    elseif input == \"SW-620H\"\n",
    "        return 66\n",
    "    end\n",
    "end\n",
    "\n",
    "function get_term(str::AbstractString, num::Int)\n",
    "    if str == \"MUTAG\" && num > 188\n",
    "        return 188\n",
    "    else\n",
    "        return num\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4d300446",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"data/PC-3H/PC-3H_node_labels.txt\""
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nam = \"PC-3H\" #\"SW-620H\"\n",
    "intended_term = 400\n",
    "\n",
    "node_types = get_node_types(nam)\n",
    "term = get_term(nam, intended_term);\n",
    "max_coar_level = 8\n",
    "\n",
    "name = \"data/\" * nam * \"/\" * nam\n",
    "\n",
    "ad_file = name * \"_A.txt\"\n",
    "graph_ind_file = name * \"_graph_indicator.txt\"\n",
    "graph_label_file = name * \"_graph_labels.txt\"\n",
    "node_label_file = name * \"_node_labels.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "12d6d3b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"coarsened_eq1\""
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coar_folder_name = \"coarsened_eq1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9d65e94a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# nam = \"MUTAG\" #\"PC-3H\" #\"SW-620H\"\n",
    "# node_types = 7 #66 for \"SW-620H\", 46 for \"PC-3H\", 7 for \"MUTAG\"\n",
    "# term = 188 # length(graph_label_mat) or 188 for MUTAG, 1000 or something for others\n",
    "\n",
    "# name = \"data/MUTAG\"\n",
    "# name = \"data/SW-620H/SW-620H\"\n",
    "# name = \"data/\" * nam * \"/\" * nam\n",
    "\n",
    "# ad_file = name * \"_A.txt\"\n",
    "# graph_ind_file = name * \"_graph_indicator.txt\"\n",
    "# graph_label_file = name * \"_graph_labels.txt\"\n",
    "# node_label_file = name * \"_node_labels.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c2be4ef3",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_ind_mat = open(graph_ind_file) do file\n",
    "    [parse(Int, line) for line in eachline(file)]\n",
    "end\n",
    "\n",
    "graph_label_mat = open(graph_label_file) do file\n",
    "    [parse(Int, line) for line in eachline(file)]\n",
    "end\n",
    "\n",
    "node_label_mat = open(node_label_file) do file\n",
    "    [parse(Int, line) for line in eachline(file)]\n",
    "end\n",
    "\n",
    "ad_mat = []\n",
    "open(ad_file) do file\n",
    "    for line in eachline(file)\n",
    "        temp = parse.(Int, split(line, ','))\n",
    "        push!(ad_mat, temp)\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6b95a2c8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "onehot (generic function with 1 method)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function onehot(vec, number)\n",
    "    mat = Matrix{Int}(undef, length(vec), number)\n",
    "    for i in eachindex(vec)\n",
    "        hotvec = zeros(Int, number)\n",
    "        hotvec[vec[i]+1] = 1\n",
    "        mat[i, :] = hotvec\n",
    "    end\n",
    "    return mat\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf06b3fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:58:58\u001b[39m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8786825074593354\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:57:28\u001b[39m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7638103015157351\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:51:55\u001b[39m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6948804891256706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:51:29\u001b[39m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6626227071878\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32mProgress: 100%|████████████████████████████████████████▉|  ETA: 0:00:06\u001b[39m"
     ]
    }
   ],
   "source": [
    "alpha = 1e7\n",
    "\n",
    "for L in 1:max_coar_level\n",
    "    \n",
    "    \n",
    "    # alpha = 1e7\n",
    "    # L = 2\n",
    "    sparsednodes = []\n",
    "    sparsedfeatures = []\n",
    "    before = Int[]\n",
    "    after = Int[]\n",
    "\n",
    "    @showprogress for graph_number in 1:term #length(graph_label_mat)\n",
    "\n",
    "        nodenumbers = findall(x -> x == graph_number, graph_ind_mat)\n",
    "        edgenumbers = Int[]\n",
    "        node2numbers = Int[]\n",
    "        node1numbers = Int[]\n",
    "        for (ind, row) in enumerate(ad_mat)\n",
    "            node1, node2 = row\n",
    "            if node1 in nodenumbers && node2 in nodenumbers\n",
    "                push!(edgenumbers, ind)\n",
    "                push!(node1numbers, node2)\n",
    "                push!(node2numbers, node1)\n",
    "            end\n",
    "        end\n",
    "        node_feature = [node_label_mat[node] for node in nodenumbers]\n",
    "        push!(before, length(node_feature))\n",
    "        node_feature_oh = onehot(node_feature, node_types)'\n",
    "        buffer = minimum(nodenumbers) - 1\n",
    "        node1numbers .- buffer\n",
    "\n",
    "        rr = copy(node1numbers .- buffer)\n",
    "        cc = copy(node2numbers .- buffer)\n",
    "        vv = ones(Float64, length(rr))\n",
    "        W = copy(vv)\n",
    "        mg = length(rr)\n",
    "        R = append!(copy(rr), copy(cc))\n",
    "        C = append!(copy(cc), copy(rr))\n",
    "        V = append!(copy(vv), copy(vv))\n",
    "        A = sparse(R, C, V)\n",
    "        NF = copy(node_feature_oh)\n",
    "        A_new= decomposition(A, L, NF, alpha)\n",
    "        i, j, v = findnz(A_new)\n",
    "        push!(after, length(i)/2)\n",
    "\n",
    "        newnode2numbers = j .+ buffer\n",
    "        newnodes = [j.-1, i.-1]\n",
    "        new_node_feature = [node_label_mat[node] for node in unique(newnode2numbers)]\n",
    "\n",
    "        push!(sparsednodes, newnodes)\n",
    "        push!(sparsedfeatures, new_node_feature)\n",
    "    end\n",
    "\n",
    "\n",
    "    println(mean(after./before))\n",
    "\n",
    "\n",
    "    #.          #= comment from here to here if not want to write\n",
    "\n",
    "    open(\"data/\"*nam*\"/\"*coar_folder_name*\"/\"*nam*\"_node_coer_lv_\"\n",
    "        *string(L)*\"_with_nf_first_\" *string(term)* \".txt\", \"w\") do f\n",
    "        for i in sparsednodes\n",
    "            writedlm(f, i, \",\") \n",
    "        end\n",
    "    end\n",
    "\n",
    "    # filename = \"data.txt\"\n",
    "    writedlm(\"data/\" * nam*\"/\" *coar_folder_name*\"/\"*nam * \"_feat_coer_lv_\"\n",
    "        *string(L)*\"_with_nf_first_\" *string(term)* \".txt\",\n",
    "        sparsedfeatures, '\\t')\n",
    "    \n",
    "    # comment from here to here if not want to write  =#\n",
    "    \n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4c814e9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.8.2",
   "language": "julia",
   "name": "julia-1.8"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
