{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:29.168545Z",
     "start_time": "2024-11-10T14:30:28.549467Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import sys\n",
    "import torch\n",
    "sys.path.append('../')"
   ],
   "id": "95007b1ccc1a771f",
   "execution_count": 1,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Reference",
   "id": "7936179103d86f8d"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:29.173072Z",
     "start_time": "2024-11-10T14:30:29.169357Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from humancompatible.interconnect.simulators.nodes.reference import ReferenceSignal\n",
    "\n",
    "refsig = ReferenceSignal(name=\"r\")\n",
    "refsig.set_reference_signal(300)"
   ],
   "id": "db38facef45a843",
   "execution_count": 2,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Aggregator",
   "id": "b2496bda77ac2b7"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:29.176760Z",
     "start_time": "2024-11-10T14:30:29.173636Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from humancompatible.interconnect.simulators.nodes.aggregator import Aggregator\n",
    "\n",
    "class AggregatorLogic1:\n",
    "    def __init__(self):\n",
    "        self.tensors = []\n",
    "        self.result = None\n",
    "        \n",
    "    def forward(self, values):\n",
    "        if type(values) is not list:\n",
    "            values = [values]\n",
    "        self.tensors = values\n",
    "        result = torch.sum(torch.stack([torch.sum(t.detach().clone().requires_grad_()) for t in self.tensors])).unsqueeze(dim=0)\n",
    "        self.result = result\n",
    "        return result\n",
    "\n",
    "class AggregatorLogic2:\n",
    "    def __init__(self):\n",
    "        self.tensors = []\n",
    "        \n",
    "    def forward(self, values):\n",
    "        if type(values) is not list:\n",
    "            values = [values]\n",
    "        self.tensors = values\n",
    "        result = torch.sum(torch.stack([torch.sum(t) for t in self.tensors])).unsqueeze(dim=0) \n",
    "        return result\n",
    "    \n",
    "agg1 = Aggregator(name=\"A1\", logic=AggregatorLogic1())\n",
    "agg2 = Aggregator(name=\"A2\", logic=AggregatorLogic2())"
   ],
   "id": "7c781b720e6b7099",
   "execution_count": 3,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Controller",
   "id": "53cd141aae7c6c4f"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:29.180716Z",
     "start_time": "2024-11-10T14:30:29.177695Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from humancompatible.interconnect.simulators.nodes.controller import Controller\n",
    "\n",
    "class PiControllerLogic(torch.nn.Module):\n",
    "    def __init__(self, kp, ki):\n",
    "        super().__init__()\n",
    "        self.tensors = {\n",
    "                        \"e\": torch.tensor([0.0], requires_grad=True),\n",
    "                        \"x\": torch.tensor([0.0], requires_grad=True)}\n",
    "        self.kp = torch.nn.Parameter(torch.tensor([kp], dtype=torch.float32))\n",
    "        self.ki = torch.nn.Parameter(torch.tensor([ki], dtype=torch.float32))\n",
    "        self.variables = [\"e\"]\n",
    "    \n",
    "    def forward(self, values):\n",
    "        self.tensors[\"e\"] = values[\"e\"]\n",
    "        result = ( (self.kp * self.tensors[\"e\"]) +\n",
    "                   (self.ki * (self.tensors[\"x\"] + self.tensors[\"e\"]) ) )\n",
    "        self.tensors[\"x\"] = self.tensors[\"x\"] + self.tensors[\"e\"]\n",
    "        return result\n",
    "\n",
    "cont = Controller(name=\"C\", logic=PiControllerLogic(kp=0.5, ki=0.5))\n",
    "# cont = Controller(name=\"C\", logic=PiControllerLogic(kp=0.5197, ki=0.8018))"
   ],
   "id": "4dd86f4ca900d720",
   "execution_count": 4,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Population",
   "id": "b90896499df0f11d"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:29.184730Z",
     "start_time": "2024-11-10T14:30:29.181219Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from humancompatible.interconnect.simulators.nodes.population import Population\n",
    "\n",
    "class AgentLogic:\n",
    "    def __init__(self, s_const1 = 1.0, s_const2=0.0):\n",
    "        self.tensors = {\"x\": torch.tensor([0.0], requires_grad=True),\n",
    "                        \"pi\": torch.tensor([0.0], requires_grad=True),\n",
    "                        \"s_const1\": torch.tensor([s_const1], requires_grad=True, dtype=torch.float),\n",
    "                        \"s_const2\": torch.tensor([s_const2], requires_grad=True, dtype=torch.float)}\n",
    "        self.variables = [\"p\"]\n",
    "    \n",
    "    def probability_function(self, x):\n",
    "        return self.tensors[\"s_const1\"] / (1 + torch.exp(-x+self.tensors[\"s_const2\"]))\n",
    "    \n",
    "    def f1(self):\n",
    "        return self.tensors[\"x\"] / 2\n",
    "    \n",
    "    def f2(self):\n",
    "        return self.tensors[\"x\"] * 2 + 1\n",
    "\n",
    "    def forward(self, values, number_of_agents):\n",
    "        self.tensors[\"p\"] = values[\"p\"]\n",
    "        #f1_ind = torch.bernoulli(torch.ones(1, 1)*0.5)\n",
    "        f1_part = torch.bernoulli(torch.ones(1, 1)*0.5) / number_of_agents\n",
    "        f2_part = 1 - f1_part\n",
    "        # print(f\"f1: {f1_part}, f2: {f2_part}, x: {self.tensors['x']}\\n\")\n",
    "        #self.tensors[\"x\"] = f1_part * self.f1(self.tensors[\"pi\"]) + f2_part * self.f2(self.tensors[\"pi\"])\n",
    "        # return f1_part *  self.tensors[\"s_const1\"] / (1 + torch.exp(-self.tensors[\"pi\"]+self.tensors[\"s_const2\"])) \\\n",
    "        # + f2_part *  self.tensors[\"s_const1\"] / (1.5 + torch.exp(-self.tensors[\"pi\"]+self.tensors[\"s_const2\"]))\n",
    "        # return f1_part *  torch.exp(-self.tensors[\"p\"]) + f2_part *  torch.exp(self.tensors[\"p\"])\n",
    "        # print(self.tensors[\"p\"])\n",
    "        #-self.tensors[\"p\"]))\n",
    "        # result = torch.sum(torch.bernoulli(torch.ones(1, number_of_agents)*torch.min(self.tensors[\"p\"], torch.ones(1,1)))) / number_of_agents\n",
    "        # result = self.tensors[\"p\"] + torch.bernoulli(torch.ones(1, 1)) / number_of_agents\n",
    "        result = f1_part * self.tensors[\"p\"] + (1 - f1_part) * (self.tensors[\"p\"] - 10.0)\n",
    "        return result\n",
    "        \n",
    "\n",
    "\n",
    "pop1 = Population(name=\"P1\",\n",
    "                  logic=AgentLogic(),\n",
    "                  number_of_agents=1000)\n",
    "\n",
    "pop2 = Population(name=\"P2\",\n",
    "                  logic=AgentLogic(s_const2=-1.0),\n",
    "                  number_of_agents=1000)"
   ],
   "id": "8c6facf8b142d7a1",
   "execution_count": 5,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Delay",
   "id": "b1fae83027206048"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:29.187257Z",
     "start_time": "2024-11-10T14:30:29.185336Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from humancompatible.interconnect.simulators.nodes.delay import Delay\n",
    "delay = Delay(name=\"Z\", time=1)"
   ],
   "id": "35b8ed327fb97a46",
   "execution_count": 6,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Filter",
   "id": "8f65b6469cec849a"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:29.190325Z",
     "start_time": "2024-11-10T14:30:29.187796Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from humancompatible.interconnect.simulators.nodes.filter import Filter\n",
    "\n",
    "class FilterLogic:\n",
    "    def __init__(self):\n",
    "        self.tensors = {\"S\": torch.tensor([0.0], requires_grad=True),\n",
    "                        \"K\": torch.tensor([2.0], requires_grad=True),}\n",
    "        self.variables = [\"S\"]\n",
    "        self.result = None\n",
    "    \n",
    "    def forward(self, values):\n",
    "        self.tensors[\"S\"] = values[\"S\"]\n",
    "        result = - self.tensors[\"S\"] / self.tensors[\"K\"]\n",
    "        self.result = result\n",
    "        # result = torch.tensor([result.item()], requires_grad=False)\n",
    "        return result\n",
    "\n",
    "fil = Filter(name=\"F\", logic=FilterLogic())"
   ],
   "id": "db495002a2734d06",
   "execution_count": 7,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Simulation",
   "id": "de2f0f9d2f24d0e6"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:29.825194Z",
     "start_time": "2024-11-10T14:30:29.190940Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from humancompatible.interconnect.simulators.simulation import Simulation\n",
    "sim = Simulation()\n",
    "\n",
    "sim.system.add_nodes([refsig, agg1, agg2, cont, pop1, pop2, delay, fil])\n",
    "sim.system.connect_nodes(refsig, agg1)\n",
    "sim.system.connect_nodes(agg1, cont)\n",
    "sim.system.connect_nodes(cont, pop1)\n",
    "sim.system.connect_nodes(cont, pop2)\n",
    "sim.system.connect_nodes(pop1, agg2)\n",
    "sim.system.connect_nodes(pop2, agg2)\n",
    "sim.system.connect_nodes(agg2, delay)\n",
    "sim.system.connect_nodes(delay, fil)\n",
    "sim.system.connect_nodes(fil, agg1)\n",
    "\n",
    "sim.system.set_start_node(refsig)\n",
    "sim.system.set_checkpoint_node(agg1)\n",
    "\n",
    "sim.plot.render_graph()"
   ],
   "id": "bdee5edf04695b85",
   "execution_count": 8,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:30.075492Z",
     "start_time": "2024-11-10T14:30:29.827570Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch.nn as nn\n",
    "model = sim.system.get_node(\"C\").logic\n",
    "sim.system.set_learning_model(model)\n",
    "sim.system.set_loss_function(nn.L1Loss())\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.05)\n",
    "sim.system.set_optimizer(optimizer)"
   ],
   "id": "539af94d9ee17d6a",
   "execution_count": 9,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:30.133732Z",
     "start_time": "2024-11-10T14:30:30.077173Z"
    }
   },
   "cell_type": "code",
   "source": "sim.plot.population_probabilities(xMin=-10, xMax=10)",
   "id": "9951462bbfb9f68d",
   "execution_count": 10,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:30.529843Z",
     "start_time": "2024-11-10T14:30:30.134281Z"
    }
   },
   "cell_type": "code",
   "source": [
    "sim.system.run(100, show_trace=False, show_loss=False)\n",
    "# sim.system.run(20)"
   ],
   "id": "c51ac3742a34fbce",
   "execution_count": 11,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:30.533405Z",
     "start_time": "2024-11-10T14:30:30.530565Z"
    }
   },
   "cell_type": "code",
   "source": [
    "cont.logic.eval()\n",
    "cont.logic.state_dict(), cont.logic.kp.grad, cont.logic.ki.grad"
   ],
   "id": "22c448148b7f00fa",
   "execution_count": 12,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:30.605849Z",
     "start_time": "2024-11-10T14:30:30.534208Z"
    }
   },
   "cell_type": "code",
   "source": "sim.plot.node_outputs(agg1)",
   "id": "f5730bd67b86f4e9",
   "execution_count": 13,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:30.814721Z",
     "start_time": "2024-11-10T14:30:30.608219Z"
    }
   },
   "cell_type": "code",
   "source": "sim.plot.runtimes()",
   "id": "73cb472fac4ac03c",
   "execution_count": 14,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:30.818493Z",
     "start_time": "2024-11-10T14:30:30.815280Z"
    }
   },
   "cell_type": "code",
   "source": "refsig.outputValue, agg1.outputValue, cont.outputValue, pop1.outputValue, agg2.outputValue, delay.outputValue, fil.outputValue",
   "id": "eb4123f56561ff9c",
   "execution_count": 15,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:30.822321Z",
     "start_time": "2024-11-10T14:30:30.819112Z"
    }
   },
   "cell_type": "code",
   "source": [
    "inputValue = (cont.logic.kp, cont.logic.ki, cont.logic.tensors[\"e\"])\n",
    "outputValue = fil.outputValue\n",
    "\n",
    "torch.autograd.grad(inputs=inputValue, outputs=outputValue)"
   ],
   "id": "5b33cf7422b806d2",
   "execution_count": 16,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:31.080025Z",
     "start_time": "2024-11-10T14:30:30.822852Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torchviz import make_dot\n",
    "\n",
    "graph = make_dot(fil.logic.result, params=dict(refsig=refsig.outputValue, \n",
    "                                          result=agg1.logic.result,\n",
    "                                          kp=cont.logic.kp,\n",
    "                                          ki=cont.logic.ki,\n",
    "                                          x=cont.logic.tensors[\"x\"],\n",
    "                                          s1=pop1.logic.tensors[\"s_const1\"],\n",
    "                                          s2=pop1.logic.tensors[\"s_const2\"],\n",
    "                                          K=fil.logic.tensors[\"K\"],\n",
    "                                          output=fil.logic.result))\n",
    "graph.render(\"computation_graph\", view=False)"
   ],
   "id": "b1f10fcc3019697",
   "execution_count": 17,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-10T14:30:31.082241Z",
     "start_time": "2024-11-10T14:30:31.080936Z"
    }
   },
   "cell_type": "code",
   "source": "",
   "id": "98c529854599a559",
   "execution_count": 17,
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
