{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Count vs SP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "#### SP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "PATH = Path(\".../evaluate_models/CORRECTIONS/preds/nd_preds\")\n",
    "\n",
    "\n",
    "def read_stuff(full_path_filename, extra_cl=[]):\n",
    "    cols = [\"G0\", \"G1\", \"G2\", \"G3\", \"G4\", \"G5\", \"G6\", \"G7\", \"GraphName\", \"Type\"]\n",
    "    for i in extra_cl:\n",
    "        cols.append(i)\n",
    "    df = pd.read_csv(\n",
    "        full_path_filename,\n",
    "        delimiter=\",\",\n",
    "        usecols=cols,\n",
    "    )\n",
    "    df = df.set_index(\"Type\")\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPSILON = 0.0001\n",
    "names = []\n",
    "all_cnts = []\n",
    "\n",
    "tot = 0\n",
    "for filename in os.listdir(PATH):\n",
    "    if \".sha256\" in filename:\n",
    "        continue\n",
    "    print(filename)\n",
    "\n",
    "    counts = read_stuff(PATH / filename)\n",
    "    counts.replace(0, EPSILON, inplace=True)\n",
    "    \n",
    "    counts = counts.groupby(\"GraphName\").apply(\n",
    "        lambda x: np.abs(x.iloc[0, :] - x.iloc[1, :])\n",
    "    )\n",
    "    quarts = np.quantile(counts, [0.25, 0.5, 0.75], axis=0)\n",
    "    all_cnts.append(quarts)\n",
    "    names.append(filename)\n",
    "    \n",
    "all_cnts = np.array(all_cnts)\n",
    "print(all_cnts.shape)\n",
    "\n",
    "v,c = np.unique(np.argmin(all_cnts, axis=0), return_counts=True)\n",
    "to_use = v[np.argmax(c)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH2 = Path(\n",
    "    \".../evaluate_models/CORRECTIONS/preds/sreal_preds\"\n",
    ")\n",
    "\n",
    "\n",
    "for filename in os.listdir(PATH2):\n",
    "    if \".sha256\" in filename:\n",
    "        continue\n",
    "    print(filename)\n",
    "\n",
    "    true_sps = read_stuff(PATH2 / filename, [\"DatasetName\"])\n",
    "    true_sps.replace(0, EPSILON, inplace=True)\n",
    "    break\n",
    "\n",
    "true_sps = true_sps[true_sps.index.isin([\"True\"])]\n",
    "true_sps = true_sps.sort_values(by=\"GraphName\", ascending=True)\n",
    "true_sps = true_sps[\n",
    "    ~(\n",
    "        (true_sps[\"DatasetName\"] == \"srealCHEMINFORMATICS\")\n",
    "        | (true_sps[\"DatasetName\"] == \"srealANIMALSOCIAL\")\n",
    "        | (true_sps[\"DatasetName\"] == \"srealBRAIN\")\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_sps.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH3 = Path(\".../sreal/raw_scores\")\n",
    "\n",
    "\n",
    "def read_stuff2(full_path_filename, name, dname):\n",
    "    df = pd.read_csv(\n",
    "        full_path_filename,\n",
    "        skipinitialspace=True,\n",
    "        delimiter=\",\",\n",
    "        usecols=[\"occ_original\", \"avg_random\", \"stdev_random\"],\n",
    "    )\n",
    "    df[\"Graph\"] = [i for i in range(8)]\n",
    "    df[\"GraphName\"] = name\n",
    "    df[\"DatasetName\"] = dname\n",
    "    return df\n",
    "\n",
    "\n",
    "dfs = []\n",
    "\n",
    "for filename in sorted(os.listdir(PATH3)):\n",
    "    counts = read_stuff2(\n",
    "        PATH3 / filename,\n",
    "        filename.split(\"@\")[1].split(\".score\")[0],\n",
    "        filename.split(\"@\")[0],\n",
    "    )\n",
    "    dfs.append(counts)\n",
    "\n",
    "df = (\n",
    "    pd.concat(dfs)\n",
    "    .reset_index(drop=True)\n",
    "    .rename(\n",
    "        columns={\"occ_original\": \"y\", \"avg_random\": \"Ey\", \"stdev_random\": \"sigma_y\"}\n",
    "    )\n",
    ")\n",
    "df = df.sort_values(by=\"GraphName\", ascending=True)\n",
    "df = df[\n",
    "    ~(\n",
    "        (df[\"DatasetName\"] == \"srealCHEMINFORMATICS\")\n",
    "        | (df[\"DatasetName\"] == \"srealANIMALSOCIAL\")\n",
    "        | (df[\"DatasetName\"] == \"srealBRAIN\")\n",
    "    )\n",
    "]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {},
   "outputs": [],
   "source": [
    "ms = [\n",
    "    92.15,\n",
    "    59.08,\n",
    "    140.78,\n",
    "    37.04,\n",
    "    913.2499999999999,\n",
    "    109.96000000000001,\n",
    "    79.50999999999999,\n",
    "    201.59999999999997,\n",
    "]\n",
    "ss = [\n",
    "    104.06789850861793,\n",
    "    96.61549358151622,\n",
    "    252.91898228484155,\n",
    "    89.70138460469828,\n",
    "    1780.922847149758,\n",
    "    252.750031454004,\n",
    "    186.38784804809566,\n",
    "    367.91743639028584,\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_scores = []\n",
    "z_scores2 = []\n",
    "for i in range(0, df.shape[0], 8):\n",
    "    z_score_df = df.iloc[i:i+8]\n",
    "    graph = []\n",
    "    graph2 = []\n",
    "    for j in range(8):\n",
    "        vals = z_score_df[z_score_df[\"Graph\"] == j]\n",
    "        _t1 = vals[\"y\"] - vals[\"Ey\"] - ss[j]\n",
    "        _t2 = vals[\"y\"] - vals[\"Ey\"] + ss[j]\n",
    "        _t1 /= np.sqrt(vals[\"sigma_y\"]**2 + ss[j]**2)\n",
    "        _t2 /= np.sqrt(vals[\"sigma_y\"]**2 + ss[j]**2)\n",
    "        graph.append(_t1.item())\n",
    "        graph2.append(_t2.item())\n",
    "    z_scores.append(graph)\n",
    "    z_scores2.append(graph2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_scores = np.array(z_scores)\n",
    "for i in range(z_scores.shape[0]):\n",
    "    z_scores[i, :] = (z_scores[i, :]/np.sqrt(np.sum(z_scores[i, :]**2)))\n",
    "z_scores2 = np.array(z_scores2)\n",
    "for i in range(z_scores2.shape[0]):\n",
    "    z_scores2[i, :] = (z_scores2[i, :]/np.sqrt(np.sum(z_scores2[i, :]**2)))\n",
    "\n",
    "print(np.sum(z_scores**2, axis=1))\n",
    "print(np.sum(z_scores**2, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "metadata": {},
   "outputs": [],
   "source": [
    "diffs_count = np.abs(true_sps.iloc[:, 0:8].to_numpy() - z_scores)\n",
    "diffs_count2 = np.abs(true_sps.iloc[:, 0:8].to_numpy() - z_scores2)\n",
    "diff_final = np.minimum(diffs_count, diffs_count2)\n",
    "quarts_count = np.quantile(diff_final, [0.25, 0.5, 0.75], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(\n",
    "    np.vstack(\n",
    "        [\n",
    "            all_cnts[to_use, :, :].reshape(-1, 8),\n",
    "            quarts_count.reshape(-1, 8),\n",
    "        ]\n",
    "    ),\n",
    "    columns=[\"G\" + str(i) for i in range(8)],\n",
    ")\n",
    "df[\"Type\"] = [\"SP\", \"SP\", \"SP\", \"Count\", \"Count\", \"Count\"]\n",
    "df[\"QS\"] = [\"0.25\",\"0.5\",\"0.75\"]*2\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_melted = pd.melt(\n",
    "    df,\n",
    "    id_vars=[\"Type\", \"QS\"],\n",
    "    value_vars=[\"G0\", \"G1\", \"G2\", \"G3\", \"G4\", \"G5\", \"G6\", \"G7\"],\n",
    "    var_name=\"Graph\",\n",
    "    value_name=\"Value\",\n",
    ")\n",
    "df_pivot = df_melted.pivot_table(\n",
    "    index=[\"Graph\", \"Type\"], columns=\"QS\", values=\"Value\"\n",
    ").reset_index()\n",
    "df_pivot.columns = [\"Graph\", \"Type\", \"QS = 0.25\", \"QS = 0.5\", \"QS = 0.75\"]\n",
    "\n",
    "np.round(df_pivot, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def diff_val(x, switch=False):\n",
    "    if switch:\n",
    "        ours_vs_theirs = x.iloc[1, 1:] / x.iloc[0, 1:]\n",
    "    else:\n",
    "        ours_vs_theirs = x.iloc[1, 0:] / x.iloc[0, 0:]\n",
    "    res = np.select(\n",
    "        [ours_vs_theirs <= 1, ours_vs_theirs > 1],\n",
    "        [-(1 - ours_vs_theirs) * 100, (ours_vs_theirs - 1) * 100],\n",
    "    )\n",
    "    return pd.Series(res)\n",
    "\n",
    "\n",
    "df_pivot.drop(columns=[\"Type\"]).groupby(\"Graph\").apply(\n",
    "    lambda x: np.round(diff_val(x, switch=False), 3)\n",
    ").transpose()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Single vs Multi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\n",
    "    \".../multi-vs-isolated.txt\", skipinitialspace=True\n",
    ")\n",
    "df.drop(columns=[\"type\"]).groupby(\"g\").apply(\n",
    "    lambda x: np.round(diff_val(x, switch=True), 3)\n",
    ").transpose()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_of_truth",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
