{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "66ccff92",
   "metadata": {},
   "source": [
    "## Beam Search Experiments\n",
    "\n",
    "This notebooks explores preliminary experiments that use beam search together with multi-shot sampling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e72522f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../dataset\")\n",
    "sys.path.append(\"../../model\")\n",
    "sys.path.append(\"../../\")\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sampling import *\n",
    "\n",
    "from train_bgp import Model as GraphBgpModel\n",
    "from bgp_semantics import BgpSemantics\n",
    "from factbase import *\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import torch\n",
    "import argparse\n",
    "from snapshot import ModelSnapshot\n",
    "import os\n",
    "\n",
    "class SampleDescriptor:\n",
    "    def __init__(self, num_nodes, num_networks, program):\n",
    "        self.num_nodes = num_nodes\n",
    "        self.num_networks = num_networks\n",
    "        self.program = program"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "55725a82",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'router': router: Constant, 'network': network: Constant, 'external': external: Constant, 'route_reflector': route_reflector: Constant, 'ibgp': ibgp: Constant × Constant, 'ebgp': ebgp: Constant × Constant, 'bgp_route': bgp_route: Constant × Constant × int × int × int × int × int × int, 'connected': connected: Constant × Constant × int, 'fwd': fwd: Constant × Constant × Constant, 'reachable': reachable: Constant × Constant × Constant, 'trafficIsolation': trafficIsolation: Constant × Constant × Constant × Constant}\n",
      "using model at ../../trained-model/bgp-64-pred-6layers-model-epoch2800.pt\n",
      "model iterations 4\n"
     ]
    }
   ],
   "source": [
    "\n",
    "device = torch.device(\"cpu\") # torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "s = BgpSemantics()\n",
    "\n",
    "predicate_declarations = s.decls()\n",
    "print(predicate_declarations)\n",
    "prog = FactBase(predicate_declarations)\n",
    "feature = prog.feature_registry.feature\n",
    "\n",
    "excluded_feature_indices = set([1])\n",
    "features = prog.feature_registry.get_all_features()\n",
    "\n",
    "model = None\n",
    "NO_STATIC_ROUTES = True\n",
    "protocol = \"bgp\"\n",
    "\n",
    "model_path = \"../../trained-model/bgp-64-pred-6layers-model-epoch2800.pt\"\n",
    "state_dict, HIDDEN_DIM, NUM_EDGE_TYPES, excluded_feature_indices = torch.load(model_path, map_location=device)\n",
    "model = GraphBgpModel(features, HIDDEN_DIM, NUM_EDGE_TYPES, excluded_feature_indices).to(device)\n",
    "\n",
    "state_dict = convert_old_gat_conv_state_dict(state_dict)\n",
    "model.load_state_dict(state_dict)\n",
    "model.feature = feature\n",
    "print(\"using model at\", model_path)\n",
    "\n",
    "print(\"model iterations\", model.num_iterations)\n",
    "\n",
    "def mask_parameters(x, decls, with_prob_static_route=True, without_static_routes=NO_STATIC_ROUTES):\n",
    "    mask = torch.zeros_like(x)\n",
    "    \n",
    "    # predicate_connected_arg2 [weight]\n",
    "    mask[:,:,feature(\"predicate_connected_arg2\").idx] = (x[:,:,feature(\"predicate_connected_arg2\").idx] > -1)\n",
    "    \n",
    "    # bgp_route: gateway, network, LP, AS, OT, MED, IS_EBGP, SPEAKER_ID\n",
    "    # predicate_bgp_route_arg2 [LP]\n",
    "    # predicate_bgp_route_arg3 [AS]\n",
    "    # predicate_bgp_route_arg4 [OT]\n",
    "    # predicate_bgp_route_arg5 [MED], \n",
    "    # predicate_bgp_route_arg6 [IS_EBGP]\n",
    "    # predicate_bgp_route_arg7 [SPEAKER_ID]\n",
    "    if protocol == \"bgp\":\n",
    "        masked_bgp_route_args = [2,3,5]\n",
    "        for i in masked_bgp_route_args:\n",
    "            idx = feature(\"predicate_bgp_route_arg\"+str(i)).idx\n",
    "            mask[:,:,idx] = (x[:,:,idx] > -1)\n",
    "\n",
    "    return mask.bool()\n",
    "\n",
    "def sample_random_prediction(model, prediction_features, batched_data, mask):\n",
    "    r = torch.zeros_like(batched_data.x)\n",
    "    for f in prediction_features:\n",
    "        r[:,:,f.idx] = torch.randint(0, 32, size=[data.x.size(0), 1]).to(device)\n",
    "    r[:,:,feature(\"predicate_bgp_route_arg6\").idx] = torch.randint(0, 2, size=[data.x.size(0),1]).to(device)\n",
    "    \n",
    "    return mask * r + mask.logical_not() * batched_data.x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "4c271c7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['bgp-n0.logic', 'bgp-n1.logic', 'bgp-n2.logic', 'bgp-n3.logic', 'bgp-n4.logic', 'bgp-n5.logic', 'bgp-n6.logic', 'bgp-n7.logic', 'bgp-n8.logic', 'bgp-n9.logic', 'bgp-n10.logic', 'bgp-n11.logic', 'bgp-n12.logic', 'bgp-n13.logic', 'bgp-n14.logic', 'bgp-n15.logic', 'bgp-n16.logic', 'bgp-n17.logic', 'bgp-n18.logic', 'bgp-n19.logic', 'bgp-n20.logic', 'bgp-n21.logic', 'bgp-n22.logic', 'bgp-n23.logic']\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "def get_id(filename):\n",
    "    str_n = filename.split(\"-n\", 1)[1].split(\".\", 1)[0]\n",
    "    if \"-unsatsample\" in str_n:\n",
    "        str_n = str_n.split(\"-unsatsample\", 1)[0]\n",
    "    return int(str_n)\n",
    "\n",
    "dataset = \"../consistency/dataset-ported/bgp-qlty-reqs-16/\"\n",
    "files = [f for f in os.listdir(dataset) if f.endswith(\".logic\")]\n",
    "files = sorted(files, key=lambda x: get_id(x))\n",
    "print(files)\n",
    "programs = [torch.load(os.path.join(dataset, f)) for f in files]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "95c637be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aef4bef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from model.beam import beam_search\n",
    "\n",
    "num_samples = 1\n",
    "random = True\n",
    "num_shots = 4\n",
    "sampling_mode = \"topk\"\n",
    "\n",
    "descriptor = programs[0]\n",
    "\n",
    "if type(descriptor) is Data or type(descriptor) is dict:\n",
    "    descriptor = SampleDescriptor(0, 0, FactBase.from_data(descriptor))\n",
    "data, names = descriptor.program.to_torch_data(return_node_names=True)\n",
    "\n",
    "prediction_features = [\n",
    "    feature(\"predicate_connected_arg2\"),  # OSPF weights\n",
    "    # bgp_route: LP x AS x -OT x MED x -IS_EBGP x -SPEAKER_ID\n",
    "    feature(\"predicate_bgp_route_arg2\"),  # BGP LP\n",
    "    feature(\"predicate_bgp_route_arg3\"), # BGP AS\n",
    "    #feature(\"predicate_bgp_route_arg4\"), # BGP ORIGIN_TYPE\n",
    "    feature(\"predicate_bgp_route_arg5\"), # BGP MED\n",
    "    #feature(\"predicate_bgp_route_arg6\"), # BGP IS_EBGP\n",
    "    #feature(\"predicate_bgp_route_arg7\") # SPEAKER_ID\n",
    "]\n",
    "\n",
    "batched_data = data.clone().to(device)\n",
    "batched_data.x = batched_data.x.unsqueeze(1)\n",
    "batched_data.edge_index = reflexive(bidirectional(batched_data.edge_index), num_nodes=batched_data.x.size(0))\n",
    "batched_data.edge_type = reflexive_bidirectional_edge_type(batched_data.edge_type, batched_data.x.size(0))\n",
    "mask = mask_parameters(batched_data.x, predicate_declarations).to(device)\n",
    "\n",
    "best_consistency = 0\n",
    "\n",
    "data.x = beam_search(model, prediction_features, batched_data, mask, iterative=True, \n",
    "    number_of_shots=num_shots, inverted=False, mode=sampling_mode, beam_n=128, beam_k=8)[:,0]\n",
    "\n",
    "timeelapsed = time.time() - tstart\n",
    "predicted_program = FactBase.from_data(data, decls=predicate_declarations, names=names)\n",
    "consistency, summary = s.check(predicted_program, return_summary=True)\n",
    "best_consistency = max(consistency, best_consistency)\n",
    "\n",
    "def get_value(k):\n",
    "    if k in summary.keys(): return summary[k]\n",
    "    else: return 1.0\n",
    "\n",
    "num_nodes = len(descriptor.program.constants(\"router\")) + len(descriptor.program.constants(\"route_reflector\"))\n",
    "num_networks = len(descriptor.program.constants(\"network\"))\n",
    "\n",
    "print(\"Consistency %0.2f (best %0.2f) (Nodes %d, Sample %d)\" % (consistency, best_consistency, num_nodes, j))\n",
    "#if args.num_shots == 1: break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
