{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import pickle as pkl\n",
    "from glob import glob\n",
    "import os\n",
    "import argparse\n",
    "import regex as re\n",
    "import pandas as pd\n",
    "import statistics\n",
    "import copy\n",
    "\n",
    "from MisInfoSpread import MisInfoSpread\n",
    "from MisInfoSpread import MisInfoSpreadState\n",
    "\n",
    "def flatten(state):\n",
    "    return [val * i for val, adj in zip(state.node_states, state.adjacency_matrix) for i in adj]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_inference( dataset_path, model_path, nodes, max_steps, st, count_inf, count_actions):\n",
    "    \n",
    "    # device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n",
    "    print(f\"Using device: {device}\")\n",
    "\n",
    "    misinfo = MisInfoSpread(num_nodes=nodes, max_time_steps=max_steps, \n",
    "                        trust_on_source=st, count_infected_nodes=count_inf, \n",
    "                        count_actions=count_actions)\n",
    "\n",
    "    model = misinfo.get_nnet_model().to(device)\n",
    "    model.load_state_dict(torch.load(model_path , map_location=torch.device(device)))\n",
    "\n",
    "    states = pkl.load(open(dataset_path, 'rb'))\n",
    "    for state in states:\n",
    "        state.node_states = [int(x) for x in state.node_states]\n",
    "        for i in range(len(state.adjacency_matrix)):\n",
    "            for j in range(len(state.adjacency_matrix[i])):\n",
    "                if state.adjacency_matrix[i][j] != 0:\n",
    "                    state.adjacency_matrix[i][j] = 1\n",
    "                else:\n",
    "                    state.adjacency_matrix[i][j] = 0\n",
    "\n",
    "    candidate_nodes = misinfo.find_neighbor_batch(states)\n",
    "\n",
    "    actions_dict = {i: [] for i in range(len(states))}\n",
    "\n",
    "    while any(candidate_node for candidate_node in candidate_nodes):\n",
    "        blockernode_np = []\n",
    "        count = 0\n",
    "        for state, cand_nodes in zip(states, candidate_nodes):\n",
    "            print(\"Processing states \", count, end='\\r')\n",
    "            if cand_nodes:\n",
    "                expectation_values = []\n",
    "                for cand_node in cand_nodes:\n",
    "                    temp_ns, _, _ = misinfo.step(copy.deepcopy(state), [cand_node])\n",
    "                    output_tensor = torch.FloatTensor(flatten(temp_ns)).view(1, -1).to(device)\n",
    "                    expected_infection = model(output_tensor).detach().cpu().numpy()\n",
    "                    expectation_values.append( (expected_infection, cand_node) )\n",
    "\n",
    "                # sort the expectation values based on the expected infection\n",
    "                expectation_values.sort(key=lambda x: x[0], reverse=True)\n",
    "\n",
    "                if len(expectation_values) < count_actions:\n",
    "                    blockernode_np.append([node for _, node in expectation_values])\n",
    "                    actions_dict[count].append([node for _, node in expectation_values])\n",
    "                else:\n",
    "                    blockernode_np.append([node for _, node in expectation_values[:count_actions]])\n",
    "                    actions_dict[count].append([node for _, node in expectation_values[:count_actions]])\n",
    "            else:\n",
    "                blockernode_np.append([])\n",
    "\n",
    "            count += 1\n",
    "        next_states, rewards, done = misinfo.step_batch(states, blockernode_np)\n",
    "        states = next_states\n",
    "        candidate_nodes = misinfo.find_neighbor_batch(states)\n",
    "        # print count of dones\n",
    "        print(\"Done: \", done.count(True), \" \"*20)\n",
    "        if all(done):\n",
    "            break\n",
    "    \n",
    "    inf_rate = []\n",
    "    for state in states:\n",
    "        inf_rate.append(state.node_states.count(-1.0)/len(state.node_states))\n",
    "\n",
    "    mean = round(statistics.mean(inf_rate), 4)\n",
    "    std_dev = round(statistics.stdev(inf_rate), 4)\n",
    "    print(f\"Mean: {mean}, Std Dev: {std_dev}\")\n",
    "    return mean, std_dev, actions_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: mps\n",
      "Done:  62                     \n",
      "Done:  586                     \n",
      "Done:  985                     \n",
      "Done:  1000                     \n",
      "Mean: 0.5004, Std Dev: 0.119\n"
     ]
    }
   ],
   "source": [
    "model = \"saved_models/target_model_1_1_mn10_ms50_st1.0.pt\"\n",
    "dataset_path = \"/Users/bittu/Desktop/InfoSpread-server/dataset/generate_dataset/deg_dataset/10/10_deg_3.pkl\"\n",
    "\n",
    "mean, std_dev, actions_dict = run_inference(dataset_path, model, 10, 50, 1.0, 1, 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, item in actions_dict.items():\n",
    "    actions_dict[key] = [action_item[0] for action_item in item]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# save actions dict as json file with intend\n",
    "\n",
    "with open(\"actions_r3.json\", \"w\") as outfile: \n",
    "    json.dump(actions_dict, outfile, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl_infoSpread",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
