{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b81091fa-7a81-4908-89cc-e47999a1a280",
   "metadata": {},
   "source": [
    "# Plot AMA Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ba52b74-5eb1-404a-bf1f-f7a12eb58d91",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import t\n",
    "import pickle\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "989792b6-4dcf-486c-8135-9dccaccc3af1",
   "metadata": {},
   "source": [
    "##  NtD Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b1a1da0-9079-40d7-a05e-6d8ed567cec0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open('NtD_results.pkl', 'rb') as f:\n",
    "    results = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69034be7-ce5b-4c32-a191-4fe90801ef05",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "OODerrors = np.zeros((len(results),len(results[0])//2)) # Number of trials x Number of AMA iterations\n",
    "AMAloss = np.zeros((len(results),len(results[0])))\n",
    "for i in range(len(results)):\n",
    "    results_dropna = results[i].dropna()\n",
    "    OODerrors[i,:] = results_dropna['Relative OOD Error'].to_numpy()\n",
    "    AMAloss[i,:] = results[i]['AMA Loss'].to_numpy()\n",
    "OODiterations = results_dropna['Iteration'].to_numpy()+.5\n",
    "AMAiterations = results[0]['Iteration'].to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d69a2d33-7fc1-4407-9bd7-23418506a8f5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "n_runs = OODerrors.shape[0]\n",
    "avgOODerrors = np.mean(OODerrors, axis=0)\n",
    "stdOODerrors = np.std(OODerrors, axis=0, ddof=1)\n",
    "se = stdOODerrors / np.sqrt(n_runs)\n",
    "tcrit = t.ppf(0.975, df=n_runs-1)\n",
    "ci = tcrit * se\n",
    "\n",
    "plt.figure(figsize=(7, 4))\n",
    "plt.errorbar(\n",
    "    OODiterations, avgOODerrors,\n",
    "    yerr=ci,\n",
    "    fmt='o--',            # circle markers with dashed line\n",
    "    color='red',          # line and marker color\n",
    "    ecolor='black',       # error bar color\n",
    "    capsize=5            # cap size for error bars\n",
    "    # label='Average OOD Error'\n",
    ")\n",
    "\n",
    "plt.grid(True,which='both')\n",
    "plt.xlim(.5, 10.5)\n",
    "plt.xticks(np.arange(0, 11, 1))\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Average Relative OOD Error')\n",
    "plt.yscale('log') \n",
    "plt.tight_layout()\n",
    "plt.savefig('NtD_OODresult.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d6dcd5f-ff74-46e1-9bae-b63f4a0b61f2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "n_runs = AMAloss.shape[0]\n",
    "AMAloss /= AMAloss[:,0:1]\n",
    "avgAMAloss = np.mean(AMAloss, axis=0)\n",
    "stdAMAloss = np.std(AMAloss, axis=0, ddof=1)\n",
    "se = stdAMAloss / np.sqrt(n_runs)\n",
    "tcrit = t.ppf(0.975, df=n_runs-1)\n",
    "ci = tcrit * se\n",
    "\n",
    "plt.figure(figsize=(7, 4))\n",
    "plt.semilogy(AMAiterations, avgAMAloss,'o--',color='red')\n",
    "\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Average Relative AMA Loss')\n",
    "plt.grid(True,which=\"both\")\n",
    "plt.xlim(0, 10.5)\n",
    "plt.xticks(np.arange(0, 11, 1))\n",
    "plt.tight_layout()\n",
    "plt.savefig('NtD_AMAresult.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88447326-6b4e-46f1-961d-62e010a22c24",
   "metadata": {},
   "source": [
    "## Darcy Flow result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "984f2724-4ba0-488d-8479-829410519716",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open('DarcyFlow_results.pkl', 'rb') as f:\n",
    "    results = pickle.load(f)\n",
    "# Change to cpu\n",
    "def detach_df(df):\n",
    "    return df.applymap(lambda x: x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x)\n",
    "results = [detach_df(df) for df in results]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56be59bd-9252-49bb-9a29-6c1eaec8430d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "OODerrors = np.zeros((len(results),len(results[0])//2)) # Number of trials x Number of AMA iterations\n",
    "AMAloss = np.zeros((len(results),len(results[0])))\n",
    "for i in range(len(results)):\n",
    "    results_dropna = results[i].dropna()\n",
    "    OODerrors[i,:] = results_dropna['Relative OOD Error'].to_numpy()\n",
    "    AMAloss[i,:] = results[i]['AMA Loss'].to_numpy()\n",
    "OODiterations = results_dropna['Iteration'].to_numpy()+.5\n",
    "AMAiterations = results[0]['Iteration'].to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a3d7cce-a336-4f2e-8eb5-9c72468413fa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "n_runs = OODerrors.shape[0]\n",
    "avgOODerrors = np.mean(OODerrors, axis=0)\n",
    "stdOODerrors = np.std(OODerrors, axis=0, ddof=1)\n",
    "se = stdOODerrors / np.sqrt(n_runs)\n",
    "tcrit = t.ppf(0.975, df=n_runs-1)\n",
    "ci = tcrit * se\n",
    "\n",
    "plt.figure(figsize=(7, 4))\n",
    "plt.errorbar(\n",
    "    OODiterations, avgOODerrors,\n",
    "    yerr=ci,\n",
    "    fmt='o--',            # circle markers with dashed line\n",
    "    color='red',          # line and marker color\n",
    "    ecolor='black',       # error bar color\n",
    "    capsize=5            # cap size for error bars\n",
    "    # label='Average OOD Error'\n",
    ")\n",
    "\n",
    "plt.grid(True,which='both')\n",
    "plt.xlim(.5, 10.5)\n",
    "plt.xticks(np.arange(0, 11, 1))\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Average Relative OOD Error')\n",
    "plt.yscale('log') \n",
    "plt.tight_layout()\n",
    "plt.savefig('Darcy_OODresult.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66154afd-a2e2-47fe-b89d-0db62552b436",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "n_runs = AMAloss.shape[0]\n",
    "AMAloss /= AMAloss[:,0:1]\n",
    "avgAMAloss = np.mean(AMAloss, axis=0)\n",
    "stdAMAloss = np.std(AMAloss, axis=0, ddof=1)\n",
    "se = stdAMAloss / np.sqrt(n_runs)\n",
    "tcrit = t.ppf(0.975, df=n_runs-1)\n",
    "ci = tcrit * se\n",
    "\n",
    "plt.figure(figsize=(7, 4))\n",
    "plt.semilogy(AMAiterations, avgAMAloss,'o--',color='red')\n",
    "\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Average Relative AMA Loss')\n",
    "plt.grid(True,which=\"both\")\n",
    "plt.xlim(0, 10.5)\n",
    "plt.xticks(np.arange(0, 11, 1))\n",
    "plt.tight_layout()\n",
    "plt.savefig('Darcy_AMAresult.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44b80f29-bd65-4906-908b-bd7a30fc4c5e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
