{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#%matplotlib notebook\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pickle\n",
    "import json\n",
    "from scipy.stats import wilcoxon\n",
    "from scipy.optimize import least_squares\n",
    "import sys\n",
    "from data_tools import load_data\n",
    "\n",
    "MODELFILE = 'test_pds.p'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# load estimated score matrix\n",
    "data = pickle.load(open(MODELFILE, 'rb'))\n",
    "est_scores = data['scores']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# load true score matrix\n",
    "tru_scores, labels = load_data('sim.mat', feature_list=['scores'])\n",
    "tru_scores = np.squeeze(np.asarray(tru_scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import permutations\n",
    "\n",
    "def scale_scores(caus_scores, tru_scores):\n",
    "    f_perm = list(permutations([0, 1, 2]))\n",
    "\n",
    "    best_err = np.inf\n",
    "    for k, p in enumerate(f_perm):\n",
    "        these_scores = caus_scores[:,p]\n",
    "        caus_resid = lambda scaling : (these_scores*scaling - tru_scores).flat\n",
    "        res = least_squares(caus_resid, np.ones(3))\n",
    "        if res.cost < best_err:\n",
    "            best_err = res.cost\n",
    "            caus_scale = res.x\n",
    "            best_perm = p\n",
    "\n",
    "    caus_est = caus_scores[:,best_perm]*caus_scale\n",
    "    return caus_est    \n",
    "    \n",
    "scaled_est = scale_scores(est_scores, tru_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Factor 1 - R2=0.96197 low=0.92516 high=0.98085\n",
      "Factor 2 - R2=0.94986 low=0.90206 high=0.97464\n",
      "Factor 3 - R2=0.89731 low=0.80552 high=0.94705\n"
     ]
    }
   ],
   "source": [
    "# pairwise comparison between absolute residuals\n",
    "from scipy.stats import spearmanr, norm\n",
    "from math import tanh, atanh\n",
    "\n",
    "def spear_CI(r2, n, alpha=0.05):\n",
    "    z = norm().ppf(1 - alpha/2)\n",
    "    st_err = np.sqrt((1+r2/2)/(n-3))\n",
    "    low = tanh(atanh(r2) - z*st_err)\n",
    "    up = tanh(atanh(r2) + z*st_err)\n",
    "    return (low, up)\n",
    "    \n",
    "for N in range(3):\n",
    "    this_est = scaled_est[:,N]\n",
    "    s, _ = spearmanr(this_est, tru_scores[:,N])\n",
    "    l, u = spear_CI(s, len(this_est))\n",
    "    \n",
    "    print(f\"Factor {N+1:d} - R2={s:.5f} low={l:.5f} high={u:.5f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
