{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from path_constant import project_root\n",
    "import os\n",
    "import torch\n",
    "path= f'{project_root}/XrayLLM/trained_models'"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "\n",
    "tvd_diff={}\n",
    "kl_diff={}\n",
    "for dist in ['joint']:\n",
    "    if os.path.exists(path + \"/tvd/\" + dist):\n",
    "        tvd_diff[dist] = torch.load(path + \"/tvd/\" + dist).tolist()\n",
    "        kl_diff[dist] = torch.load(path + \"/kl/\" + dist).tolist()\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "data": {
      "text/plain": "{'joint': [0.10440000146627426,\n  0.06239999830722809,\n  0.062199998646974564,\n  0.06949999928474426,\n  0.0885000005364418,\n  0.08089999854564667,\n  0.06260000169277191,\n  0.0544000007212162,\n  0.06019999831914902,\n  0.06729999929666519,\n  0.07970000058412552,\n  0.0625,\n  0.049300000071525574,\n  0.06589999794960022,\n  0.06889999657869339,\n  0.05009999871253967,\n  0.05990000069141388,\n  0.03779999911785126,\n  0.06729999929666519,\n  0.04430000111460686,\n  0.055399999022483826,\n  0.05510000139474869,\n  0.06310000270605087,\n  0.05869999900460243,\n  0.05480000004172325,\n  0.051899999380111694,\n  0.06440000236034393,\n  0.053700000047683716,\n  0.05380000174045563,\n  0.06069999933242798,\n  0.045499999076128006,\n  0.03920000046491623,\n  0.060100000351667404,\n  0.039000000804662704,\n  0.04050000011920929,\n  0.039000000804662704,\n  0.059700001031160355,\n  0.0430000014603138,\n  0.05620000138878822,\n  0.05139999836683273,\n  0.035999998450279236,\n  0.04580000042915344,\n  0.07980000227689743,\n  0.07490000128746033,\n  0.06019999831914902,\n  0.06669999659061432,\n  0.06459999829530716,\n  0.043800000101327896,\n  0.054999999701976776,\n  0.051100000739097595,\n  0.06539999693632126,\n  0.05559999868273735,\n  0.04010000079870224,\n  0.062199998646974564,\n  0.04729999974370003,\n  0.05770000070333481,\n  0.04919999837875366]}"
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tvd_diff"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(200, 16)\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "def dict_mean(ara):\n",
    "    ate_array= []\n",
    "    for row in ara:\n",
    "        ret = np.array(list(row.values()))\n",
    "        ret= ret.reshape(1, -1)\n",
    "        ate_array.append(ret)\n",
    "\n",
    "    ate_array= np.concatenate(ate_array)\n",
    "    print(ate_array.shape)\n",
    "\n",
    "    combs= list(ara[0].keys())\n",
    "    mean= np.mean(ate_array, axis=0)\n",
    "    std= np.std(ate_array, axis=0)\n",
    "\n",
    "    mean_dict= dict(zip(combs, mean))\n",
    "    std_dict= dict(zip(combs, std))\n",
    "\n",
    "    mean_dict = dict(sorted(mean_dict.items(), key=lambda item: -item[1]))\n",
    "\n",
    "    result={}\n",
    "    for iter, comb in enumerate(mean_dict):\n",
    "        # print(comb, mean_dict[comb], std_dict[comb])\n",
    "        result[comb]=(mean_dict[comb], std_dict[comb])\n",
    "        if iter==4:\n",
    "            break\n",
    "\n",
    "    return result\n",
    "\n",
    "with open(path+ \"/intv_prob.pickle\", 'rb') as handle:\n",
    "    b = pickle.load(handle)\n",
    "\n",
    "result =dict_mean(b['Atelectasis'])\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(200, 16)\n"
     ]
    },
    {
     "data": {
      "text/plain": "{(1, 1, 0, 0): (0.45208349999999997, 0.017538707984056295),\n (1, 1, 0, 1): (0.25289100000000014, 0.013053571120578457),\n (1, 1, 1, 0): (0.21069549999999992, 0.01250053317862882),\n (1, 1, 1, 1): (0.08433000000000003, 0.010563739868058093),\n (0, 0, 0, 0): (9.999999999999978e-07, 2.117582368135751e-21)}"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result =dict_mean(b['Pleural Effusion'])\n",
    "result"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "data": {
      "text/plain": "dict_keys(['Pleural Effusion', 'Atelectasis'])"
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.keys()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
