{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "from ecit import *\n",
    "\n",
    "\n",
    "\n",
    "def ecitSimulateAlpha(methods, \n",
    "                 zDis_list=['gaussian'], \n",
    "                 noiseDis_list =['t','laplace','cauchy'],\n",
    "                 n_list=[2000], \n",
    "                 t=1000):\n",
    "    \n",
    "    results_table = []\n",
    "\n",
    "    for n in n_list:\n",
    "        for zDis in zDis_list:\n",
    "            for cit, k, p_ensemble in methods:\n",
    "                row =  [n, zDis.capitalize(), cit.__name__ + str(k) + p_ensemble[0].__name__ if isinstance(p_ensemble, list) else cit.__name__ + str(k) + p_ensemble.__name__]\n",
    "\n",
    "                for noiseDis in tqdm(noiseDis_list, desc=f\"{n:>5}, {zDis:>8}, {k:>3},{p_ensemble[0].__name__ if isinstance(p_ensemble, list) else p_ensemble.__name__ :>11}\"):\n",
    "                    if isinstance(p_ensemble, list):\n",
    "                        eI = np.array([0]*len(p_ensemble))\n",
    "                        eII = np.array([0]*len(p_ensemble))\n",
    "                        for i in range(t):\n",
    "                            retries = 0\n",
    "                            while retries < 5:\n",
    "                                try:\n",
    "                                    dataI = np.hstack((generate_samples(n=n,indp='C',z_dis=zDis, noise_dis=noiseDis,noise_std=1)))\n",
    "                                    dataII = np.hstack((generate_samples(n=n,indp='N',z_dis=zDis, noise_dis=noiseDis,noise_std=1)))\n",
    "                                    obj_ECIT = ECIT(dataI, cit, p_ensemble, k if k<100 else int(n/k))\n",
    "                                    pI = obj_ECIT([0], [1], [2], multi=True)\n",
    "                                    obj_ECIT = ECIT(dataII, cit, p_ensemble, k if k<100 else int(n/k))\n",
    "                                    pII = obj_ECIT([0], [1], [2], multi=True)\n",
    "                                    break\n",
    "                                except Exception as e:\n",
    "                                    retries += 1\n",
    "                                    print(f\"Retries times {retries}\")\n",
    "                                    if retries >= 5: raise e\n",
    "                            eI += (np.array(pI) < 0.05).astype(int)\n",
    "                            eII += (np.array(pII) > 0.05).astype(int)\n",
    "                        eI = eI/t\n",
    "                        power = (t - eII)/t\n",
    "                        row.append(eI)\n",
    "                        row.append(power)\n",
    "                    else:\n",
    "                        eI = 0\n",
    "                        eII = 0\n",
    "                        for i in range(t):\n",
    "                            retries = 0\n",
    "                            while retries < 5:\n",
    "                                try:\n",
    "                                    dataI = np.hstack((generate_samples(n=n,indp='C',z_dis=zDis, noise_dis=noiseDis,noise_std=1)))\n",
    "                                    dataII = np.hstack((generate_samples(n=n,indp='N',z_dis=zDis, noise_dis=noiseDis,noise_std=1)))\n",
    "                                    obj_ECIT = ECIT(dataI, cit, p_ensemble, k if k<100 else int(n/k))\n",
    "                                    pI = obj_ECIT([0], [1], [2], multi=isinstance(p_ensemble, list))\n",
    "                                    obj_ECIT = ECIT(dataII, cit, p_ensemble, k if k<100 else int(n/k))\n",
    "                                    pII = obj_ECIT([0], [1], [2], multi=isinstance(p_ensemble, list))\n",
    "                                    break\n",
    "                                except Exception as e:\n",
    "                                    retries += 1\n",
    "                                    print(f\"Retries times {retries}\")\n",
    "                                    if retries >= 5: raise e\n",
    "                            if pI<0.05:\n",
    "                                eI += 1\n",
    "                            if pII>0.05:\n",
    "                                eII += 1\n",
    "                        eI = eI/t\n",
    "                        power = (t - eII)/t\n",
    "                        row.append(eI)\n",
    "                        row.append(power) \n",
    "                results_table.append(row)\n",
    "\n",
    "    return results_table\n",
    "\n",
    "\n",
    "\n",
    "def show_nk(table, x = [200,285,333,400,500,666,1000], OrgiI = 0.056, OrgiP = 0.9, lower=0.77, upper=1, save=False):\n",
    "    sns.set()\n",
    "\n",
    "    alpha = 0.95\n",
    "    markersize = 3.8\n",
    "    linewidth = 1.45\n",
    "\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(8, 2.4), dpi=500)\n",
    "\n",
    "    linestyles = ['--', '-', '-']\n",
    "    markers = ['^', 's', 'D']\n",
    "    colors = [\"#cf444d\", sns.color_palette(\"muted\")[9], sns.color_palette(\"muted\")[0]]\n",
    "\n",
    "    i=0\n",
    "    axes[0].axhline(y=OrgiI, color=\"#cf444d\", linestyle='--', linewidth=0.8, label='Orgi.')\n",
    "    axes[1].axhline(y=OrgiP, color=\"#cf444d\", linestyle='--', linewidth=0.8, label='Orgi.')\n",
    "    axes[1].legend(loc='upper left', fontsize=7.8, ncol=1)\n",
    "\n",
    "\n",
    "    i=1\n",
    "    print(len(x))\n",
    "    axes[0].plot(x, table['175I'], alpha=alpha, label=r'$\\alpha = 1.75$', linestyle=linestyles[i], marker=markers[i], markersize=markersize, linewidth=linewidth, color=colors[i])\n",
    "    axes[1].plot(x, table['175P'], alpha=alpha, label=r'$\\alpha = 1.75$', linestyle=linestyles[i], marker=markers[i], markersize=markersize, linewidth=linewidth, color=colors[i])\n",
    "\n",
    "\n",
    "    i=2\n",
    "    axes[0].plot(x, table['2I'], alpha=alpha, label=r'$\\alpha = 2$', linestyle=linestyles[i], marker=markers[i], markersize=markersize, linewidth=linewidth, color=colors[i])\n",
    "    axes[1].plot(x, table['2P'], alpha=alpha, label=r'$\\alpha = 2$', linestyle=linestyles[i], marker=markers[i], markersize=markersize, linewidth=linewidth, color=colors[i])\n",
    "\n",
    "\n",
    "    ax_eI, ax_eII = axes\n",
    "\n",
    "    ax_eI.set_xticks(x)\n",
    "    ax_eII.set_xticks(x)\n",
    "    ax_eI.set_ylim(0, 0.22)\n",
    "    ax_eII.set_ylim(lower, upper)\n",
    "    ax_eI.tick_params(axis='x', labelsize=7.5, pad=-2, rotation=45)\n",
    "    ax_eII.tick_params(axis='x', labelsize=7.5, pad=-2, rotation=45)\n",
    "    ax_eI.tick_params(axis='y', labelsize=7.5, pad=-2)\n",
    "    ax_eII.tick_params(axis='y', labelsize=7.5, pad=-2)\n",
    "\n",
    "\n",
    "    ax_eI.axhline(y=0.05, color='black', linestyle='-', alpha=0.6, linewidth = 0.5, label='Significance level')\n",
    "\n",
    "    ax_eI.set_title(\"Type I Error\", fontsize=11)\n",
    "    ax_eI.set_ylabel(\"Error Rate\", fontsize=10)\n",
    "\n",
    "    ax_eII.set_title(\"Power\", fontsize=11)\n",
    "    ax_eII.set_ylabel(\"Power\", fontsize=10)\n",
    "\n",
    "\n",
    "    ax_eI.set_xlabel(r\"$n_k$\", fontsize=10)\n",
    "    ax_eII.set_xlabel(r\"$n_k$\", fontsize=10)\n",
    "\n",
    "\n",
    "    ax_eI.legend(loc='upper right', fontsize=8.5, ncol=1)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    if save: plt.savefig(\"nk.pdf\", format='pdf', bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Combination Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enplist = [p_alpha2, tippett_method, edgington_method, fisher_method, pearson_method, mudholkar_method]\n",
    "ensCIT = [(kcit, 400, enplist)]\n",
    "results = ecitSimulateAlpha(ensCIT, n_list=[2000], noiseDis_list =['t','laplace','cauchy'], t=1000)\n",
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Subset Size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ensCIT = [(kcit, 1, p_alpha2)]\n",
    "results = ecitSimulateAlpha(ensCIT, n_list=[2000], noiseDis_list =['t','laplace','cauchy'], t=1000)\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enplist = [p_alpha2, p_alpha175]\n",
    "ensCIT = [(kcit, 200, enplist), (kcit, 285, enplist), (kcit, 333, enplist), (kcit, 400, enplist), (kcit, 500, enplist), (kcit, 666, enplist), (kcit, 1000, enplist)]\n",
    "results = ecitSimulateAlpha(ensCIT, n_list=[2000], noiseDis_list =['t','laplace','cauchy'], t=1000)\n",
    "results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graph",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "share": {
   "datetime": "2025-01-16T05:48:24.740Z",
   "image": {
    "name": "modelscope:1.18.0-pytorch2.3.0-cpu-py310-ubuntu22.04",
    "url": "dsw-registry-vpc.cn-hangzhou.cr.aliyuncs.com/pai/modelscope:1.18.0-pytorch2.3.0-cpu-py310-ubuntu22.04"
   },
   "instance": "dsw-03a689ba3735b16d",
   "spec": {
    "id": "ecs.g6.xlarge",
    "type": "CPU"
   },
   "uid": "1260733139507565"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
