{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.mixture import GaussianMixture\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import logging\n",
    "from src.exp.gen.generate import NoiseType, IvType, FunType, DagType, IvMode, gen_data_type\n",
    "from src.mixtures.mixing.mixing import MixingType\n",
    "from src.examples.util import demo_clustering\n",
    "from src.exp.algos import CD\n",
    "from src.exp.gen.generate import GSType"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {},
   "source": [
    "### Fig. 1: Mixing for a causal relationship X -> Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "params = { 'N': 2, 'S': 1000, 'P': 1, 'K': 2, 'C': 5,  'PZ': 0.5, 'NZ': 2,\n",
    "           'NS': NoiseType.GAUSS, 'F': FunType.LIN, 'DG': DagType.ERDOS,\n",
    "           'IVT': IvType.FLIP, 'IVM': IvMode.MIXING,\n",
    "           'GS': GSType.BIV_CAUSAL_CHANGEY}\n",
    "\n",
    "data, truths = gen_data_type(params, seed)\n",
    "print(f\"\\tMixed Nodes: {truths['t_n_Z']}\" )\n",
    "truths[\"_dg\"].plot_X(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Discover the mixture\n",
    "(ours, _) = demo_clustering(data, truths, params, mixing_ty=MixingType.MIX_LIN, causaldiscovery_method_ty = CD.SKIP,  ORACLE_G = True, ORACLE_K = False, KMAX=5, ret_model=True)\n",
    "#truths[\"_dg\"].plot_X_idls(data,e_Z_n)\n",
    "print(f\"\\tDiscovered Nodes: {ours.e_n_Z}\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from src.mixtures.util.utils_idl import pi_join\n",
    "import numpy as np\n",
    "\n",
    "i = 2\n",
    "pa_i = list(truths[\"_dg\"].G.predecessors(i))\n",
    "\n",
    "true_labels = np.zeros(data.shape[0])\n",
    "for zi, node_set in enumerate(truths[\"_dg\"].conf_ind_sets):\n",
    "    if i in node_set:\n",
    "        true_labels = pi_join(true_labels, truths[\"_dg\"].Zs[zi])\n",
    "\n",
    "for ix, pa in enumerate(pa_i):\n",
    "    df = pd.DataFrame({\n",
    "        'x': data[:, pa],\n",
    "        'y': data[:, i],\n",
    "        'c':  true_labels\n",
    "    })\n",
    "\n",
    "    df.to_csv('illustration_1.tsv', sep='\\t', index=False)\n",
    "    df = pd.read_csv('../../results_paper/illustration_1.tsv', sep='\\t')\n",
    "\n",
    "    df['x'] = (df['x'] - df['x'].min()) / (df['x'].max() - df['x'].min())\n",
    "    df['y'] = (df['y'] - df['y'].min()) / (df['y'].max() - df['y'].min())\n",
    "    df.to_csv('illustration_1_nm.tsv', sep='\\t', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "truths[\"_dg\"].plot_X_idls(data,ours.e_Z_n)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "### Synthetic data generation and CMM fitting as in Fig. 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Synthetic data parameters\n",
    "params = {\n",
    "    'N': 10, 'S': 1000, 'P': 1, 'K': 2, 'C': 5, 'PZ': 0.5, 'NZ': 2,\n",
    "    'NS': NoiseType.GAUSS, 'F': FunType.LIN, 'DG': DagType.ERDOS,\n",
    "    'IVT': IvType.FLIP, 'IVM': IvMode.MIXING, 'GS': GSType.GRAPH }\n",
    "\n",
    "data, truths = gen_data_type(params, 42)\n",
    "print(f\"\\tMixed Nodes: {truths['t_n_Z']}\")\n",
    "#truths[\"_dg\"].plot_X(data)\n",
    "\n",
    "# Discover the mixture\n",
    "KMAX = 5\n",
    "(ours, _) = demo_clustering(data, truths, params, mixing_ty=MixingType.MIX_LIN, causaldiscovery_method_ty=CD.SKIP,  ORACLE_G=True, ORACLE_K=False, KMAX=KMAX, ret_model=True)\n",
    "#truths[\"_dg\"].plot_X_idls(data,e_Z_n)\n",
    "print(f\"\\tDiscovered Nodes: {ours.e_n_Z}\" )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in ours.topic_graph.nodes:\n",
    "    ours.visu_pproba_dens(i)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ours.visu_heatmatrix_nodepair_MI(hide_singleclus=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
