{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-11T02:56:04.565098Z",
     "start_time": "2019-12-11T02:56:04.542160Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from collections import OrderedDict\n",
    "from scipy.stats import norm, uniform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-11T02:56:05.512015Z",
     "start_time": "2019-12-11T02:56:05.493067Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "def __randnetbalanced(dims, samples, indegree, parminmax, errminmax, random_state):\n",
    "    \"\"\"\n",
    "    create a more balanced random network\n",
    "\n",
    "    Parameter\n",
    "    ---------\n",
    "    dims : int\n",
    "        number of variables\n",
    "    samples : int\n",
    "        number of samples\n",
    "    indegree : int or float('inf')\n",
    "        number of parents of each node (float('inf') = fully connected)\n",
    "    parminmax : dictionary\n",
    "        standard deviation owing to parents \n",
    "    errminmax : dictionary\n",
    "        standard deviation owing to error variable\n",
    "    random_state : np.random.RandomState\n",
    "        random state\n",
    "\n",
    "    Return\n",
    "    ------\n",
    "    B : array, shape (dims, dims)\n",
    "        the strictly lower triangular network matrix\n",
    "    errstd : array, shape (dims, 1)\n",
    "        the vector of error (disturbance) standard deviations\n",
    "    \"\"\"\n",
    "\n",
    "    # First, generate errstd\n",
    "    errstd = uniform.rvs(loc=errminmax['min'], scale=errminmax['max'], size=[\n",
    "                         dims, 1], random_state=random_state)\n",
    "\n",
    "    # Initializations\n",
    "    X = np.empty(shape=[dims, samples])\n",
    "    B = np.zeros([dims, dims])\n",
    "\n",
    "    # Go trough each node in turn\n",
    "    for i in range(dims):\n",
    "\n",
    "        # If indegree is finite, randomly pick that many parents,\n",
    "        # else, all previous variables are parents\n",
    "        if indegree == float('inf'):\n",
    "            if i <= indegree:\n",
    "                par = np.arange(i)\n",
    "            else:\n",
    "                par = random_state.permutation(i)[:indegree]\n",
    "        else:\n",
    "            par = np.arange(i)\n",
    "\n",
    "        if len(par) == 0:\n",
    "            # if node has no parents\n",
    "            # Increase errstd to get it to roughly same variance\n",
    "            parent_std = uniform.rvs(\n",
    "                loc=parminmax['min'], scale=parminmax['max'], random_state=random_state)\n",
    "            errstd[i] = np.sqrt(errstd[i]**2 + parent_std**2)\n",
    "\n",
    "            # Set data matrix to empty\n",
    "            X[i] = np.zeros(samples)\n",
    "        else:\n",
    "            # If node has parents, do the following\n",
    "            w = norm.rvs(size=[1, len(par)], random_state=random_state)\n",
    "\n",
    "            # Randomly pick weights\n",
    "            wfull = np.zeros([1, i])\n",
    "            wfull[0, par] = w\n",
    "\n",
    "            # Calculate contribution of parents\n",
    "            X[i] = np.dot(wfull, X[:i, :])\n",
    "\n",
    "            # Randomly select a 'parents std'\n",
    "            parstd = uniform.rvs(\n",
    "                loc=parminmax['min'], scale=parminmax['max'], random_state=random_state)\n",
    "\n",
    "            # Scale w so that the combination of parents has 'parstd' std\n",
    "            scaling = parstd / np.sqrt(np.mean(X[i] ** 2))\n",
    "            w = w * scaling\n",
    "\n",
    "            # Recalculate contribution of parents\n",
    "            wfull = np.zeros([1, i])\n",
    "            wfull[0, par] = w\n",
    "            X[i] = np.dot(wfull, X[:i, :])\n",
    "\n",
    "            # Fill in B\n",
    "            B[i, par] = w\n",
    "\n",
    "        # Update data matrix\n",
    "        X[i] = X[i] + norm.rvs(size=samples,\n",
    "                               random_state=random_state) * errstd[i]\n",
    "\n",
    "    return B, errstd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-11T02:56:05.712307Z",
     "start_time": "2019-12-11T02:56:05.695355Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "def generate_data(dims=6, samples=10000, seed=None):\n",
    "    \"\"\"\n",
    "    Parameter\n",
    "    ---------\n",
    "    dims : int\n",
    "        number of variables\n",
    "    samples : int\n",
    "        number of samples\n",
    "    seed : int\n",
    "        seed for random_state. Default is None.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    .. [1] SS. Shimizu, P. O. Hoyer, A. Hyvärinen, and A. J. Kerminen.\n",
    "       A linear non-gaussian acyclic model for causal discovery.\n",
    "       Journal of Machine Learning Research, 7:2003-2030, 2006.\n",
    "    .. [2] http://www.cs.helsinki.fi/group/neuroinf/lingam/lingam.tar.gz\n",
    "       ./lingam-1.4.2/code/randnetbalanced.m\n",
    "       ./lingam-1.4.2/code/jmlr/plotsjmlr.m\n",
    "    \"\"\"\n",
    "\n",
    "    random_state = np.random.RandomState(seed=seed)\n",
    "\n",
    "    # Randomly select sparse/full connections\n",
    "    sparse = np.floor(uniform.rvs(random_state=random_state) * 2).astype(int)\n",
    "    indegree = (np.floor(uniform.rvs(random_state=random_state) *\n",
    "                         3).astype(int) + 1) if sparse == 1 else float('inf')\n",
    "\n",
    "    # Create the network\n",
    "    B, disturbancestd = __randnetbalanced(dims, samples, indegree, {\n",
    "                                          'min': 0.5, 'max': 1.5}, {'min': 0.5, 'max': 1.5}, random_state)\n",
    "\n",
    "    # constants, giving non-zero means\n",
    "    c = 2 * norm.rvs(size=[dims, 1], random_state=random_state)\n",
    "\n",
    "    # nonlinearity exponent, in [0.5, 0.8] or [1.2, 2.0].\n",
    "    q = uniform.rvs(size=[dims, 1], random_state=random_state) * 1.1 + 0.5\n",
    "    index = q > 0.8\n",
    "    q[index] = q[index] + 0.4\n",
    "\n",
    "    # This generates the disturbance variables, which are mutually\n",
    "    # independent, and non-gaussian\n",
    "    S = norm.rvs(size=[dims, samples], random_state=random_state)\n",
    "    S = np.sign(S) * (np.abs(S) ** np.dot(q, np.ones([1, samples])))\n",
    "\n",
    "    # This normalizes the disturbance variables to have the\n",
    "    # appropriate scales\n",
    "    S = S / ((np.sqrt(np.mean(S.T**2).T) / disturbancestd)\n",
    "             * np.ones([1, samples]))\n",
    "\n",
    "    # Now we generate the data one component at a time\n",
    "    Xorig = np.zeros([dims, samples])\n",
    "    for i in range(dims):\n",
    "        Xorig[i] = np.dot(B[i], Xorig) + S[i] + c[i]\n",
    "\n",
    "    # Select a random permutation because we do not assume that\n",
    "    # we know the correct ordering of the variables\n",
    "    p = random_state.permutation(dims)\n",
    "\n",
    "    # Permute the rows of the data matrix, to give us the\n",
    "    # observed data\n",
    "    X = Xorig[p]\n",
    "\n",
    "    # Permute the rows and columns of the original generating\n",
    "    # matrix B so that they correspond to the actual data\n",
    "    Bp = B[p][:, p]\n",
    "\n",
    "    # Permute the generating disturbance stds so that they\n",
    "    # correspond to the actual data\n",
    "    disturbancestdp = disturbancestd[p]\n",
    "\n",
    "    # Permute the generating constants so that they correspond to\n",
    "    # the actual data\n",
    "    cp = c[p]\n",
    "\n",
    "    # make causal order\n",
    "    causal_order = np.empty(len(p))\n",
    "    causal_order[p] = np.arange(len(p))\n",
    "\n",
    "    return X.T, Bp, causal_order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-11T02:56:05.843277Z",
     "start_time": "2019-12-11T02:56:05.835332Z"
    }
   },
   "outputs": [],
   "source": [
    "def make_test_data(dims_list, samples_list, data_set_num, set_seed=True):\n",
    "    \"\"\" make some data sets to test.\n",
    "    \"\"\"\n",
    "    test_data_set = OrderedDict()\n",
    "\n",
    "    count = 0\n",
    "    for dims in dims_list:\n",
    "        for samples in samples_list:\n",
    "            test_data_set[(dims, samples)] = []\n",
    "\n",
    "            for data_no in range(data_set_num):\n",
    "                X, Bp, causal_order = generate_data(\n",
    "                    dims=dims, samples=samples, seed=(count if set_seed else None))\n",
    "                test_data_set[(dims, samples)].append({'X': X, 'B': Bp, 'causal_order': causal_order})\n",
    "\n",
    "                count += 1\n",
    "\n",
    "    return test_data_set"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-11T02:57:28.130770Z",
     "start_time": "2019-12-11T02:57:27.850790Z"
    }
   },
   "outputs": [],
   "source": [
    "dims_list = [10, 50, 100]\n",
    "samples_list = [200, 1000, 5000]\n",
    "data_set_num = 10\n",
    "\n",
    "test_data_set = make_test_data(dims_list, samples_list, data_set_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
