{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "from opt import *\n",
    "import pickle\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#load the data \n",
    "with open('workspace_v1.pkl', 'rb') as file:\n",
    "    lw = pickle.load(file)\n",
    "\n",
    "# # Access the variables from the loaded workspace\n",
    "All_A_hat = lw['All_A_hat']\n",
    "All_A_hat_bin = lw['All_A_hat_bin']\n",
    "est_error = lw['est_error']\n",
    "est_fsc = lw['est_fsc']\n",
    "est_bias = lw['est_bias']\n",
    "est_bias_bin = lw['est_bias_bin']\n",
    "model_fit = lw['model_fit']\n",
    "true_bias = lw['true_bias']\n",
    "true_model_fit = lw['true_model_fit']\n",
    "dnames = lw['dnames']\n",
    "models = lw['models']\n",
    "mus2 = lw['mus2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GLASSO\n",
      "0.2661 & 0.3111 & 7.1250      & 0.1995 & 12.6102 & 0.0436      & 0.1121 & 0.9529 & 0.0223      & 0.6477 & 0.4068 & 51.0340\n",
      "FGLASSO_mu20\n",
      "0.1417 & 0.4825 & 6.8099      & 0.1895 & 11.8476 & 0.0421      & 0.1119 & 0.8177 & 0.0253      & 0.0505 & 0.1657 & 50.4863\n",
      "FGLASSO_dp\n",
      "0.1417 & 0.4824 & 6.8099      & 0.1896 & 10.3317 & 0.0422      & 0.1119 & 0.8177 & 0.0253      & 0.0505 & 0.1657 & 50.4863\n",
      "FGLASSO_nw\n",
      "0.1417 & 0.4824 & 6.8099      & 0.1895 & 11.8391 & 0.0421      & 0.1119 & 0.8177 & 0.0253      & 0.0505 & 0.1657 & 50.4863\n",
      "FGSR_dp\n",
      "0.4383 & 0.8942 & 7.0651      & 0.4188 & 0.5754 & 1.4283      & 1.0260 & 4.0568 & 0.1724      & 1.0606 & 0.3787 & 51.0690\n",
      "FGSR_nw\n",
      "0.4386 & 0.8924 & 7.0652      & 0.4068 & 6.6887 & 1.4221      & 1.0260 & 4.0568 & 0.1724      & 1.1149 & 0.3734 & 51.0663\n"
     ]
    }
   ],
   "source": [
    "def print_vector_elements(res):\n",
    "    # Ensure res is a flat array\n",
    "    res_flat = res.flatten()\n",
    "    \n",
    "    # Create a formatted output string with six spaces after each group of three elements\n",
    "    output = \"\"\n",
    "    for i in range(len(res_flat)):\n",
    "        if (i + 1) % 3 == 0 and i != len(res_flat) - 1:\n",
    "            output += f\"{res_flat[i]:.4f}      & \"\n",
    "        else:\n",
    "            output += f\"{res_flat[i]:.4f} & \"\n",
    "    \n",
    "    # Remove the trailing ampersand and space\n",
    "    output = output.rstrip(\" & \")\n",
    "    \n",
    "    # Print the formatted string\n",
    "    np.set_printoptions(suppress=True,precision=4)\n",
    "    print(output)\n",
    "\n",
    "#get results for each model \n",
    "mdl_idx = 0 \n",
    "mu2_idx = 0 # for mu = 1 and mu = 1e6\n",
    "nmodels=len(models)\n",
    "res = np.zeros((12,nmodels))\n",
    "for k, dn in enumerate(dnames):\n",
    "    idx = 3*k \n",
    "    for m, model in enumerate(models):\n",
    "        res[idx,m] = est_error[k,m,mu2_idx]\n",
    "        res[idx+1,m] = est_bias[k,m,mu2_idx,4]\n",
    "        res[idx+2,m] = model_fit[k,m,mu2_idx]     \n",
    "\n",
    "\n",
    "\n",
    "for m, model in enumerate(models):\n",
    "    print(model)\n",
    "    print_vector_elements(res[:,m])\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.2679 0.3454]\n",
      " [0.2379 0.2921]\n",
      " [0.1306 0.1388]\n",
      " [0.6494 0.6646]]\n",
      "[[ 0.331   0.4518]\n",
      " [12.6107 12.5812]\n",
      " [ 0.9425  0.9171]\n",
      " [ 0.4306  0.4072]]\n",
      "GL rw 150\n",
      "0.2679 & 0.3310 & 7.0282      & 0.2379 & 12.6107 & 1.8449      & 0.1306 & 0.9425 & 2.1914      & 0.6494 & 0.4306 & 51.0369\n",
      "GL rw 300\n",
      "0.3454 & 0.4518 & 7.0327      & 0.2921 & 12.5812 & 1.8302      & 0.1388 & 0.9171 & 2.1814      & 0.6646 & 0.4072 & 51.0378\n"
     ]
    }
   ],
   "source": [
    "# Compute the metrics for the rewiting baselines\n",
    "# Define rewiring \n",
    "rw1 = 150\n",
    "rw2 = 300\n",
    "\n",
    "nd = len(dnames)\n",
    "nmetrics = 5\n",
    "\n",
    "rw_er = np.zeros((nd, 2))\n",
    "mf = np.zeros((nd, 2))\n",
    "rw_bias = np.zeros((nd, 2, nmetrics)) # 4 bias metrics\n",
    "results = np.zeros((12,2))\n",
    "m = 0\n",
    "\n",
    "for k, dn in enumerate(dnames):\n",
    "    dAhat = All_A_hat[dnames[k]]\n",
    "    dAhat_bin = All_A_hat_bin[dnames[k]]\n",
    "    A_true, A_true_bin, C_est, C_est_norm, Z, z = load_datasets(dn)\n",
    "    N = A_true.shape[0]\n",
    "    #extract each estimated matrix and then rewire links \n",
    "    A = dAhat[k,0,0]\n",
    "    Arw1 = rewire_precmat(A, rewire=rw1, replace=False)\n",
    "    Arw2 = rewire_precmat(A, rewire=rw2, replace=False)\n",
    "\n",
    "    #bias,error,modelfit  [0,1,2] [3,4,5]\n",
    "    rw_er[k,0] = compute_frob_err(Arw1, A_true)\n",
    "    rw_bias[k,0,:] = compute_all_bias(Arw1,Z)\n",
    "    rw_er[k,1] = compute_frob_err(Arw2, A_true)\n",
    "    rw_bias[k,1,:] = compute_all_bias(Arw2,Z)\n",
    "    mf[k,0] = np.linalg.norm(Arw1 @ C_est_norm - np.eye(N), 'fro')/np.linalg.norm(C_est_norm, 'fro')\n",
    "    mf[k,1] = np.linalg.norm(Arw2 @ C_est_norm - np.eye(N), 'fro')/np.linalg.norm(C_est_norm, 'fro')\n",
    "\n",
    "    idx = 3*k        \n",
    "    results[idx,0] = rw_er[k,0]\n",
    "    results[idx+1,0] = rw_bias[k,0,4]\n",
    "    results[idx+2,0] = mf[k,0] \n",
    "\n",
    "    results[idx,1] = rw_er[k,1]\n",
    "    results[idx+1,1] = rw_bias[k,1,4]\n",
    "    results[idx+2,1] = mf[k,1] \n",
    "print(rw_er)\n",
    "print(rw_bias[:,:,4])\n",
    "\n",
    "print('GL rw', rw1)\n",
    "print_vector_elements(results[:,0])\n",
    "print('GL rw', rw2)\n",
    "print_vector_elements(results[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "schoolCs4 : Nodes= [126.] , Edges= [959.] , Groups= [2.] , Signals= 28561 , Sensitive atribute =  Gender\n",
      "coautorship130 : Nodes= [130.] , Edges= [525.] , Groups= [6.] , Signals= 1903 , Sensitive atribute =  Conference\n",
      "movielens : Nodes= [200.] , Edges= [665.] , Groups= [2.] , Signals= 943 , Sensitive atribute =  Year\n",
      "contact311Cs4 : Nodes= [311.] , Edges= [1009.] , Groups= [2.] , Signals= 47127 , Sensitive atribute =  Gender\n"
     ]
    }
   ],
   "source": [
    "n_nodes = np.zeros((4,1))\n",
    "n_edges = np.zeros((4,1))\n",
    "n_groups = np.zeros((4,1))\n",
    "n_signals = [28561,1903,943,47127]\n",
    "sens_Atr= ['Gender','Conference','Year','Gender']\n",
    "for k, dn in enumerate(dnames):\n",
    "    A_true, A_true_bin, C_est, C_est_norm, Z, z = load_datasets(dn)\n",
    "    n_nodes[k] = A_true.shape[0]\n",
    "    n_edges[k] = np.sum(A_true_bin)/2\n",
    "    n_groups[k] = Z.shape[0]\n",
    "    print(dn,': Nodes=',n_nodes[k], ', Edges=', n_edges[k], ', Groups=',n_groups[k], ', Signals=', n_signals[k], ', Sensitive atribute = ', sens_Atr[k])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
