{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f807d8b3",
   "metadata": {},
   "source": [
    "# Graph Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e30387b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import networkx as nx\n",
    "import math\n",
    "import time\n",
    "import datetime\n",
    "import os\n",
    "import pickle as pkl\n",
    "import seaborn as sns\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, '../')\n",
    "import utils\n",
    "from ahk import AHK_graphon\n",
    "\n",
    "from matplotlib import pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a3f1910",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loading training data\n",
    "\n",
    "def load(name,trainfilepath):\n",
    "    assert name in [\"ego\",\"comm\"]\n",
    "\n",
    "    A_train, A_val, A_test = np.load(trainfilepath,allow_pickle=True)\n",
    "\n",
    "    train, val, test = [],[],[]\n",
    "\n",
    "    for A in A_train:\n",
    "        train.append(nx.from_numpy_array(A))\n",
    "    for A in A_val:\n",
    "        val.append(nx.from_numpy_array(A))\n",
    "    for A in A_test:\n",
    "        test.append(nx.from_numpy_array(A))\n",
    "        \n",
    "    return train,val,test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e4e0278",
   "metadata": {},
   "outputs": [],
   "source": [
    "name=\"ego\"\n",
    "\n",
    "datadir=\"../dataset/\"+name+\"/\"\n",
    "\n",
    "filepath=datadir+name+\"_train_val_test.npy\"\n",
    "\n",
    "train,val,test=load(name,filepath)\n",
    "\n",
    "print(\"{} train {} val {} test graphs\".format(len(train),len(val),len(test)))\n",
    "sizes=np.array(list(G.number_of_nodes() for G in train))\n",
    "print(\"Max train size: \", np.max(sizes), \"Min train size: \", np.min(sizes))\n",
    "\n",
    "traindata=utils.batch_nx_to_world(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c5324ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Learning settings:\n",
    "\n",
    "# Initialize model\n",
    "binbounds=utils.uni_bins(1)\n",
    "learnmodel=AHK_graphon(traindata[0].sig,binbounds)\n",
    "\n",
    "settings={}\n",
    "\n",
    "settings['num_pi_b']=50\n",
    "settings['batchsize']=5\n",
    "settings['soft']=0.001\n",
    "settings['numepochs']=50\n",
    "settings['early_stop']=3 #Number of epochs with no log-likelihood improvement required for early stopping\n",
    "settings['bingain']=0.01 #the factor by which a bin refinement has improved log-likelihood in order\n",
    "                         #to continue bin refinements\n",
    "settings['learn_bins']=False\n",
    "settings['with_trace']=False\n",
    "settings['randombatches']=False\n",
    "settings['adaptbatchsize']=False\n",
    "settings['ubias']=0.0\n",
    "settings['savepath']='../Experiments/DAG/'\n",
    "\n",
    "#Adam params:\n",
    "settings['ad_alpha']=0.01\n",
    "settings['ad_beta1']=0.9\n",
    "settings['ad_beta2']=0.999\n",
    "settings['ad_epsilon']=1e-8\n",
    "\n",
    "settings['seed']=0\n",
    "settings['method']=\"adam\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16847b2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# learn:\n",
    "rng=np.random.default_rng(seed=settings['seed'])\n",
    "learnmodel.rand_init(rng)\n",
    "best,loglik,_=learnmodel.learn(settings,traindata,rng,exact_gradients=False,info_each_epoch=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ee01d8d",
   "metadata": {},
   "source": [
    "# Generate graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24ac43cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify a range of target sample sizes,\n",
    "# baseline: the empirical size distribution in the test set:\n",
    "sizes=np.array(list(G.number_of_nodes() for G in test))\n",
    "\n",
    "# scale sizes:\n",
    "scalefactor=1.5\n",
    "sizes=scalefactor*sizes\n",
    "\n",
    "print(\"Target sizes: avg:\", np.average(sizes), \"Min: \", np.min(sizes), \"Max: \", np.max(sizes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8661f6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate and plot:\n",
    "numsamples=10\n",
    "for n in range(numsamples):\n",
    "    w=best.sample_world(int(rng.choice(sizes)),rng)\n",
    "    w_nx=w.to_nx()\n",
    "    nx.draw_networkx(w_nx)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23b5facf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
