{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c3c2f71",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import time\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import sys, os\n",
    "sys.path.append(\"../..\")\n",
    "from ecit import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bb67f8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_relation_from_txt(file_path, label=\"CI\"):\n",
    "    rules = []\n",
    "\n",
    "    def normalize(var):\n",
    "        return var.strip()\n",
    "    \n",
    "\n",
    "    with open(file_path, 'r') as f:\n",
    "        for line in f:\n",
    "            line = line.strip()\n",
    "            if ';' not in line:\n",
    "                continue\n",
    "\n",
    "            line = re.sub(r'^\\d+\\.\\s*', '', line)\n",
    "\n",
    "            Z_vars = []\n",
    "            Z_raw = re.findall(r'\\((.*?)\\)', line)\n",
    "            if Z_raw:\n",
    "                Z_vars = [normalize(z) for z in Z_raw[0].split(',')]\n",
    "                line = re.sub(r'\\(.*?\\)', '', line)\n",
    "\n",
    "            parts = [normalize(p) for p in line.split(';') if p.strip()]\n",
    "            if len(parts) < 3 and Z_vars:\n",
    "                X, Y = parts[0], parts[1]\n",
    "            elif len(parts) == 3:\n",
    "                X, Y = parts[0], parts[1]\n",
    "                Z_vars = [normalize(parts[2])]\n",
    "            else:\n",
    "                continue\n",
    "\n",
    "            rules.append({\n",
    "                \"X\": X,\n",
    "                \"Y\": Y,\n",
    "                \"Z\": Z_vars,\n",
    "                \"label\": label\n",
    "            })\n",
    "\n",
    "    return rules\n",
    "\n",
    "ci_r = parse_relation_from_txt(\"CI.txt\", label=\"CI\")\n",
    "ni_r = parse_relation_from_txt(\"NI.txt\", label=\"NI\")\n",
    "all_r = ci_r + ni_r\n",
    "all_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4df7f53",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = pd.read_excel(\"Data Files/1. cd3cd28.xls\")\n",
    "df2 = pd.read_excel(\"Data Files/2. cd3cd28icam2.xls\")\n",
    "df = pd.concat([df1, df2], axis=0, ignore_index=True)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "254a36a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def simu(cit, p_ensemble_list, k, alpha=0.05, show=True):\n",
    "    TP = np.array([0]*len(p_ensemble_list))\n",
    "    TN = np.array([0]*len(p_ensemble_list))\n",
    "    FP = np.array([0]*len(p_ensemble_list))\n",
    "    FN = np.array([0]*len(p_ensemble_list))\n",
    "\n",
    "    for rule in tqdm(all_r, disable=not show):\n",
    "        label = rule[\"label\"]\n",
    "        X = df[[rule[\"X\"]]].to_numpy()\n",
    "        Y = df[[rule[\"Y\"]]].to_numpy()\n",
    "        Z = df[rule[\"Z\"]].to_numpy()\n",
    "        dz = Z.shape[1]\n",
    "        obj_ECIT = ECIT(np.hstack((X,Y,Z)), cit, p_ensemble_list, k)\n",
    "        try:\n",
    "            ps = obj_ECIT([0], [1], list(range(2, dz + 2)))\n",
    "        except Exception as e:\n",
    "            if show:\n",
    "                print(f\"First attempt failed on rule {rule}, retrying... Error: {e}\")\n",
    "            try:\n",
    "                ps = obj_ECIT([0], [1], list(range(2, dz + 2)))\n",
    "            except Exception as e:\n",
    "                if show:\n",
    "                    print(f\"Second attempt failed, skipping rule. Error: {e}\")\n",
    "                continue\n",
    "        ps = np.array(ps)\n",
    "        if label == \"CI\":\n",
    "            TP += ps > alpha\n",
    "            FP += ps <= alpha\n",
    "        else:\n",
    "            TN += ps <= alpha\n",
    "            FN += ps > alpha\n",
    "    pre = TP / (TP + FP )\n",
    "    rec = TP / (TP + FN )\n",
    "    f1 = 2 * pre * rec / (pre + rec)\n",
    "\n",
    "    results = np.array([TP, TN, FP, FN, pre, rec, f1])\n",
    "    results = results.T\n",
    "    if show:\n",
    "        print(results)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f64abd01",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_simu(cit_list, ens_list, t=10, alpha=0.05):\n",
    "    results = {}\n",
    "    for cit in cit_list:\n",
    "        table = []\n",
    "        for k, p_ensemble in ens_list:\n",
    "            ens = np.zeros((len(p_ensemble), 7))\n",
    "            ti = 1 if k==1 and cit.__name__!='lpcit' else t\n",
    "            if cit.__name__ == \"rcit\": ti = 100\n",
    "            for _ in tqdm(range(ti), desc=cit.__name__+str(k)):\n",
    "                ens += simu(cit, p_ensemble, k, alpha, show=False)\n",
    "            ens = ens / ti\n",
    "            for en in ens:\n",
    "                table.append(list(en))\n",
    "        results[cit.__name__] = table\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51bf1fa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "from sklearn.exceptions import ConvergenceWarning\n",
    "warnings.simplefilter(\"ignore\", category=ConvergenceWarning)\n",
    "\n",
    "np.random.seed(1)\n",
    "\n",
    "cit_list = [rcit, kcit, lpcit, cmiknn, ccit, fisherz]\n",
    "ens_list = [(1, [p_alpha2]), (5, [p_alpha175, p_alpha2])]\n",
    "results = run_simu(cit_list, ens_list)\n",
    "results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graph",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
