{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6189067f",
   "metadata": {},
   "source": [
    "# Modeling assumptions\n",
    "\n",
    "Let $S_t^{(a)}=\\phi_\\theta(Z^{(a)},X_t).$\n",
    "$$Y_t^{(a)}=\\theta^{(a)\\top }S_t^{(a)} + \\epsilon_t^{(a)}$$ \n",
    "$$\\theta^{(a)}\\sim N(\\mu, \\Sigma) \\text{ iid}$$\n",
    "$$\\epsilon_t^{(a)}\\sim N(0,\\sigma^2) \\text{ iid}$$\n",
    "so that \n",
    "$$Y^{(a)}_t\\mid S_t^{(a)} \\sim N(\\mu^\\top S_t^{(a)}, \\sigma^2+S_t^{(a)\\top} \\Sigma S_t^{(a)})$$\n",
    "\n",
    "# Estimating params\n",
    "For each row indexed by $a$, calculate the OLS coefficients, $\\hat\\theta^{(a)}$. Then set $\\mu$ to be the mean of the $\\hat\\theta^{(a)}$, $\\Sigma$ to be the covariance of the $\\hat\\theta^{(a)}$. Let $\\sigma^2$ be the variance of the residuals. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9f184583",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fbbe8128",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import Ridge\n",
    "import numpy as np\n",
    "\n",
    "def get_params(Xb, Y, alpha=1.0):\n",
    "    with torch.no_grad():\n",
    "        all_solns = []\n",
    "        all_preds = []\n",
    "        cutoff = int(Xb.shape[1] * 0.9)\n",
    "        for i in range(len(Xb)):\n",
    "            clf = Ridge(alpha=alpha, fit_intercept=False)\n",
    "            # fit on \"training\" set\n",
    "            clf.fit(Xb[i,:cutoff].detach().numpy(), Y[i,:cutoff].detach().numpy())\n",
    "            all_preds.append(clf.predict(Xb[i].detach().numpy()))\n",
    "            all_solns.append(clf.coef_)\n",
    "        all_solution = torch.tensor(all_solns)\n",
    "        coef_mean = all_solution.mean(0)\n",
    "        coef_cov = torch.cov(all_solution.T) #+ torch.eye(len(all_solution.T)) * 0.01\n",
    "        preds = torch.tensor(all_preds)\n",
    "        # calculate noise var on held-out set\n",
    "        noise_var = ((Y[:,cutoff:] - preds[:,cutoff:])**2).mean()\n",
    "        synth_params = {'prior_mean':coef_mean, 'prior_cov':coef_cov, 'noise_var':noise_var}\n",
    "    return synth_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bbe3d4ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "from postprocessing import load_old_model\n",
    "\n",
    "def get_model_featurizer(model_dir):\n",
    "    model_path = model_dir + \"/best_loss.pt\"\n",
    "    check = torch.load(model_path, map_location=torch.device('cpu'))\n",
    "    config_path = model_dir + \"/config.pt\"\n",
    "    config = torch.load(config_path, map_location=torch.device('cpu'))\n",
    "    config.device = 'cpu'\n",
    "    model = load_old_model(config, check['state_dict'], check)\n",
    "\n",
    "    # replace last part of the model with \n",
    "    top_layer_model = model.top_layer.model\n",
    "    print('old top layer')\n",
    "    print(top_layer_model)\n",
    "    # Create a new sequential model with the first 6 layers (indices 0 to 5)\n",
    "    new_top_layer_model = torch.nn.Sequential(*list(top_layer_model.children())[:5])\n",
    "\n",
    "    print('new top layer')\n",
    "    print(new_top_layer_model)\n",
    "    # Replace the original top_layer.model with the new one\n",
    "    model.top_layer.model = new_top_layer_model\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "54f6dfae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(self, Zkwargs, X): \n",
    "    with torch.no_grad():\n",
    "        embed = self.z_encoder(Zkwargs)\n",
    "\n",
    "        bs, ncol, xdim = X.shape\n",
    "        bsZ, zdim = embed.shape\n",
    "        assert bs == bsZ \n",
    "        embedZ = embed.unsqueeze(1).repeat((1,ncol,1))\n",
    "        X_enc = self.x_suff_encoder(X)\n",
    "        state = self.init_model_states(X.shape[0])\n",
    "        embedState = state.unsqueeze(1).repeat((1,ncol,1))\n",
    "\n",
    "        input_ = torch.cat([embedZ, X_enc, embedState], 2)\n",
    "        return self.top_layer(input_)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1233ad45",
   "metadata": {},
   "source": [
    "## Synthetic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "8ac8a525",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '/shared/share_mala/implicitbayes/dataset_files/synthetic_data/binary_context/N=1000,D=10000,D_eval=10000,method=0111_logistic,dimX=5,one_X_per_col=False,flip/'\n",
    "train_data = torch.load(data_dir + '/train_data.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "5b1c2eb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = '/shared/share_mala/implicitbayes/dataset_files/synthetic_data/binary_context/N=1000,D=10000,D_eval=10000,method=0111_logistic,dimX=5,one_X_per_col=False,flip//models//use_Y_linear_comparisons_flip_0111_logistic_sequential_dimX=5/fs_sequential:epochs=100,bs=500,lr=0.01,wd=0.01,MLP_layers=3,MLP_width=100,weight_factor=1,max_obs=500,repeat_suffstat=1,Zdim=2,sched=constant,suffstat_eps=1.0,useY,seed=2340923'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "4f8caed3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:root:IS SEQUENTIAL True\n",
      "old top layer\n",
      "Sequential(\n",
      "  (0): Linear(in_features=37, out_features=100, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=100, out_features=100, bias=True)\n",
      "  (3): ReLU()\n",
      "  (4): Linear(in_features=100, out_features=100, bias=True)\n",
      "  (5): ReLU()\n",
      "  (6): Linear(in_features=100, out_features=1, bias=True)\n",
      "  (7): Sigmoid()\n",
      ")\n",
      "new top layer\n",
      "Sequential(\n",
      "  (0): Linear(in_features=37, out_features=100, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=100, out_features=100, bias=True)\n",
      "  (3): ReLU()\n",
      "  (4): Linear(in_features=100, out_features=100, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model = get_model_featurizer(model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "dbee8168",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = train_data['X']\n",
    "Y = train_data['Y']\n",
    "Z = train_data['Z']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "4ab863a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2000, 1000, 101])\n"
     ]
    }
   ],
   "source": [
    "# we only use the first max_rows rows to avoid memory issues\n",
    "max_rows = 2000\n",
    "phi_embeddings = forward(model, Z[:max_rows], X[:max_rows])\n",
    "Y_flat = Y[:max_rows]\n",
    "\n",
    "X = phi_embeddings\n",
    "X_flat = X.reshape(-1, X.shape[-1])\n",
    "Xb = torch.cat([X, torch.ones(X.shape[:-1] + (1,))], dim=-1)\n",
    "print(Xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "4a987e76",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = get_params(Xb, Y_flat, alpha=0.1)\n",
    "cov = params['prior_cov']\n",
    "params['prior_cov'] += torch.eye(len(params['prior_cov'])) * 0.0001\n",
    "torch.save(params, '../neurips_code/saved_params/0111_logistic_feats_estimated_reg0.0001_fixed6.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "485d40e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "synth_params = {\n",
    "    'prior_mean':torch.zeros(101), \n",
    "    'prior_cov':torch.eye(101),\n",
    "    'noise_var':0.25,\n",
    "}\n",
    "torch.save(synth_params, 'saved_params/0111_logistic_feats_uninformed.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b83de31",
   "metadata": {},
   "source": [
    "# Semisynthetic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3a28402c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '/shared/share_mala/implicitbayes/dataset_files/MIND_data/large/N=1000,D=20000,D_eval=10000,method=0111_logistic_withZ_zero_prodsign_last,dimX=5,one_X_per_col=False,flip/bert_Z/'\n",
    "semi_train_data = torch.load(data_dir + '/train_data.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "75f07556",
   "metadata": {},
   "outputs": [],
   "source": [
    "semi_model_dir = '/shared/share_mala/implicitbayes/dataset_files/MIND_data/large/N=1000,D=20000,D_eval=10000,method=0111_logistic_withZ_zero_prodsign_last,dimX=5,one_X_per_col=False,flip//bert_Z//models/even_use_Y_0111_20000_1_context_semisynthetic_fixed_bert_0111_logistic_withZ_zero_prodsign_last_sequential_context_dimX=5/fs_sequential:epochs=40,bs=500,lr=0.01,wd=0.01,MLP_layers=3,MLP_width=100,weight_factor=1,max_obs=500,repeat_suffstat=100,Zdim=768,sched=constant,X_MLP_layer=2,X_MLP_width=100,suffstat_eps=1.0,useY,seed=2340923'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4eeb6530",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:root:IS SEQUENTIAL True\n",
      "old top layer\n",
      "Sequential(\n",
      "  (0): Linear(in_features=3773, out_features=100, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=100, out_features=100, bias=True)\n",
      "  (3): ReLU()\n",
      "  (4): Linear(in_features=100, out_features=100, bias=True)\n",
      "  (5): ReLU()\n",
      "  (6): Linear(in_features=100, out_features=1, bias=True)\n",
      "  (7): Sigmoid()\n",
      ")\n",
      "new top layer\n",
      "Sequential(\n",
      "  (0): Linear(in_features=3773, out_features=100, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=100, out_features=100, bias=True)\n",
      "  (3): ReLU()\n",
      "  (4): Linear(in_features=100, out_features=100, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "semi_model = get_model_featurizer(semi_model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7f830ec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get feats\n",
    "X = semi_train_data['X']\n",
    "Y = semi_train_data['Y']\n",
    "Z = semi_train_data['Z']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "87f24f3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "5d0739ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71c7131a58d14aada1e89415c82e44d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "max_rows = 100\n",
    "all_embeds = []\n",
    "for i in tqdm(range(20)):\n",
    "    embeds = forward(semi_model, Z[max_rows*i:max_rows*(i+1)], X[max_rows*i:max_rows*(i+1)])\n",
    "    all_embeds.append(embeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "667c0228",
   "metadata": {},
   "outputs": [],
   "source": [
    "semi_phi_embeddings = torch.cat(all_embeds, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "af6ff844",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_flat = Y[:2000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "1e3ba92c",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(Y_flat) == len(semi_phi_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "4923382c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2000, 1000, 101])\n"
     ]
    }
   ],
   "source": [
    "X = semi_phi_embeddings\n",
    "X_flat = X.reshape(-1, X.shape[-1])\n",
    "Xb = torch.cat([X, torch.ones(X.shape[:-1] + (1,))], dim=-1)\n",
    "print(Xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b90e2fa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = get_params(Xb, Y_flat, alpha=0.1)\n",
    "cov = params['prior_cov']\n",
    "params['prior_cov'] += torch.eye(len(params['prior_cov'])) * 0.0001\n",
    "torch.save(params, 'saved_params/0111_logistic_withZ_zero_prodsign_last_feats_estimated_reg0.0001_fixed2_alpha0.1.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5420cff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "synth_params = {\n",
    "    'prior_mean':torch.zeros(101), \n",
    "    'prior_cov':torch.eye(101),\n",
    "    'noise_var':0.25,\n",
    "}\n",
    "torch.save(synth_params, 'saved_params/0111_logistic_withZ_zero_prodsign_last_feats_uninformed.pt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
