{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "from ecit import *\n",
    "\n",
    "\n",
    "\n",
    "def ecitSimulateAlpha(methods, \n",
    "                 zDis_list=['gaussian'], \n",
    "                 noiseDis_list =['laplace'], \n",
    "                 n_list=[1200], \n",
    "                 t=10000):\n",
    "    \n",
    "    results_table = []\n",
    "\n",
    "    for n in n_list:\n",
    "        for zDis in zDis_list:\n",
    "            for cit, k, p_ensemble in methods:\n",
    "                row =  [n, zDis.capitalize(), cit.__name__ + str(k) + p_ensemble[0].__name__ if isinstance(p_ensemble, list) else cit.__name__ + str(k) + p_ensemble.__name__]\n",
    "\n",
    "                for noiseDis in tqdm(noiseDis_list, desc=f\"{n:>5}, {zDis:>8}, {k:>3},{p_ensemble[0].__name__ if isinstance(p_ensemble, list) else p_ensemble.__name__ :>11}\"):\n",
    "                    if isinstance(p_ensemble, list):\n",
    "                        eI = np.array([0]*len(p_ensemble))\n",
    "                        eII = np.array([0]*len(p_ensemble))\n",
    "                        for i in range(t):\n",
    "                            retries = 0\n",
    "                            while retries < 5:\n",
    "                                try:\n",
    "                                    dataI = np.hstack((generate_samples(n=n,indp='C',z_dis=zDis, noise_dis=noiseDis,noise_std=1)))\n",
    "                                    dataII = np.hstack((generate_samples(n=n,indp='N',z_dis=zDis, noise_dis=noiseDis,noise_std=1)))\n",
    "                                    obj_ECIT = ECIT(dataI, cit, p_ensemble, k if k<100 else int(n/k))\n",
    "                                    pI = obj_ECIT([0], [1], [2], multi=True)\n",
    "                                    obj_ECIT = ECIT(dataII, cit, p_ensemble, k if k<100 else int(n/k))\n",
    "                                    pII = obj_ECIT([0], [1], [2], multi=True)\n",
    "                                    break\n",
    "                                except Exception as e:\n",
    "                                    retries += 1\n",
    "                                    print(f\"Retries times {retries}\")\n",
    "                                    if retries >= 5: raise e\n",
    "                            eI += (np.array(pI) < 0.05).astype(int)\n",
    "                            eII += (np.array(pII) > 0.05).astype(int)\n",
    "                        eI = eI/t\n",
    "                        power = (t - eII)/t\n",
    "                        row.append(eI)\n",
    "                        row.append(power)\n",
    "                    else:\n",
    "                        eI = 0\n",
    "                        eII = 0\n",
    "                        for i in range(t):\n",
    "                            retries = 0\n",
    "                            while retries < 5:\n",
    "                                try:\n",
    "                                    dataI = np.hstack((generate_samples(n=n,indp='C',z_dis=zDis, noise_dis=noiseDis,noise_std=1)))\n",
    "                                    dataII = np.hstack((generate_samples(n=n,indp='N',z_dis=zDis, noise_dis=noiseDis,noise_std=1)))\n",
    "                                    obj_ECIT = ECIT(dataI, cit, p_ensemble, k if k<100 else int(n/k))\n",
    "                                    pI = obj_ECIT([0], [1], [2], multi=isinstance(p_ensemble, list))\n",
    "                                    obj_ECIT = ECIT(dataII, cit, p_ensemble, k if k<100 else int(n/k))\n",
    "                                    pII = obj_ECIT([0], [1], [2], multi=isinstance(p_ensemble, list))\n",
    "                                    break\n",
    "                                except Exception as e:\n",
    "                                    retries += 1\n",
    "                                    print(f\"Retries times {retries}\")\n",
    "                                    if retries >= 5: raise e\n",
    "                            if pI<0.05:\n",
    "                                eI += 1\n",
    "                            if pII>0.05:\n",
    "                                eII += 1\n",
    "                        eI = eI/t\n",
    "                        power = (t - eII)/t\n",
    "                        row.append(eI)\n",
    "                        row.append(power) \n",
    "                results_table.append(row)\n",
    "\n",
    "    return results_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enplist = [p_alpha2, p_alpha175, p_alpha15, p_alpha125, p_alpha1, p_alpha075, p_alpha05, p_alpha025, p_alpha01, p_mean]\n",
    "ensCIT = [(kcit, 400, enplist)]\n",
    "results = ecitSimulateAlpha(ensCIT, n_list=[1200], t=10000)\n",
    "results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graph",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "share": {
   "datetime": "2025-01-16T05:48:24.740Z",
   "image": {
    "name": "modelscope:1.18.0-pytorch2.3.0-cpu-py310-ubuntu22.04",
    "url": "dsw-registry-vpc.cn-hangzhou.cr.aliyuncs.com/pai/modelscope:1.18.0-pytorch2.3.0-cpu-py310-ubuntu22.04"
   },
   "instance": "dsw-03a689ba3735b16d",
   "spec": {
    "id": "ecs.g6.xlarge",
    "type": "CPU"
   },
   "uid": "1260733139507565"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
