{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.003\n",
    "attn_score_pt_path = f\"w_{w}_num_layers_1_time_feat_dim_8_linear_attn_score.pt\"\n",
    "attn_score_dict = torch.load(os.path.join(\"attn_analysis_output\", attn_score_pt_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_attn_score(attn_score, t_diff, w):\n",
    "    \"\"\"\n",
    "    Get the normalized attention score and true attention score for the first n-1 tokens\n",
    "    attn_score: [batch_size, n, n]\n",
    "    t_diff: [batch_size, n]\n",
    "    w: float, exponential decay factor\n",
    "    \"\"\"\n",
    "    attn_score_true = torch.exp(-w * t_diff)\n",
    "    # exclude the last token and normalize the previous tokens\n",
    "    attn_score_true = attn_score_true[:, :-1] / attn_score_true[:, :-1].sum(dim=1, keepdim=True)\n",
    "    attn_score = attn_score[:, -1, :-1] / attn_score[:, -1, :-1].sum(dim=1, keepdim=True)\n",
    "    return attn_score, attn_score_true\n",
    "\n",
    "\n",
    "def get_attn_err(attn_score_dict, w):\n",
    "    train_attn_score, train_attn_score_true = get_attn_score(\n",
    "        attn_score_dict[\"train_attn_score\"], attn_score_dict[\"train_t_diff\"], w\n",
    "    )\n",
    "    val_attn_score, val_attn_score_true = get_attn_score(\n",
    "        attn_score_dict[\"val_attn_score\"], attn_score_dict[\"val_t_diff\"], w\n",
    "    )\n",
    "    test_attn_score, test_attn_score_true = get_attn_score(\n",
    "        attn_score_dict[\"test_attn_score\"], attn_score_dict[\"test_t_diff\"], w\n",
    "    )\n",
    "\n",
    "    train_attn_err = torch.mean((train_attn_score_true - train_attn_score) ** 2)\n",
    "    val_attn_err = torch.mean((val_attn_score_true - val_attn_score) ** 2)\n",
    "    test_attn_err = torch.mean((test_attn_score_true - test_attn_score) ** 2)\n",
    "\n",
    "    return train_attn_err, val_attn_err, test_attn_err"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def viz_attn_score(attn_score_dict, w):\n",
    "    train_attn_score, train_attn_score_true = get_attn_score(\n",
    "        attn_score_dict[\"train_attn_score\"], attn_score_dict[\"train_t_diff\"], w\n",
    "    )\n",
    "    val_attn_score, val_attn_score_true = get_attn_score(\n",
    "        attn_score_dict[\"val_attn_score\"], attn_score_dict[\"val_t_diff\"], w\n",
    "    )\n",
    "    test_attn_score, test_attn_score_true = get_attn_score(\n",
    "        attn_score_dict[\"test_attn_score\"], attn_score_dict[\"test_t_diff\"], w\n",
    "    )\n",
    "\n",
    "    avg_train_attn_score = train_attn_score.mean(dim=0)\n",
    "    std_train_attn_score = train_attn_score.std(dim=0)\n",
    "    avg_train_attn_score_true = train_attn_score_true.mean(dim=0)\n",
    "    std_train_attn_score_true = train_attn_score_true.std(dim=0)\n",
    "\n",
    "    avg_val_attn_score = val_attn_score.mean(dim=0)\n",
    "    std_val_attn_score = val_attn_score.std(dim=0)\n",
    "    avg_val_attn_score_true = val_attn_score_true.mean(dim=0)\n",
    "    std_val_attn_score_true = val_attn_score_true.std(dim=0)\n",
    "\n",
    "    avg_test_attn_score = test_attn_score.mean(dim=0)\n",
    "    std_test_attn_score = test_attn_score.std(dim=0)\n",
    "    avg_test_attn_score_true = test_attn_score_true.mean(dim=0)\n",
    "    std_test_attn_score_true = test_attn_score_true.std(dim=0)\n",
    "\n",
    "    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 16))\n",
    "\n",
    "    num_runs = 10\n",
    "\n",
    "    ax1.errorbar(\n",
    "        range(1, len(avg_train_attn_score) + 1),\n",
    "        avg_train_attn_score,\n",
    "        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_train_attn_score],\n",
    "        label=\"Learned\",\n",
    "    )\n",
    "    ax1.errorbar(\n",
    "        range(1, len(avg_train_attn_score_true) + 1),\n",
    "        avg_train_attn_score_true,\n",
    "        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_train_attn_score_true],\n",
    "        label=\"Truth\",\n",
    "    )\n",
    "    ax1.legend()\n",
    "    # ax1.plot(range(1, len(avg_train_attn_score) + 1), avg_train_attn_score)\n",
    "    # ax1.plot(range(1, len(avg_train_attn_score_true) + 1), avg_train_attn_score_true)\n",
    "    ax1.set_ylabel(\"Attention Score\")\n",
    "    ax1.set_xlabel(\"Event Index\")\n",
    "    ax1.set_title(\"Train\")\n",
    "\n",
    "    ax2.errorbar(\n",
    "        range(1, len(avg_val_attn_score) + 1),\n",
    "        avg_val_attn_score,\n",
    "        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_val_attn_score],\n",
    "        label=\"Learned\",\n",
    "    )\n",
    "    ax2.errorbar(\n",
    "        range(1, len(avg_val_attn_score_true) + 1),\n",
    "        avg_val_attn_score_true,\n",
    "        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_val_attn_score_true],\n",
    "        label=\"Truth\",\n",
    "    )\n",
    "    ax2.legend()\n",
    "    # ax2.plot(range(1, len(avg_val_attn_score) + 1), avg_val_attn_score)\n",
    "    # ax2.plot(range(1, len(avg_val_attn_score_true) + 1), avg_val_attn_score_true)\n",
    "    ax2.set_ylabel(\"Attention Score\")\n",
    "    ax2.set_xlabel(\"Event Index\")\n",
    "    ax2.set_title(\"Val\")\n",
    "\n",
    "    ax3.errorbar(\n",
    "        range(1, len(avg_test_attn_score) + 1),\n",
    "        avg_test_attn_score,\n",
    "        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_test_attn_score],\n",
    "        label=\"Learned\",\n",
    "    )\n",
    "    ax3.errorbar(\n",
    "        range(1, len(avg_test_attn_score_true) + 1),\n",
    "        avg_test_attn_score_true,\n",
    "        yerr=[1.96 * std / math.sqrt(num_runs) for std in std_test_attn_score_true],\n",
    "        label=\"Truth\",\n",
    "    )\n",
    "    ax3.legend()\n",
    "    # ax3.plot(range(1, len(avg_test_attn_score) + 1), avg_test_attn_score)\n",
    "    # ax3.plot(range(1, len(avg_test_attn_score_true) + 1), avg_test_attn_score_true)\n",
    "    ax3.set_ylabel(\"Attention Score\")\n",
    "    ax3.set_xlabel(\"Event Index\")\n",
    "    ax3.set_title(\"Test\")\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_attn_err, val_attn_err, test_attn_err = get_attn_err(attn_score_dict, w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_attn_err, val_attn_err, test_attn_err"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "viz_attn_score(attn_score_dict, w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import numpy as np\n",
    "# import matplotlib.pyplot as plt\n",
    "\n",
    "# # Parameters\n",
    "# a = 0.003       # decay rate\n",
    "# omega = 0.02   # angular frequency\n",
    "\n",
    "# # Time range\n",
    "# t = np.linspace(0, 800, 1000)\n",
    "\n",
    "# # Define the function\n",
    "# f = np.exp(-a * t) * np.cos(omega * t)**2\n",
    "\n",
    "# # Plot\n",
    "# plt.figure(figsize=(8, 4))\n",
    "# plt.plot(t, f, label=r'$e^{-a t}\\cos^2(\\omega t)$')\n",
    "# plt.title('Damped Oscillation: $e^{-a t}\\\\cos^2(\\\\omega t)$')\n",
    "# plt.xlabel('t')\n",
    "# plt.ylabel('f(t)')\n",
    "# plt.ylim(0, 1.1)  # A bit above 1 for clarity\n",
    "# plt.legend()\n",
    "# plt.grid(True)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dg-shift",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
