{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "39b8268c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yugui/opt/anaconda3/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import scipy.stats\n",
    "from scipy.stats import norm\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from tqdm.auto import tqdm\n",
    "import time\n",
    "import scipy.linalg as scilinalg\n",
    "import seaborn as sns\n",
    "from scipy.stats import ortho_group\n",
    "import pandas as pd\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import multiprocessing as mp\n",
    "from joblib import Parallel, delayed\n",
    "from utils import *\n",
    "\n",
    "from conf_simu import *\n",
    "\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4843ec4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def simu_hetero(d1,d2,alpha,het,sd,tail,pr,k_star,rk,M_mean,mis_set,full_exp=False):\n",
    "    # generate underlying matrix M_star, observed indices S\n",
    "    M_star, A, P, S = gen_data(d1,d2,het,sd,tail,pr,M_mean,mis_set,k_star)\n",
    "    # cmc-als\n",
    "    coverage_cmc_als, coverage_cmc_als_hat, length_cmc_als, length_cmc_als_hat = cfmc_simu_hetero(M_star,S,P,\"als\",rk,alpha,het)\n",
    "    if full_exp:\n",
    "        # cmc-cvx\n",
    "        coverage_cmc_cvx, coverage_cmc_cvx_hat, length_cmc_cvx, length_cmc_cvx_hat = cfmc_simu_hetero(M_star,S,P,\"cvx\",rk,alpha,het)\n",
    "        return coverage_cmc_cvx, coverage_cmc_als, coverage_cmc_cvx_hat, coverage_cmc_als_hat, length_cmc_cvx, length_cmc_als, length_cmc_cvx_hat, length_cmc_als_hat\n",
    "    else:\n",
    "        return coverage_cmc_als, coverage_cmc_als_hat, length_cmc_als, length_cmc_als_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e1c6d4c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter: 1\n",
      "iter: 2\n",
      "iter: 3\n",
      "iter: 4\n",
      "iter: 5\n",
      "iter: 6\n",
      "iter: 7\n",
      "Function value changing by less than progTol\n",
      "iter: 1\n",
      "iter: 2\n",
      "iter: 3\n",
      "iter: 4\n",
      "iter: 5\n",
      "iter: 6\n",
      "iter: 7\n",
      "Function value changing by less than progTol\n",
      "coverage rate: \n",
      "\t cmc*-als 0.8995330454045611\n",
      "\t cmc*-cvx 0.8950069326023555\n",
      "\t cmc-als 0.8992302420832868\n",
      "\t cmc-cvx 0.8989991553381038\n",
      "Average length: \n",
      "\t cmc*-als 3.4331\n",
      "\t cmc*-cvx 3.474\n",
      "\t cmc-als 3.4376\n",
      "\t cmc-cvx 3.5119\n"
     ]
    }
   ],
   "source": [
    "d1 = d2 = 500\n",
    "alpha = 0.1\n",
    "sd = 1\n",
    "het = 'logis2' \n",
    "pr = 0.8\n",
    "rk = 8\n",
    "M_mean = 1\n",
    "mis_set = 0\n",
    "k_star = 8\n",
    "full_exp = True\n",
    "tail = 'gaussian'\n",
    "if full_exp:\n",
    "    coverage_cmc_cvx, coverage_cmc_als, coverage_cmc_cvx_hat, coverage_cmc_als_hat, length_cmc_cvx, length_cmc_als, length_cmc_cvx_hat, length_cmc_als_hat = simu_hetero(d1,d2,alpha,het,sd,tail,pr,k_star,rk,M_mean,mis_set,full_exp=full_exp)\n",
    "    print('coverage rate: \\n\\t cmc*-als {}\\n\\t cmc*-cvx {}\\n\\t cmc-als {}\\n\\t cmc-cvx {}'.format(coverage_cmc_als, coverage_cmc_cvx, coverage_cmc_als_hat, coverage_cmc_cvx_hat))\n",
    "    print('Average length: \\n\\t cmc*-als {}\\n\\t cmc*-cvx {}\\n\\t cmc-als {}\\n\\t cmc-cvx {}'.format(length_cmc_als, length_cmc_cvx, length_cmc_als_hat, length_cmc_cvx_hat))\n",
    "else:\n",
    "    coverage_cmc_als, coverage_cmc_als_hat, length_cmc_als, length_cmc_als_hat = simu_hetero(d1,d2,alpha,het,sd,tail,pr,k_star,rk,M_mean,mis_set,full_exp=full_exp)\n",
    "    print('coverage rate: \\n\\t cmc*-als {}\\n\\t cmc-als {}'.format(coverage_cmc_als, coverage_cmc_als_hat))\n",
    "    print('Average length: \\n\\t cmc*-als {}\\n\\t cmc-als {}'.format(length_cmc_als, length_cmc_als_hat))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdf22178",
   "metadata": {},
   "outputs": [],
   "source": [
    "# repeated\n",
    "# remember to disable plotting\n",
    "alpha = 0.1\n",
    "sd0 = 1\n",
    "base1 = 'cvx'\n",
    "base2 = base = 'als'\n",
    "repN = 100\n",
    "pr = 0.8\n",
    "M_mean = 1\n",
    "k_star = 8\n",
    "tail = 'gaussian'\n",
    "sigma_true=False\n",
    "full_exp=False\n",
    "num_cores = mp.cpu_count()\n",
    "\n",
    "\n",
    "for d in [500]:\n",
    "    d1 = d2 = d\n",
    "    for full_exp in [False,True]:\n",
    "        if full_exp:\n",
    "            rk_seq = range(4,25,4)\n",
    "        else:\n",
    "            rk_seq = range(2,41,2)\n",
    "        for het in ['logis2','logis1']:\n",
    "            for mis_set in [4,5,0]:\n",
    "                for rk in rk_seq:\n",
    "                    if mis_set==4:\n",
    "                        tail='het'\n",
    "                        sd = sd0\n",
    "                    elif mis_set==5:\n",
    "                        tail='het1'\n",
    "                        sd=sd0\n",
    "                    else:\n",
    "                        tail='gaussian'\n",
    "                        sd = sd0\n",
    "                    print([mis_set, tail, sd])\n",
    "                    if __name__ == \"__main__\":\n",
    "                        results = Parallel(n_jobs=num_cores)(delayed(simu_hetero)(d1,d2,alpha,het,sd,tail,pr,k_star,rk,M_mean,mis_set,full_exp=full_exp) for i in range(repN))\n",
    "                    results = np.array(results)\n",
    "\n",
    "                    if full_exp:\n",
    "                        res_mat = results.reshape(repN,8)\n",
    "\n",
    "                        cov_rt_cvx = res_mat[:,0]\n",
    "                        cov_rt_als = res_mat[:,1]\n",
    "                        cov_rt_cvx_hat = res_mat[:,2]\n",
    "                        cov_rt_als_hat = res_mat[:,3]\n",
    "                        cov_ = np.hstack((cov_rt_cvx, cov_rt_als))\n",
    "                        cov_ = np.hstack((cov_,cov_rt_cvx_hat))\n",
    "                        cov_ = np.hstack((cov_,cov_rt_als_hat))\n",
    "                        len_ave_cvx = res_mat[:,4]\n",
    "                        len_ave_als = res_mat[:,5]\n",
    "                        len_ave_cvx_hat = res_mat[:,6]\n",
    "                        len_ave_als_hat = res_mat[:,7]\n",
    "                        len_ = np.hstack((len_ave_cvx, len_ave_als))\n",
    "                        len_ = np.hstack((len_,len_ave_cvx_hat))\n",
    "                        len_ = np.hstack((len_,len_ave_als_hat))\n",
    "\n",
    "                        label1 = 'cf*-'+base1\n",
    "                        label2 = 'cf*-'+base2\n",
    "                        label3 = 'cf-'+base1\n",
    "                        label4 = 'cf-'+base2\n",
    "                        nam_ = [label1]*repN + [label2]*repN + [label3]*repN + [label4]*repN\n",
    "                        cov_df = pd.DataFrame(cov_, columns=['coverage'])\n",
    "                        len_df = pd.DataFrame(len_, columns=['length'])\n",
    "                        cov_df['approach'] = nam_\n",
    "                        len_df['approach'] = nam_\n",
    "\n",
    "                    else:\n",
    "                        res_mat = results.reshape(repN,4)\n",
    "\n",
    "                        cov_rt_als = res_mat[:,0]\n",
    "                        cov_rt_als_hat = res_mat[:,1]\n",
    "                        cov_ = np.hstack((cov_rt_als, cov_rt_als_hat))\n",
    "                        len_ave_als = res_mat[:,2]\n",
    "                        len_ave_als_hat = res_mat[:,3]\n",
    "                        len_ = np.hstack((len_ave_als, len_ave_als_hat))\n",
    "\n",
    "                        label1 = 'cf*-'+base2\n",
    "                        label2 = 'cf-'+base2\n",
    "                        nam_ = [label1]*repN + [label2]*repN\n",
    "                        cov_df = pd.DataFrame(cov_, columns=['coverage'])\n",
    "                        len_df = pd.DataFrame(len_, columns=['length'])\n",
    "                        cov_df['approach'] = nam_\n",
    "                        len_df['approach'] = nam_\n",
    "\n",
    "                    file_dir1 = '../results/hetero_cov_'+str(d1)+'_'+str(mis_set)+'_'+str(k_star)+'_'+str(rk)+'_'+het+'_'+str(sd)+'_'+tail+'_'+str(full_exp)+'.csv'\n",
    "                    file_dir2 = '../results/hetero_len_'+str(d1)+'_'+str(mis_set)+'_'+str(k_star)+'_'+str(rk)+'_'+het+'_'+str(sd)+'_'+tail+'_'+str(full_exp)+'.csv'\n",
    "                    cov_df.to_csv(file_dir1, index=False)\n",
    "                    len_df.to_csv(file_dir2, index=False)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
