{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from causallearn.utils.cit import CIT\n",
    "import numpy as np\n",
    "\n",
    "def _binary_all(X_true):\n",
    "    X = X_true.copy()\n",
    "    for i in range(X.shape[1]):\n",
    "        x_tmp = X[:,i]\n",
    "        var = np.var(x_tmp)\n",
    "        # c = np.random.uniform(-var/8, var/8, 1).item()\n",
    "        c = np.median(x_tmp)\n",
    "        print(f\"Random variable: {i}, threshold: {c}\")\n",
    "        # x_tmp[x_tmp < c] = 0\n",
    "        # x_tmp[x_tmp > c] = 1\n",
    "        x_tmp = np.where(x_tmp > c, 1, 0)\n",
    "        X[:,i] = x_tmp\n",
    "    return X\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = 'dis_test'\n",
    "\n",
    "samples = 1000\n",
    "x1 = np.random.normal(5,1,samples)\n",
    "x2 = np.random.normal(-2,1,samples) + 2*x1\n",
    "x3 = x1 + np.random.normal(0,1, samples)\n",
    "\n",
    "X_true = np.array([x1, x2, x3]).T\n",
    "data_bin = _binary_all(X_true)\n",
    "\n",
    "dist_test_obj = CIT(data=data_bin, method=test)\n",
    "p_value = dist_test_obj(1,2,[0])\n",
    "p_value\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "causal_learn_pycharm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
