{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import ortho_group\n",
    "import time\n",
    "import datetime\n",
    "\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "import scipy.linalg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"  # Arrange GPU devices starting from 0\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]= \"1\"  # Set the GPU n to use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 50\n",
    "p = n*2\n",
    "niter = 100\n",
    "niterX = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rhos = np.linspace(0.,np.sqrt(0.6),10)\n",
    "sigmas = np.linspace(0.1,1.,11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPS = 1e-12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "double_Rs = []\n",
    "double_Vs = []\n",
    "double_V2s = []\n",
    "double_ts = []\n",
    "double_t2s = []\n",
    "double_t3s = []\n",
    "double_ls = []\n",
    "double_Bs = []\n",
    "double_B2s = []\n",
    "double_Zls =[]\n",
    "\n",
    "\n",
    "\n",
    "nprocess = 0.\n",
    "ntotal = len(rhos)*len(sigmas)*niter*niterX\n",
    "\n",
    "print(p)\n",
    "U = ortho_group.rvs(dim=p)\n",
    "UT = np.transpose(U)\n",
    "diag = np.abs(np.random.randn(p))\n",
    "diag /= diag.sum()\n",
    "D = np.diag(diag)\n",
    "sqrt_D = np.diag(np.sqrt(diag))\n",
    "\n",
    "Sigma = U@D@UT*p # trS = p\n",
    "sqrt_Sigma = U@sqrt_D@UT*np.sqrt(p)\n",
    "beta = np.random.randn(p)\n",
    "beta /= np.linalg.norm(beta)\n",
    "\n",
    "for indr, rho in enumerate(rhos):\n",
    "    # print(rho)\n",
    "    Rs = []\n",
    "    Vs = []\n",
    "    V2s = []\n",
    "    ts = []\n",
    "    t2s = []\n",
    "    t3s = []\n",
    "    ls = []\n",
    "    Bs = []\n",
    "    B2s = []\n",
    "    Zls =[]\n",
    "\n",
    "    Rho = np.eye(n)\n",
    "    for i in range(n):\n",
    "        for j in range(i):\n",
    "            Rho[i][j] = rho**(i-j)\n",
    "\n",
    "    Rho2 = np.eye(n)\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            Rho2[i][j] = rho**np.abs(i-j)\n",
    "        \n",
    "        \n",
    "    for inds, sigma in enumerate(sigmas):\n",
    "\n",
    "        l = 0.   \n",
    "        R = 0.\n",
    "        t = 0.\n",
    "        t2 = 0.\n",
    "        t3 = 0.\n",
    "        Z_l = 0.\n",
    "        l = 0.\n",
    "        V = 0.\n",
    "        V2 = 0.\n",
    "\n",
    "        Omega = (sigma**2/(1-rho**2))*Rho2\n",
    "        tr_omega = np.trace(Omega)\n",
    "\n",
    "\n",
    "        for indX in range(niterX):\n",
    "            Z = np.random.multivariate_normal(np.zeros(p), np.eye(p), size=n)\n",
    "            ZT = np.transpose(Z)\n",
    "            ZZT = Z@ZT\n",
    "            ZZTP = np.linalg.pinv(ZZT)\n",
    "\n",
    "            Z_evals = np.linalg.eigvalsh(ZZT+EPS*np.eye(n))\n",
    "            Z_evals = 1/Z_evals\n",
    "            Z_evals_cut = Z_evals*(Z_evals<=10)\n",
    "            Z_l += (Z_evals_cut).sum()\n",
    "            X = Z@sqrt_Sigma\n",
    "            \n",
    "            XT = np.transpose(X)\n",
    "            XXT = X@XT\n",
    "            XXTP = np.linalg.pinv(XXT)\n",
    "\n",
    "            XTX = XT@X\n",
    "            XTXP = np.linalg.pinv(XTX)\n",
    "            XTXPS = XTXP@Sigma\n",
    "            t += np.trace(XXTP@Omega)\n",
    "            t2 += np.trace(XTXPS)\n",
    "            t3 += np.trace(XTXPS)*tr_omega/n\n",
    "\n",
    "            evals = np.linalg.eigvals(XTXPS+EPS*np.eye(p))\n",
    "            # evals = 1/evals\n",
    "            # evals_cut = evals*(evals<=10)\n",
    "            l += (evals).sum() #### Tr(XTXPS)\n",
    "            \n",
    "            BBT = np.zeros((p,p))\n",
    "            B = np.zeros(p).reshape(-1,1)\n",
    "            for ind in range(niter): \n",
    "                nprocess += 1\n",
    "                rate = nprocess / ntotal\n",
    "                print(\"process: %.3f\"%(100*rate), end='\\r')\n",
    "                eta = np.random.randn(n)*sigma # np.random.multivariate_normal(np.zeros(n),Omega)\n",
    "                # epsilon = np.random.multivariate_normal(np.zeros(n),Omega)\n",
    "                epsilon = Rho@eta\n",
    "\n",
    "\n",
    "                y = np.matmul(X,beta) + epsilon\n",
    "                # betahat = scipy.linalg.lstsq(X, y, lapack_driver='gelsy')[0]\n",
    "                betahat = XT@XXTP@y\n",
    "                betahat = betahat.reshape(-1,1)\n",
    "                betahatT = np.transpose(betahat)\n",
    "                BBT += betahat@betahatT\n",
    "                B += betahat\n",
    "                \n",
    "                # R += np.linalg.norm(beta-betahat.reshape(-1))**2\n",
    "\n",
    "            Cov = BBT/niter-(B/niter)@np.transpose(B/niter)\n",
    "            CovS = Cov@Sigma\n",
    "            V += np.trace(CovS)\n",
    "\n",
    "            # Var2 = Cov@Sigma\n",
    "            # V2 += np.trace(Var2)\n",
    "\n",
    "        # Rs.append(R/niter/niterX)\n",
    "        # dbeta = B.reshape(-1,1)/niter/niterX-beta.reshape(-1,1)\n",
    "        # Bias = np.linalg.norm(dbeta)**2\n",
    "        # Bias2 = np.linalg.norm(sqrt_Sigma@dbeta)**2\n",
    "        # Bs.append(Bias)\n",
    "        # B2s.append(Bias2)\n",
    "        Vs.append(V/niterX)\n",
    "        # V2s.append(V2/niterX)\n",
    "        ts.append(t/niterX)\n",
    "        t2s.append(t2/niterX)\n",
    "        t3s.append(t3/niterX)\n",
    "        ls.append(tr_omega/n*l/niterX)\n",
    "        Zls.append(tr_omega/n*Z_l/niterX)\n",
    "    # double_Rs.append(Rs)\n",
    "    double_Vs.append(Vs)\n",
    "    # double_V2s.append(V2s)\n",
    "    double_ts.append(ts)\n",
    "    double_t2s.append(t2s)\n",
    "    double_t3s.append(t3s)\n",
    "    double_ls.append(ls)\n",
    "    # double_Bs.append(Bs)\n",
    "    # double_B2s.append(B2s)\n",
    "    double_Zls.append(Zls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(np.array(double_ts[:][:-1]).reshape(-1,),np.array(double_ls[:][:-1]).reshape(-1,))\n",
    "plt.plot([0,4],[0,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(np.array(double_ts[:][:-2]).reshape(-1,),np.array(double_Vs[:][:-2]).reshape(-1,))\n",
    "plt.plot([0,3],[0,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(np.array(double_ts[:][:-3]).reshape(-1,),np.array(double_Vs[:][:-3]).reshape(-1,))\n",
    "plt.plot([0,3],[0,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4.5, 3))\n",
    "# plt.figure(figsize=(4.5,3))\n",
    "\n",
    "b = (np.array(double_Vs))\n",
    "c = (np.array(double_t3s))\n",
    "\n",
    "\n",
    "im0 = ax.contour(np.array(sigmas)**2,np.array(rhos)**2, (b),vmin=0,vmax=3, levels=list(np.linspace(0,3,11)),label='th.',alpha=0.5)\n",
    "im1 = ax.contour(np.array(sigmas)**2,np.array(rhos)**2, (c),vmin=0,vmax=3, levels=list(np.linspace(0,3,11)), label='exp.',linestyle='dashed')\n",
    "\n",
    "for c in im1.collections:\n",
    "    c.set_dashes([(0,(2.0,2.0))])\n",
    "\n",
    "h0, l0 = im0.legend_elements()\n",
    "h1, l1 = im1.legend_elements()\n",
    "# im2 = ax2.contour(rhos, sigmas, np.log(c),vmin=-5,vmax=3, levels=np.linspace(-1,5,17))\n",
    "\n",
    "plt.xlabel(r'$\\sigma^2$')\n",
    "# plt.xlabel(r'$\\sigma^2$')\n",
    "\n",
    "plt.ylabel(r'$\\rho^2$')\n",
    "\n",
    "plt.xticks(np.linspace(0,1,5))\n",
    "# ax1.set_xticks(np.linspace(0,1,5))\n",
    "\n",
    "# ax0.set_title(r'$\\log\\left(\\frac{1}{p}Tr(\\Omega \\Upsilon)\\right)$')\n",
    "# ax0.set_title(r'$ E_X[Var(\\hat\\beta|X)]$')\n",
    "# ax1.set_title(r'$\\frac{1}{np}Tr(\\Omega)E_X[Tr(\\Lambda^{-1})]$')\n",
    "# fig.ylabel(r'$\\sigma$')\n",
    "\n",
    "plt.title('Prediction Risk - AR(1) Errors')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.colorbar(im0, ax=(ax0, ax1))\n",
    "\n",
    "# plt.clabel(im0, inline=1, fontsize=8, colors='k')#, levels=np.linspace(-5,4,10))].clabel(im1, inline=1, fontsize=8, colors='k')#, levels=np.linspace(-5,4,10))\n",
    "plt.clabel(im1, inline=1, fontsize=8, colors='k')#, levels=np.linspace(-5,4,10))\n",
    "\n",
    "# ax0.grid()\n",
    "# ax1.grid()\n",
    "\n",
    "# ax.legend(h0+h1,['a','b'],ncol=2)\n",
    "# plt.legend([im0.collections[-1],im1.collections[-1]],['th','exp'])\n",
    "\n",
    "if 1:\n",
    "    name = 'pr_ar_linear_50_100_100'\n",
    "    now = datetime.datetime.now()\n",
    "    savename = now.strftime('%Y%m%d%H%M%S')\n",
    "    print(name+savename)\n",
    "    plt.savefig('../../fig/'+name+savename+'.pdf')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ffcv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
