{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "import junctiontree as jt\n",
    "from tests.util import assert_potentials_equal\n",
    "from junctiontree import computation as comp\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_list = [\n",
    "            [\"A\",\"D\",\"E\"],\n",
    "            [\"A\",\"B\",\"D\"],\n",
    "            [\"D\",\"E\",\"F\"],\n",
    "            [\"A\",\"C\",\"E\"],\n",
    "            [\"C\",\"E\",\"G\"],\n",
    "            [\"E\",\"G\",\"H\"],\n",
    "            [\"A\",\"D\"],\n",
    "            [\"D\",\"E\"],\n",
    "            [\"A\",\"E\"],\n",
    "            [\"C\",\"E\"],\n",
    "            [\"E\",\"G\"],\n",
    "]\n",
    "\n",
    "var_sizes = {\n",
    "            \"A\": 2,\n",
    "            \"B\": 2,\n",
    "            \"C\": 2,\n",
    "            \"D\": 2,\n",
    "            \"E\": 2,\n",
    "            \"F\": 2,\n",
    "            \"G\": 2,\n",
    "            \"H\": 2\n",
    "        }\n",
    "\n",
    "factors = [\n",
    "    [\"A\"],\n",
    "    [\"A\", \"B\"],\n",
    "    [\"A\", \"C\"],\n",
    "    [\"B\", \"D\"],\n",
    "    [\"C\", \"E\"],\n",
    "    [\"C\", \"G\"],\n",
    "    [\"D\", \"E\", \"F\"],\n",
    "    [\"E\", \"G\", \"H\"],\n",
    "\n",
    "]\n",
    "\n",
    "values = [\n",
    "        np.array([0.5,0.5]),\n",
    "        np.array([[0.6,0.4],\n",
    "                   [0.5,0.5]]),\n",
    "        np.array([[0.8,0.2],\n",
    "                  [0.3,0.7]]),\n",
    "        np.array([[0.5,0.5],\n",
    "                    [0.1,0.9]]),\n",
    "        np.array([[0.4,0.6],\n",
    "                  [0.7,0.3]]),\n",
    "        np.array([[0.9,0.1],\n",
    "                  [0.8,0.2]]),\n",
    "        np.array([[[0.01,0.99],\n",
    "                   [0.99,0.01]],\n",
    "                  [[0.99,0.01],\n",
    "                   [0.99,0.01]]]),\n",
    "        np.array([[[0.05,0.95],\n",
    "                   [0.05,0.95]],\n",
    "                  [[0.05,0.95],\n",
    "                   [0.95,0.05]]])\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "_tree = jt.create_junction_tree(factors, var_sizes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3, (9, [4, (10, [5, (6, [0]), (8, [2, (7, [1])])])])]\n",
      "[['E', 'D'], ['E', 'G'], ['E', 'C'], ['B', 'C'], ['D', 'C']]\n"
     ]
    }
   ],
   "source": [
    "print(_tree.tree)\n",
    "print(_tree.separators)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FactorGraph(factors=[['A'], ['A', 'B'], ['A', 'C'], ['B', 'D'], ['C', 'E'], ['C', 'G'], ['D', 'E', 'F'], ['E', 'G', 'H']], sizes={'A': 2, 'B': 2, 'C': 2, 'D': 2, 'E': 2, 'F': 2, 'G': 2, 'H': 2})\n",
      "[3, 3, 3, 4, 5, 2, 0, 1]\n"
     ]
    }
   ],
   "source": [
    "# print(_tree.clique_tree)\n",
    "print(_tree.clique_tree.factor_graph)\n",
    "print(_tree.clique_tree.factor_to_maxclique)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "prop_values = _tree.propagate(values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([0.5, 0.5]),\n",
       " array([[0.3 , 0.2 ],\n",
       "        [0.25, 0.25]]),\n",
       " array([[0.4 , 0.1 ],\n",
       "        [0.15, 0.35]]),\n",
       " array([[0.275, 0.275],\n",
       "        [0.045, 0.405]]),\n",
       " array([[0.22 , 0.33 ],\n",
       "        [0.315, 0.135]]),\n",
       " array([[0.495, 0.055],\n",
       "        [0.36 , 0.09 ]]),\n",
       " array([[[0.001697, 0.168003],\n",
       "         [0.148797, 0.001503]],\n",
       " \n",
       "        [[0.361647, 0.003653],\n",
       "         [0.311553, 0.003147]]]),\n",
       " array([[[0.0225 , 0.4275 ],\n",
       "         [0.00425, 0.08075]],\n",
       " \n",
       "        [[0.02025, 0.38475],\n",
       "         [0.057  , 0.003  ]]])]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prop_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "sizes = _tree.clique_tree.factor_graph.sizes\n",
    "separator_values= [\n",
    "    np.ones(tuple(sizes[var] for var in separator))\n",
    "    for separator in _tree.separators\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[1., 1.],\n",
       "        [1., 1.]]),\n",
       " array([[1., 1.],\n",
       "        [1., 1.]]),\n",
       " array([[1., 1.],\n",
       "        [1., 1.]]),\n",
       " array([[1., 1.],\n",
       "        [1., 1.]]),\n",
       " array([[1., 1.],\n",
       "        [1., 1.]])]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "separator_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_cliques = _tree.clique_tree.maxcliques"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['D', 'E', 'F'],\n",
       " ['E', 'G', 'H'],\n",
       " ['C', 'E', 'G'],\n",
       " ['A', 'B', 'C'],\n",
       " ['B', 'C', 'D'],\n",
       " ['C', 'D', 'E']]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_cliques"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
