{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import logging\n",
    "from src.exp.gen.generate import NoiseType, IvType, FunType, DagType, IvMode, gen_data_type\n",
    "from src.mixtures.mixing.mixing import MixingType"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig()\n",
    "lg = logging.getLogger(\"EXAMPLE\")\n",
    "lg.setLevel(\"INFO\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "### Synthetic data generation and causal graph discovery as in Fig. 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.exp.gen.generate import GSType\n",
    "\n",
    "# Data Parameters\n",
    "params = { 'N': 10, 'S': 1000, 'P': 0.3, 'K': 2, 'C': 5,  'PZ': 0.4, 'NZ': 2,\n",
    "           'NS': NoiseType.GAUSS, 'F': FunType.LIN, 'DG': DagType.ERDOS,\n",
    "           'IVT': IvType.SHIFT, 'IVM': IvMode.MIXING, 'IMAX': 3, 'GS': GSType.GRAPH}\n",
    "data, truths = gen_data_type(params, seed=42)\n",
    "\n",
    "print(f\"\\tMixed Nodes: {truths['t_n_Z']}\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.examples.util import demo_clustering, demo_causal_discovery\n",
    "from src.exp.algos import CD\n",
    "KMAX = 5\n",
    "res = demo_causal_discovery(data, truths, params, causaldiscovery_method_ty=CD.CausalMixtures, KMAX=KMAX, vb=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5",
   "metadata": {},
   "source": [
    "### Baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for cd_algo in CD:\n",
    "    if cd_algo.value in [CD.CausalMixtures.value, CD.SKIP.value, CD.MixtureUTIGSP.value]: continue\n",
    "    _ = demo_causal_discovery(data, truths, params, causaldiscovery_method_ty=cd_algo )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
