{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6189067f",
   "metadata": {},
   "source": [
    "# Modeling assumptions\n",
    "\n",
    "$$Y_t^{(a)}=\\theta^{(a)\\top }X_t + \\epsilon_t^{(a)}$$ \n",
    "$$\\theta^{(a)}\\sim N(\\mu, \\Sigma) \\text{ iid}$$\n",
    "$$\\epsilon_t^{(a)}\\sim N(0,\\sigma^2) \\text{ iid}$$\n",
    "so that \n",
    "$$Y^{(a)}_t\\mid X_t \\sim N(\\mu^\\top X_t, \\sigma^2+X_t^\\top \\Sigma X_t)$$\n",
    "\n",
    "# Estimating params\n",
    "For each row indexed by $a$, calculate the OLS coefficients, $\\hat\\theta^{(a)}$. Then set $\\mu$ to be the mean of the $\\hat\\theta^{(a)}$, $\\Sigma$ to be the covariance of the $\\hat\\theta^{(a)}$. Let $\\sigma^2$ be the variance of the residuals. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb48583a",
   "metadata": {},
   "source": [
    "## Synthetic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "37930285",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "c3f163b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "import numpy as np\n",
    "\n",
    "def get_params(Xb, Y, held_out_noise_var=True, alpha=1.0):\n",
    "    with torch.no_grad():\n",
    "        all_solns = []\n",
    "        all_preds = []\n",
    "        cutoff = int(Xb.shape[1] * 0.9)\n",
    "        for i in range(len(Xb)):\n",
    "            clf = Ridge(alpha=alpha, fit_intercept=False)\n",
    "            if held_out_noise_var:\n",
    "                clf.fit(Xb[i,:cutoff].detach().numpy(), Y[i,:cutoff].detach().numpy())\n",
    "            else:\n",
    "                clf.fit(Xb[i].detach().numpy(), Y[i].detach().numpy())\n",
    "            all_preds.append(clf.predict(Xb[i].detach().numpy()))\n",
    "            all_solns.append(clf.coef_)\n",
    "        all_solution = torch.tensor(all_solns)\n",
    "        coef_mean = all_solution.mean(0)\n",
    "        coef_cov = torch.cov(all_solution.T) #+ torch.eye(len(all_solution.T)) * 0.01\n",
    "        preds = torch.tensor(all_preds)\n",
    "        if held_out_noise_var:\n",
    "            noise_var = ((Y[:,:cutoff] - preds[:,:cutoff])**2).mean()\n",
    "        else:\n",
    "            noise_var = ((Y - preds)**2).mean()\n",
    "        synth_params = {'prior_mean':coef_mean, 'prior_cov':coef_cov, 'noise_var':noise_var}\n",
    "    return synth_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "8ac8a525",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '/shared/share_mala/implicitbayes/dataset_files/synthetic_data/binary_context/N=1000,D=10000,D_eval=10000,method=0111_logistic,dimX=5,one_X_per_col=False,flip/'\n",
    "train_data = torch.load(data_dir + '/train_data.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "cb4f03ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = train_data['X']\n",
    "Y = train_data['Y']\n",
    "Xb = torch.cat([X, torch.ones(X.shape[:-1] + (1,))], dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "33c7bce0",
   "metadata": {},
   "outputs": [],
   "source": [
    "synth_params = get_params(Xb, Y, alpha=0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "7b1bc001",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0004, 0.0004, 0.0015, 0.0074, 0.0171, 0.0598])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# sanity check eigenvalues\n",
    "eigvals, eigvecs = torch.linalg.eigh(synth_params['prior_cov'])\n",
    "eigvals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "485d40e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(synth_params, '../neurips_code/saved_params/0111_logistic_fixed3.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "5e678971",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0004, 0.0004, 0.0015, 0.0074, 0.0171, 0.0598])"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_params = get_params(Xb, Y, alpha=0.0)\n",
    "# sanity check eigenvalues\n",
    "eigvals, eigvecs = torch.linalg.eigh(synth_params['prior_cov'])\n",
    "eigvals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "48967911",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0004, 0.0004, 0.0015, 0.0074, 0.0171, 0.0597])"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_params = get_params(Xb, Y, alpha=1.0)\n",
    "# sanity check eigenvalues\n",
    "eigvals, eigvecs = torch.linalg.eigh(synth_params['prior_cov'])\n",
    "eigvals"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b83de31",
   "metadata": {},
   "source": [
    "## Semisynthetic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "3a28402c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '/shared/share_mala/implicitbayes/dataset_files/MIND_data/large/N=1000,D=20000,D_eval=10000,method=0111_logistic_withZ_zero_prodsign_last,dimX=5,one_X_per_col=False,flip/'\n",
    "semi_train_data = torch.load(data_dir + '/train_data.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "7f9a879c",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = semi_train_data['X']\n",
    "Y = semi_train_data['Y']\n",
    "Xb = torch.cat([X, torch.ones(X.shape[:-1] + (1,))], dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "5d60a675",
   "metadata": {},
   "outputs": [],
   "source": [
    "semisynth_params = get_params(Xb, Y, alpha=0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "4377f854",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0002, 0.0002, 0.0002, 0.0002, 0.0003, 0.0579])"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# sanity check eigenvalues\n",
    "eigvals, eigvecs = torch.linalg.eigh(semisynth_params['prior_cov'])\n",
    "eigvals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "5705b86b",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(semisynth_params, '../neurips_code/saved_params/0111_logistic_withZ_zero_prodsign_last_fixed3.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e862d77",
   "metadata": {},
   "source": [
    "# Bandit stuff: Bayesian inference\n",
    "\n",
    "Say we know $\\mu,\\sigma,\\Sigma$ as above. Then \n",
    "$$\\theta^{(a)}\\mid \\{ Y_t^{(a)}\\}_{t=1}^{T} \\sim N(m_{a,T},\\Sigma_{a,T})$$\n",
    "where\n",
    "$$\\Sigma_{a,T}=\\left(\\Sigma^{-1} +\\frac{1}{\\sigma^2}\\sum_{t=1}^T X_t^\\top X_t\\right)^{-1}$$\n",
    "$$m_{a,T}=\\Sigma_{a,T}\\left(\\Sigma^{-1}\\mu +\\frac{1}{\\sigma^2}\\sum_{t=1}^T X_t^\\top Y_t^{(a)}\\right)$$"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
