{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import ortho_group\n",
    "import time\n",
    "import datetime\n",
    "\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "import scipy.linalg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 50\n",
    "niterX = 50 # 100\n",
    "niter = 50 # 100\n",
    "\n",
    "\n",
    "npoint = 40 # 40\n",
    "npoint_dense = 200\n",
    "xlogmax = 2\n",
    "EPS = 1e-12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "gamma = np.logspace(0,xlogmax,npoint)[1:]\n",
    "gamma = gamma + EPS*(gamma==1)\n",
    "# risk = (gamma<1)*(gamma/(1-gamma)) +(gamma>1)*(1-1/gamma+1/(gamma-1))\n",
    "ps = np.round(gamma*n).astype(int) #np.array([25,50,100]) #\n",
    "gamma = (ps/n)\n",
    "\n",
    "\n",
    "ps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_symmetric(a, rtol=1e-05, atol=1e-08):\n",
    "    return np.allclose(a, a.T, rtol=rtol, atol=atol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_list = []\n",
    "\n",
    "for multiple in [1,2,4]:\n",
    "    print(multiple)\n",
    "\n",
    "    nprocess = 0\n",
    "    U = ortho_group.rvs(dim=n)\n",
    "    UT = np.transpose(U)\n",
    "    D = np.diag(np.abs(np.random.randn(n)))\n",
    "\n",
    "    Omega = U@D@UT/D.sum()*n*multiple\n",
    "    ntotal = niter*niterX*len(ps)\n",
    "    tr_omega = np.trace(Omega)\n",
    "    print(tr_omega)\n",
    "\n",
    "    Rs = []\n",
    "    Vs = []\n",
    "    V2s = []\n",
    "    ts = []\n",
    "    t2s = []\n",
    "    t3s = []\n",
    "    ls = []\n",
    "    Bs = []\n",
    "    B2s = []\n",
    "    Zls =[]\n",
    "\n",
    "    for p in ps:\n",
    "        U = ortho_group.rvs(dim=p)\n",
    "        UT = np.transpose(U)\n",
    "        diag = np.abs(np.random.randn(p))\n",
    "        diag /= diag.sum()\n",
    "        D = np.diag(diag)\n",
    "        sqrt_D = np.diag(np.sqrt(diag))\n",
    "\n",
    "        Sigma = U@D@UT*p\n",
    "        sqrt_Sigma = U@sqrt_D@UT*np.sqrt(p)\n",
    "        \n",
    "        \n",
    "        l = 0.   \n",
    "        R = 0.\n",
    "        t = 0.\n",
    "        t2 = 0.\n",
    "        t3 = 0.\n",
    "        Z_l = 0.\n",
    "        l = 0.\n",
    "        V = 0.\n",
    "        V2 = 0.\n",
    "        Bias = 0.\n",
    "        for indX in range(niterX):\n",
    "            Z = np.random.multivariate_normal(np.zeros(p), np.eye(p), size=n)\n",
    "            ZT = np.transpose(Z)\n",
    "            ZZT = Z@ZT\n",
    "            # ZZTP = np.linalg.pinv(ZZT)\n",
    "\n",
    "            # Z_evals = np.linalg.eigvalsh(ZZT+EPS*np.eye(n))\n",
    "            # Z_evals = 1/Z_evals\n",
    "            # Z_evals_cut = Z_evals*(Z_evals<=100)\n",
    "            Z_l += 0 #(Z_evals_cut).sum()\n",
    "            # X = np.random.multivariate_normal(np.zeros(p), Sigma, size=n)\n",
    "            X = Z@sqrt_Sigma\n",
    "            \n",
    "            XT = np.transpose(X)\n",
    "            XXT = X@XT\n",
    "            XTX = XT@X\n",
    "            XTXP = np.linalg.pinv(XTX)\n",
    "            XP = np.linalg.pinv(X)\n",
    "            XTXPS = XTXP@Sigma\n",
    "\n",
    "            SXP = sqrt_Sigma@XP\n",
    "            SXPT = np.transpose(SXP)\n",
    "\n",
    "            # XXTP = np.linalg.pinv(XXT)\n",
    "            t += 0 # np.trace(XXTP@Omega)\n",
    "\n",
    "            # t2 = np.trace(XT@XXTP@Omega@XXTP@X@Sigma)\n",
    "            # t3 = np.trace(ZZTP@Omega)\n",
    "            # t3 = np.trace(XXTP@Omega)\n",
    "            beta = np.random.randn(p)\n",
    "            beta /= np.linalg.norm(beta)\n",
    "\n",
    "            evals = np.linalg.eigvalsh(SXP@SXPT+EPS*np.eye(p))\n",
    "\n",
    "            evals_cut = evals*(evals<=100) #(np.abs(np.real(evals))<=100)\n",
    "\n",
    "\n",
    "            l += (evals_cut).sum()\n",
    "            \n",
    "\n",
    "            BBT = np.zeros((p,p))\n",
    "            B = np.zeros(p).reshape(-1,1)\n",
    "            for _ in range(niter): \n",
    "                nprocess += 1\n",
    "                rate = nprocess / ntotal\n",
    "                print(\"process: %d, %.3f\"%(p,100*rate), end='\\r')\n",
    "                epsilon = np.random.multivariate_normal(np.zeros(n),Omega)\n",
    "                y = np.matmul(X,beta) + epsilon\n",
    "                betahat = scipy.linalg.lstsq(X, y, lapack_driver='gelsy')[0]\n",
    "                betahat = betahat.reshape(-1,1)                \n",
    "                betahatT = np.transpose(betahat)\n",
    "                BBT += betahat@betahatT\n",
    "                B += betahat\n",
    "                \n",
    "                # R += np.linalg.norm(beta-betahat.reshape(-1))**2\n",
    "            dbeta = B.reshape(-1,1)/niter-beta.reshape(-1,1)\n",
    "            Bias += np.linalg.norm(sqrt_Sigma@dbeta)**2\n",
    "            # Rs.append(R/niter)\n",
    "            Cov = BBT/niter-(B/niter)@np.transpose(B/niter)\n",
    "            CovS = Cov@Sigma\n",
    "            V += np.trace(CovS)\n",
    "\n",
    "        # Var2 = Cov@Sigma\n",
    "        # V2 = np.trace(Var2)\n",
    "\n",
    "            \n",
    "        # Bias2 = np.linalg.norm(sqrt_Sigma@dbeta)**2\n",
    "        Bs.append(Bias/niterX)\n",
    "        # B2s.append(Bias2)\n",
    "        Vs.append(V/niterX)\n",
    "        # V2s.append(V2/niterX)\n",
    "        ts.append(t/niterX)\n",
    "        # t2s.append(t2)\n",
    "        # t3s.append(t3)\n",
    "        ls.append(tr_omega/n*l/niterX)\n",
    "        # Zls.append(tr_omega/n*Z_l/niterX)\n",
    "        # print('del:',np.abs(Z_l-t3),t3,Z_l, end='\\r')\n",
    "    # double_Rs.append(Rs)\n",
    "\n",
    "\n",
    "    total_list.append([Vs,ts,ls,Zls,Bs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dense_gamma = np.logspace(-0.1,xlogmax,npoint_dense)\n",
    "dense_gamma = dense_gamma +EPS*(dense_gamma==1)\n",
    "bias = (dense_gamma<1)*(0) +(dense_gamma>1)*(1-1/dense_gamma)\n",
    "var = (dense_gamma<1)*(dense_gamma/(1-dense_gamma)) +(dense_gamma>1)*(1/(dense_gamma-1))\n",
    "risk = (dense_gamma<1)*(dense_gamma/(1-dense_gamma)) +(dense_gamma>1)*(1-1/dense_gamma+1/(dense_gamma-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "colors = ['k','b','r']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(12,4))\n",
    "\n",
    "\n",
    "# ax.plot(dense_gamma,risk,label='MSE',c='k',linewidth=2,alpha=0.5)\n",
    "\n",
    "ax.scatter(gamma,total_list[0][0],c=colors[0],marker='x',s=80,label='variance')\n",
    "ax.scatter(gamma,total_list[0][2],c=colors[0],marker='o',s=50,label='variance (theory)')\n",
    "ax.plot(dense_gamma,var,label='variance (theory, iso.)',c=colors[0],linestyle=':',alpha=1)\n",
    "\n",
    "\n",
    "for i in range(1,3):\n",
    "    ax.plot(dense_gamma,var*2**i,c=colors[i],linestyle=':',alpha=1)\n",
    "    ax.scatter(gamma,total_list[i][0],c=colors[i],marker='x',s=80)\n",
    "    ax.scatter(gamma,total_list[i][2],c=colors[i],marker='o',s=50)\n",
    "    # plt.scatter(gamma,total_list[i][3],c='purple',marker='o',s=5,label=r'$\\frac{1}{p}Tr(\\Lambda_{iso}^{-1})$')\n",
    "    # plt.scatter(gamma,V2s,c='r',marker='o',s=80,label=r'$Tr(Cov(\\hat\\beta)\\Sigma)$')\n",
    "    # plt.scatter(gamma,t2s,c='b',marker='o',s=30,label=r'$Tr(X^\\top(XX^\\top)^\\dagger \\Omega (XX^\\top)^\\dagger X\\Sigma)$')\n",
    "    # plt.scatter(gamma,t3s,c='orange',marker='x',s=30,label=r'$Tr((ZZ^\\top)^\\dagger \\Omega)$')\n",
    "    # plt.scatter(gamma,ls,c='g',marker='o',s=5,label=r'$\\sum_i 1/\\lambda_i(XX^\\top)$')\n",
    "    # plt.scatter(gamma,B2s,c='m',marker='v',s=5,label=r'$||\\mathbb{E}[\\hat\\beta]-\\beta||_\\Sigma^2$')\n",
    "ax.scatter(gamma,total_list[0][4],c='c',marker='v',s=5,label='bias')#||\\mathbb{E}[\\hat\\beta]-\\beta||^2$')\n",
    "\n",
    "ax.plot(dense_gamma,bias,label='bias (theory)',c='k',linestyle='--')\n",
    "# ax.plot(dense_gamma,2*var,label=r'$ s^\\ast_{iso}\\sigma^2 $',c=colors[1],linestyle=':',alpha=1)\n",
    "# ax.plot(dense_gamma,4*var,label=r'$ s^\\ast_{iso}\\sigma^2 $',c=colors[2],linestyle=':',alpha=1)\n",
    "\n",
    "\n",
    "\n",
    "# plt.scatter(gamma,Rs,c='r',marker='o',s=1,label=r'$\\mathbb{E}[||\\hat\\beta-\\beta||^2]$')\n",
    "# plt.scatter(gamma,new_Rs2,c='r',marker='o',s=5)\n",
    "# plt.scatter(gamma,new_Rs5,c='b',marker='o',s=5)\n",
    "\n",
    "ax.axhline(1,c='k',alpha=0.1)\n",
    "# plt.axhline(1/2,linestyle=':',c='r')\n",
    "# plt.axhline(1/5,linestyle=':',c='b')\n",
    "# plt.axvline(1/2,linestyle=':',c='r')\n",
    "# plt.axvline(1/5,linestyle=':',c='b')\n",
    "# plt.axvline(1,linestyle=':',c='k')\n",
    "ax.set_xscale('log')\n",
    "# plt.yscale('log')\n",
    "ax.set_yscale('symlog', linthresh=10**(-xlogmax))\n",
    "ax.set_xlim(1,10**xlogmax)\n",
    "ax.set_ylim(10**(-xlogmax),100)\n",
    "\n",
    "ax.set_xlabel(r'$\\gamma=p/n$')\n",
    "# ax.set_ylabel('MSE')\n",
    "# plt.title('Isotropic features, but anisotropic noise')\n",
    "plt.tight_layout()\n",
    "ax.legend(bbox_to_anchor=(1., 1.0))\n",
    "\n",
    "ax.grid(True,axis='both')\n",
    "# ax.yaxis.grid(True)\n",
    "\n",
    "name = 'new_pr_asymptotic_'\n",
    "now = datetime.datetime.now()\n",
    "savename = now.strftime('%Y%m%d%H%M%S')\n",
    "print(name+savename)\n",
    "if 1:\n",
    "    plt.savefig('../../fig/'+name+savename+'.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('asymp_PR.txt', 'wb') as f:\n",
    "    pickle.dump(total_list, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ffcv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
