{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# LiM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Import and settings\n",
    "In this example, we need to import `numpy` and `random`, in addition to `lingam`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T07:00:32.471816Z",
     "start_time": "2021-06-25T07:00:26.396417Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['1.20.3', '1.6.0']\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import lingam\n",
    "import lingam.utils as ut\n",
    "\n",
    "print([np.__version__, lingam.__version__])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Test data\n",
    "First, we generate a causal structure with 2 variables, where one of them is randomly set to be a discrete variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T07:00:32.548612Z",
     "start_time": "2021-06-25T07:00:32.474811Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 1 discrete variable(s).\n",
      "The true adjacency matrix is:\n",
      " [[0.         0.        ]\n",
      " [1.07516228 0.        ]]\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(0)\n",
    "\n",
    "n_samples, n_features, n_edges, graph_type, sem_type = 1000, 2, 1, 'ER', 'mixed_random_i_dis'\n",
    "B_true = ut.simulate_dag(n_features, n_edges, graph_type)\n",
    "W_true = ut.simulate_parameter(B_true)  # row to column\n",
    "\n",
    "no_dis = np.random.randint(1, n_features)  # number of discrete vars.\n",
    "print('There are %d discrete variable(s).' % (no_dis))\n",
    "nodes = [iii for iii in range(n_features)]\n",
    "dis_var = random.sample(nodes, no_dis) # randomly select no_dis discrete variables\n",
    "dis_con = np.full((1, n_features), np.inf)\n",
    "for iii in range(n_features):\n",
    "    if iii in dis_var:\n",
    "        dis_con[0, iii] = 0  # 1:continuous;   0:discrete\n",
    "    else:\n",
    "        dis_con[0, iii] = 1\n",
    "\n",
    "X = ut.simulate_linear_mixed_sem(W_true, n_samples, sem_type, dis_con)\n",
    "\n",
    "print('The true adjacency matrix is:\\n', W_true)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Causal Discovery for Linear Mixed Data\n",
    "To run causal discovery, we create a `LiM` object and call the `fit` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T07:00:35.346987Z",
     "start_time": "2021-06-25T07:00:33.763696Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- One iteration passed..... 0.13316067647747082\n",
      "------- rho  is: 1.0\n",
      "--- One iteration passed..... 0.1331815102989981\n",
      "--- One iteration passed..... 0.13327526249587057\n",
      "--- One iteration passed..... 0.13421278446459756\n",
      "--- One iteration passed..... 0.07364525286249445\n",
      "------- rho  is: 1000.0\n",
      "--- One iteration passed..... 0.07372097475667573\n",
      "--- One iteration passed..... 0.07406172328049165\n",
      "--- One iteration passed..... 0.07714547946102278\n",
      "--- One iteration passed..... 0.08194416313287393\n",
      "------- rho  is: 1000000.0\n",
      "--- One iteration passed..... 0.08215911651048043\n",
      "--- One iteration passed..... 0.08097828664201549\n",
      "------- rho  is: 10000000.0\n"
     ]
    },
    {
     "data": {
      "text/plain": "<lingam.lim.LiM at 0x179b92824c0>"
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = lingam.LiM()\n",
    "model.fit(X, dis_con)\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Using the `adjacency_matrix_` properties, we can see the estimated adjacency matrix between mixed variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T07:00:35.378981Z",
     "start_time": "2021-06-25T07:00:35.367012Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The estimated adjacency matrix is:\n",
      " [[0.         0.        ]\n",
      " [1.24859572 0.        ]]\n"
     ]
    }
   ],
   "source": [
    "print('The estimated adjacency matrix is:\\n', model._adjacency_matrix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}