{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "from synthetic_datasets import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# exp1: linear\n",
    "\n",
    "# exp1: d=12, prob=0.5, 'lingam', s=np.ones([d]) (same variance)\n",
    "#       d=12, prob=0.5, 'gaussian', s=np.ones([d]) (same variance)\n",
    "#       d=12, prob=0.5, 'lingam', s=np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) (same variance)\n",
    "#       d=12, prob=0.5, 'gaussian', s=np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) (diff. noise variance)\n",
    "#       d=30, prob=0.2, 'gaussian', s=np.ones([d]) (same variance)\n",
    "\n",
    "seeds = [1, 2, 3, 4, 5, 6, 7, 9, 10]\n",
    "\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 4\n",
    "    n_samples=1000\n",
    "    W = generate_W(d=d, prob=0.5) # 0.2 \n",
    "    c = np.zeros(d)\n",
    "    s = np.ones([d]) # s = np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) different varicne\n",
    "    xs, b_, c_ = gen_data_given_model(W, s, c, n_samples=n_samples, noise_type='gaussian', permutate=True)\n",
    "    \n",
    "# save your data     \n",
    "    dir_name = os.path.join(os.getcwd(), 'gaussian_same_noise_d{}_size{}_seed{}'.format(d, n_samples, seed))\n",
    "    os.mkdir(dir_name)   \n",
    "    np.save(os.path.join(dir_name, 'data_index.npy'), xs)\n",
    "    np.save(os.path.join(dir_name, 'DAG.npy'), b_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# exp2: quadractiv \n",
    "\n",
    "seeds = [8]\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 10\n",
    "    W = generate_W(d=d, prob=0.5)\n",
    "    c = np.zeros(d)\n",
    "    #s = np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1)\n",
    "    s = np.ones([d])\n",
    "    xs, b_, c_ = gen_data_given_model_2nd_order(W, s, c, n_samples=5000, noise_type='lingam', permutate=True)\n",
    "    \n",
    "    # get the first 3000 samples\n",
    "    xs_norm = np.linalg.norm(xs, axis=1)\n",
    "    xs_th = sorted(xs_norm)[3000]\n",
    "    xs = xs[xs_norm < xs_th]\n",
    "    \n",
    "    # save data\n",
    "#     dir_name = os.path.join(os.getcwd(), 'lingam_quad_same_noise_seed{}'.format(seed))\n",
    "#     os.mkdir(dir_name)\n",
    "    \n",
    "#     np.save(os.path.join(dir_name, 'data.npy'), xs)\n",
    "#     np.save(os.path.join(dir_name, 'DAG.npy'), b_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# exp3: gp data\n",
    "\n",
    "The GP datasets can be generated using the R code of the ANM available at http://people.tuebingen.mpg.de/jpeters/onlineCodeANM.zip\n",
    "A python interface is from the GraN-DAG authos, available at datahttps://github.com/kurowasan/GraN-DAG/blob/964b698d49f507eb5d505e4511ed289f9a8ec01b/baselines_and_metrics/rcode/code_GraN-DAG/generateData.R"
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "# exp4 multiple datasets\n",
    "\n",
    "seeds = [8]\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 12\n",
    "    n_datasets = 10\n",
    "    n_samples_each = 1000\n",
    "    total_samples = n_datasets * n_samples_each\n",
    "    total_xs = np.zeros([total_samples, d])\n",
    "    total_xs_index = np.zeros([total_samples, d + 1])\n",
    "    total_xs_one_hot = np.zeros([total_samples, d + n_datasets])\n",
    "    W = generate_W(d=d, prob=0.5) # 0.2\n",
    "    c = np.zeros(d)\n",
    "    s = np.ones([d]) # s = np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) different varicne\n",
    "    W_permute, W_domain_related, c_permute = generate_W_domain_related(W, c, n_domain=n_datasets, permute=True)\n",
    "\n",
    "    for domain_index in range(n_datasets):\n",
    "        xs, b_, c_ = gen_data_given_model(W_domain_related[domain_index, :, :], s, c_permute, n_samples=n_samples_each, noise_type='gaussian')\n",
    "        xs_index = np.concatenate((xs, domain_index*np.ones([n_samples_each, 1])), axis=1)\n",
    "        one_hot = np.zeros([n_samples_each, n_datasets])\n",
    "        one_hot[:, domain_index] = 1\n",
    "        xs_one_hot = np.concatenate((xs, one_hot), axis=1)\n",
    "        total_xs[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs\n",
    "        total_xs_index[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_index\n",
    "        total_xs_one_hot[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_one_hot\n",
    "\n",
    "# save your data\n",
    "    dir_name = os.path.join(os.getcwd(), 'multiple_gaussian_same_noise_seed{}'.format(seed))\n",
    "    os.mkdir(dir_name)\n",
    "    np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    np.save(os.path.join(dir_name, 'DAG.npy'), W_permute)\n",
    "    np.save(os.path.join(dir_name, 'DAG_multiple_domain.npy'), W_domain_related)\n",
    "    # np.save(os.path.join(dir_name, 'data_index.npy'), total_xs_index)\n",
    "    np.save(os.path.join(dir_name, 'data_one_hot.npy'), total_xs_one_hot)\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 4,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "ename": "FileExistsError",
     "evalue": "[Errno 17] File exists: '/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/baseline1_chain_varx1_bigger_varx2_gaussian_same_noise_d4_size10000_seed1'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileExistsError\u001b[0m                           Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-1ef743cc5383>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     22\u001b[0m \u001b[0;31m# save your data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m     \u001b[0mdir_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetcwd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'baseline1_chain_varx1_bigger_varx2_gaussian_same_noise_d{}_size{}_seed{}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseed\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m     \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmkdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdir_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     25\u001b[0m     \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdir_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'data.npy'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m     \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdir_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'DAG.npy'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileExistsError\u001b[0m: [Errno 17] File exists: '/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/baseline1_chain_varx1_bigger_varx2_gaussian_same_noise_d4_size10000_seed1'"
     ]
    }
   ],
   "source": [
    "# exp1: baseline1 single dataset, x1->x2->x3->x4, RL cannot determine some directions\n",
    "\n",
    "# exp1: d=12, prob=0.5, 'lingam', s=np.ones([d]) (same variance)\n",
    "#       d=12, prob=0.5, 'gaussian', s=np.ones([d]) (same variance)\n",
    "#       d=12, prob=0.5, 'lingam', s=np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) (same variance)\n",
    "#       d=12, prob=0.5, 'gaussian', s=np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) (diff. noise variance)\n",
    "#       d=30, prob=0.2, 'gaussian', s=np.ones([d]) (same variance)\n",
    "\n",
    "seeds = [1]\n",
    "\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 4\n",
    "    n_samples=10000\n",
    "    W = np.array([[0, 0, 0, 0], [0.4, 0, 0, 0], [0, 2, 0, 0,], [0, 0, -1, 0]])\n",
    "    c = np.zeros([d])\n",
    "    s = np.array([1, np.sqrt(0.75), 1, 1])\n",
    "    # s = np.ones([d]) # s = np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) different varicne\n",
    "    xs, b_, c_ = gen_data_given_model(W, s, c, n_samples=n_samples, noise_type='gaussian', permutate=False)\n",
    "\n",
    "# save your data\n",
    "    dir_name = os.path.join(os.getcwd(), 'baseline1_chain_varx1_bigger_varx2_gaussian_same_noise_d{}_size{}_seed{}'.format(d, n_samples, seed))\n",
    "    os.mkdir(dir_name)\n",
    "    np.save(os.path.join(dir_name, 'data.npy'), xs)\n",
    "    np.save(os.path.join(dir_name, 'DAG.npy'), b_)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1.         0.32851148]\n",
      " [0.32851148 1.        ]]\n"
     ]
    },
    {
     "ename": "FileExistsError",
     "evalue": "[Errno 17] File exists: '/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/baseline2_chain_multiple_theta_c_gaussian_same_noise_d4_size1000_domains10_seed8'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileExistsError\u001b[0m                           Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-f98e5e6e1846>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     35\u001b[0m \u001b[0;31m# save your data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m     \u001b[0mdir_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetcwd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'baseline2_chain_multiple_theta_c_gaussian_same_noise_d{}_size{}_domains{}_seed{}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_samples_each\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_datasets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseed\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m     \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmkdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdir_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     38\u001b[0m     \u001b[0;31m# np.save(os.path.join(dir_name, 'data.npy'), total_xs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     39\u001b[0m     \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdir_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'DAG.npy'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW_permute\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileExistsError\u001b[0m: [Errno 17] File exists: '/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/baseline2_chain_multiple_theta_c_gaussian_same_noise_d4_size1000_domains10_seed8'"
     ]
    }
   ],
   "source": [
    "# exp2 baseline2 multiple datasets, x1->x2->x3->x4, theta(t) and RL leads to extra error edges\n",
    "\n",
    "seeds = [8]\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 4\n",
    "    n_datasets = 10\n",
    "    n_samples_each = 1000\n",
    "    total_samples = n_datasets * n_samples_each\n",
    "    total_xs = np.zeros([total_samples, d])\n",
    "    total_xs_index = np.zeros([total_samples, d + 1])\n",
    "    total_xs_one_hot = np.zeros([total_samples, d + n_datasets])\n",
    "    # W = generate_W(d=d, prob=0.5) # 0.2\n",
    "    W = np.array([[0, 0, 0, 0], [-1.0, 0, 0, 0], [0, 2, 0, 0,], [0, 0, -1, 0]])\n",
    "    c = np.zeros([d])\n",
    "    s = np.ones([d]) # s = np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) different varicne\n",
    "    index = np.zeros([d, d], dtype=np.bool)\n",
    "    index[1, 0] = index[3, 2] = True\n",
    "    W_permute, W_domain_related, c_permute = generate_W_domain_related_theta_c(W, c, n_domain=n_datasets, permute=False, index=index)\n",
    "    x2_theta = W_domain_related[:, 1, 0]\n",
    "    x4_theta = W_domain_related[:, 3, 2]\n",
    "    x3_theta = W_domain_related[:, 2, 1]\n",
    "    print(np.corrcoef(np.vstack((x2_theta, x4_theta))))\n",
    "    for domain_index in range(n_datasets):\n",
    "        xs, b_, c_ = gen_data_given_model(W_domain_related[domain_index, :, :], s, c_permute, n_samples=n_samples_each, noise_type='gaussian')\n",
    "        xs_index = np.concatenate((xs, domain_index*np.ones([n_samples_each, 1])), axis=1)\n",
    "        one_hot = np.zeros([n_samples_each, n_datasets])\n",
    "        one_hot[:, domain_index] = 1\n",
    "        xs_one_hot = np.concatenate((xs, one_hot), axis=1)\n",
    "        total_xs[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs\n",
    "        total_xs_index[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_index\n",
    "        total_xs_one_hot[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_one_hot\n",
    "\n",
    "# save your data\n",
    "    dir_name = os.path.join(os.getcwd(), 'baseline2_chain_multiple_theta_c_gaussian_same_noise_d{}_size{}_domains{}_seed{}'.format(d, n_samples_each, n_datasets, seed))\n",
    "    os.mkdir(dir_name)\n",
    "    # np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    np.save(os.path.join(dir_name, 'DAG.npy'), W_permute)\n",
    "    np.save(os.path.join(dir_name, 'DAG_multiple_domain.npy'), W_domain_related)\n",
    "    # np.save(os.path.join(dir_name, 'data_index.npy'), total_xs_index)\n",
    "    np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    np.save(os.path.join(dir_name, 'data_onehot.npy'), total_xs_one_hot)\n",
    "\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1.         0.90297274]\n",
      " [0.90297274 1.        ]]\n",
      "[[ 1.         -0.22244743]\n",
      " [-0.22244743  1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed1\n",
      "[[1.         0.92662501]\n",
      " [0.92662501 1.        ]]\n",
      "[[1.         0.34845509]\n",
      " [0.34845509 1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed2\n",
      "[[1.         0.87602283]\n",
      " [0.87602283 1.        ]]\n",
      "[[1.         0.10314984]\n",
      " [0.10314984 1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed3\n",
      "[[1.         0.86605567]\n",
      " [0.86605567 1.        ]]\n",
      "[[1.         0.15202846]\n",
      " [0.15202846 1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed4\n",
      "[[1.         0.82617869]\n",
      " [0.82617869 1.        ]]\n",
      "[[ 1.         -0.33830631]\n",
      " [-0.33830631  1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed5\n",
      "[[1.         0.90919325]\n",
      " [0.90919325 1.        ]]\n",
      "[[ 1.         -0.21965401]\n",
      " [-0.21965401  1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed6\n",
      "[[1.        0.8942736]\n",
      " [0.8942736 1.       ]]\n",
      "[[ 1.       -0.313538]\n",
      " [-0.313538  1.      ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed7\n",
      "[[1.         0.76667481]\n",
      " [0.76667481 1.        ]]\n",
      "[[1.         0.10490998]\n",
      " [0.10490998 1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed8\n",
      "[[1.        0.8368398]\n",
      " [0.8368398 1.       ]]\n",
      "[[1.         0.08554362]\n",
      " [0.08554362 1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed9\n",
      "[[1.         0.94702969]\n",
      " [0.94702969 1.        ]]\n",
      "[[ 1.         -0.30431895]\n",
      " [-0.30431895  1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed10\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "from synthetic_datasets import *\n",
    "# exp3  multiple datasets, x1->x2->x3->x4,\n",
    "import time\n",
    "seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 4\n",
    "    n_datasets = 10\n",
    "    n_samples_each = 1000\n",
    "    total_samples = n_datasets * n_samples_each\n",
    "    total_xs = np.zeros([total_samples, d])\n",
    "    total_xs_index = np.zeros([total_samples, d + 1])\n",
    "    total_xs_one_hot = np.zeros([total_samples, d + n_datasets])\n",
    "    # W = generate_W(d=d, prob=0.5) # 0.2\n",
    "    W = np.array([[0, 0, 0, 0], [-0.8, 0, 0, 0], [0, 2, 0, 0,], [0, 0, -1, 0]])\n",
    "    s = np.ones([d]) # s = np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) different varicne\n",
    "    s[0] = 5\n",
    "    s[1] = 1\n",
    "    index = np.zeros([d], dtype=np.bool)\n",
    "    index[1] = index[3] = True\n",
    "    intercept_domain_related = generate_intercept_domain_related_g_c_theta_c(d, n_domain=n_datasets, index=index)\n",
    "    index2 = np.zeros([d], dtype=np.bool)\n",
    "    index2[0] = True\n",
    "    intercept_domain_related2 = generate_intercept_domain_related_theta_c(d, n_domain=n_datasets, index=index2)\n",
    "    intercept_domain_related[:, 0] = intercept_domain_related2[:, 0]\n",
    "    x2_theta = intercept_domain_related[:, 1]\n",
    "    x4_theta = intercept_domain_related[:, 3]\n",
    "    x3_theta = intercept_domain_related[:, 0]\n",
    "    print(np.corrcoef(np.vstack((x2_theta, x4_theta))))\n",
    "    print(np.corrcoef(np.vstack((x2_theta, x3_theta))))\n",
    "    for domain_index in range(n_datasets):\n",
    "        xs, b_, c_ = gen_data_given_model(W, s, intercept_domain_related[domain_index], n_samples=n_samples_each, noise_type='gaussian')\n",
    "        xs_index = np.concatenate((xs, domain_index*np.ones([n_samples_each, 1])), axis=1)\n",
    "        one_hot = np.zeros([n_samples_each, n_datasets])\n",
    "        one_hot[:, domain_index] = 1\n",
    "        xs_one_hot = np.concatenate((xs, one_hot), axis=1)\n",
    "        total_xs[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs\n",
    "        total_xs_index[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_index\n",
    "        total_xs_one_hot[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_one_hot\n",
    "    DAG_index = np.array([[0, 0, 0, 0, 1], [-0.8, 0, 0, 0, 1], [0, 5, 0, 0, 0], [0, 0, -1, 0, 1], [0, 0, 0, 0, 0]])\n",
    "# save your data\n",
    "    dir_name = os.path.join(os.getcwd(), 'exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d{}_size{}_domains{}_seed{}'.format(d, n_samples_each, n_datasets, seed))\n",
    "    os.mkdir(dir_name)\n",
    "    print(dir_name)\n",
    "    # np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    np.save(os.path.join(dir_name, 'DAG.npy'), W)\n",
    "    np.save(os.path.join(dir_name, 'DAG_index.npy'), DAG_index)\n",
    "    np.save(os.path.join(dir_name, 'data_onehot.npy'), total_xs_one_hot)\n",
    "    np.save(os.path.join(dir_name, 'data_index.npy'), total_xs_index)\n",
    "    np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    # np.save(os.path.join(dir_name, 'data_onehot.npy'), total_xs_one_hot)\n",
    "\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1.         0.87602283]\n",
      " [0.87602283 1.        ]]\n",
      "[[1.         0.10314984]\n",
      " [0.10314984 1.        ]]\n",
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d4_size1000_domains10_seed3_1588053350.053185\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "from synthetic_datasets import *\n",
    "# exp3  multiple datasets, x1->x2->x3->x4,\n",
    "import time\n",
    "seeds = [3]\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 4\n",
    "    n_datasets = 10\n",
    "    n_samples_each = 1000\n",
    "    total_samples = n_datasets * n_samples_each\n",
    "    total_xs = np.zeros([total_samples, d])\n",
    "    total_xs_index = np.zeros([total_samples, d + 1])\n",
    "    total_xs_one_hot = np.zeros([total_samples, d + n_datasets])\n",
    "    # W = generate_W(d=d, prob=0.5) # 0.2\n",
    "    W = np.array([[0, 0, 0, 0], [-0.8, 0, 0, 0], [0, 2, 0, 0,], [0, 0, -1, 0]])\n",
    "    s = 0.1*np.ones([d]) # s = np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) different varicne\n",
    "    s[0] = 1\n",
    "    s[1] = 0.1\n",
    "    index = np.zeros([d], dtype=np.bool)\n",
    "    index[1] = index[3] = True\n",
    "    intercept_domain_related = generate_intercept_domain_related_g_c_theta_c(d, n_domain=n_datasets, index=index)\n",
    "    index2 = np.zeros([d], dtype=np.bool)\n",
    "    index2[0] = True\n",
    "    intercept_domain_related2 = generate_intercept_domain_related_theta_c(d, n_domain=n_datasets, index=index2)\n",
    "    intercept_domain_related[:, 0] = intercept_domain_related2[:, 0]\n",
    "    x2_theta = intercept_domain_related[:, 1]\n",
    "    x4_theta = intercept_domain_related[:, 3]\n",
    "    x3_theta = intercept_domain_related[:, 0]\n",
    "    print(np.corrcoef(np.vstack((x2_theta, x4_theta))))\n",
    "    print(np.corrcoef(np.vstack((x2_theta, x3_theta))))\n",
    "    for domain_index in range(n_datasets):\n",
    "        xs, b_, c_ = gen_data_given_model(W, s, intercept_domain_related[domain_index], n_samples=n_samples_each, noise_type='gaussian')\n",
    "        xs_index = np.concatenate((xs, domain_index*np.ones([n_samples_each, 1])), axis=1)\n",
    "        one_hot = np.zeros([n_samples_each, n_datasets])\n",
    "        one_hot[:, domain_index] = 1\n",
    "        xs_one_hot = np.concatenate((xs, one_hot), axis=1)\n",
    "        total_xs[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs\n",
    "        total_xs_index[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_index\n",
    "        total_xs_one_hot[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_one_hot\n",
    "    DAG_index = np.array([[0, 0, 0, 0, 1], [-0.8, 0, 0, 0, 1], [0, 5, 0, 0, 0], [0, 0, -1, 0, 1], [0, 0, 0, 0, 0]])\n",
    "# save your data\n",
    "    dir_name = os.path.join(os.getcwd(), 'exp2_chain_multiple_intercept_g_c_theta_c_varx1_bigger_varx2_x1_theta_c_gaussian_same_noise_d{}_size{}_domains{}_seed{}_{}'.format(d, n_samples_each, n_datasets, seed, time.time()))\n",
    "    os.mkdir(dir_name)\n",
    "    print(dir_name)\n",
    "    # np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    np.save(os.path.join(dir_name, 'DAG.npy'), W)\n",
    "    np.save(os.path.join(dir_name, 'DAG_index.npy'), DAG_index)\n",
    "    np.save(os.path.join(dir_name, 'data_onehot.npy'), total_xs_one_hot)\n",
    "    np.save(os.path.join(dir_name, 'data_index.npy'), total_xs_index)\n",
    "    np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    # np.save(os.path.join(dir_name, 'data_onehot.npy'), total_xs_one_hot)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/dingchenwei/PycharmProjects/trustworthyAI_test_nonlinear/Causal_Structure_Learning/Causal_Discovery_RL/src/Datasets/nonlinear_d4_size200_domains6_seed4_1589522778.2089448\n"
     ]
    }
   ],
   "source": [
    "# nonlinear\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "from synthetic_datasets import *\n",
    "# exp3  multiple datasets, x1->x2->x3->x4,\n",
    "import time\n",
    "seeds = [4]\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 4\n",
    "    n_datasets = 6\n",
    "    n_samples_each = 200\n",
    "    total_samples = n_datasets * n_samples_each\n",
    "\n",
    "    c = np.zeros([n_datasets, d])\n",
    "    # c[:, 0] = np.random.uniform(0.5, 10, [n_datasets])\n",
    "    # c[:, 1] = np.random.uniform(0.5, 10, [n_datasets])\n",
    "    # c[:, 3] = np.random.uniform(0.5, 10, [n_datasets])\n",
    "\n",
    "    # for i in range(n_datasets):\n",
    "    #     c[i, 0] = c[i, 1] = c[i, 3]= (-1**i)*i*2\n",
    "\n",
    "    s = 1*np.ones([n_datasets, d])\n",
    "\n",
    "    # data 0\n",
    "    x0_0 = np.random.normal(c[0, 0], s[0, 0], [n_samples_each, 1])\n",
    "    x0_1 = 1*x0_0 + np.random.normal(c[0, 1], s[0, 1], [n_samples_each, 1])\n",
    "    x0_2 = 1*x0_1 + np.random.normal(c[0, 2], s[0, 2], [n_samples_each, 1])\n",
    "    x0_3 = 1*x0_2 + np.random.normal(c[0, 3], s[0, 3], [n_samples_each, 1])\n",
    "    total_xs_0 = np.concatenate((x0_0, x0_1, x0_2, x0_3), axis=1)\n",
    "    # data 1\n",
    "    x1_0 = np.random.normal(c[1, 0], s[1, 0], [n_samples_each, 1])\n",
    "    x1_1 = 1*(x1_0**2)  + np.random.normal(c[1, 1], s[1, 1], [n_samples_each, 1])\n",
    "    x1_2 = 1*x1_1 + np.random.normal(c[1, 2], s[1, 2], [n_samples_each, 1])\n",
    "    x1_3 = 1*(x1_2**2)  + np.random.normal(c[1, 3], s[1, 3], [n_samples_each, 1])\n",
    "    total_xs_1 = np.concatenate((x1_0, x1_1, x1_2, x1_3), axis=1)\n",
    "    # data 2\n",
    "    x2_0 = np.random.normal(c[2, 0], s[2, 0], [n_samples_each, 1])\n",
    "    x2_1 = 1*np.sin(x2_0)  + np.random.normal(c[2, 1], s[2, 1], [n_samples_each, 1])\n",
    "    x2_2 = 1*x2_1 + np.random.normal(c[2, 2], s[2, 2], [n_samples_each, 1])\n",
    "    x2_3 = 1*np.sin(x2_2)  + np.random.normal(c[2, 3], s[2, 3], [n_samples_each, 1])\n",
    "    total_xs_2 = np.concatenate((x2_0, x2_1, x2_2, x2_3), axis=1)\n",
    "    # total_xs = np.concatenate((total_xs_0, total_xs_1, total_xs_2), axis=0)\n",
    "    #\n",
    "    # # data 3\n",
    "    x3_0 = np.random.normal(c[3, 0], s[3, 0], [n_samples_each, 1])\n",
    "    x3_1 = 1*np.cos(x3_0)  + np.random.normal(c[3, 1], s[3, 1], [n_samples_each, 1])\n",
    "    x3_2 = 1*x3_1 + np.random.normal(c[3, 2], s[3, 2], [n_samples_each, 1])\n",
    "    x3_3 = 1*np.cos(x3_2)  + np.random.normal(c[3, 3], s[3, 3], [n_samples_each, 1])\n",
    "    total_xs_3 = np.concatenate((x3_0, x3_1, x3_2, x3_3), axis=1)\n",
    "\n",
    "    # # data 4\n",
    "    x4_0 = np.random.normal(c[4, 0], s[4, 0], [n_samples_each, 1])\n",
    "    x4_1 = 1*np.tan(x4_0)  + np.random.normal(c[4, 1], s[4, 1], [n_samples_each, 1])\n",
    "    x4_2 = 1*x4_1 + np.random.normal(c[4, 2], s[4, 2], [n_samples_each, 1])\n",
    "    x4_3 = 1*np.tan(x4_2)  + np.random.normal(c[4, 3], s[4, 3], [n_samples_each, 1])\n",
    "    total_xs_4 = np.concatenate((x4_0, x4_1, x4_2, x4_3), axis=1)\n",
    "\n",
    "    # # data 5\n",
    "    x5_0 = np.random.normal(c[5, 0], s[5, 0], [n_samples_each, 1])\n",
    "    x5_1 = 1*np.sin(x5_0)  + np.random.normal(c[5, 1], s[5, 1], [n_samples_each, 1])\n",
    "    x5_2 = 1*x5_1 + np.random.normal(c[5, 2], s[5, 2], [n_samples_each, 1])\n",
    "    x5_3 = 1*np.sin(x5_2)  + np.random.normal(c[5, 3], s[5, 3], [n_samples_each, 1])\n",
    "    total_xs_5 = np.concatenate((x5_0, x5_1, x5_2, x5_3), axis=1)\n",
    "\n",
    "    total_xs = np.concatenate((total_xs_0, total_xs_1, total_xs_2, total_xs_3, total_xs_4, total_xs_5), axis=0)\n",
    "\n",
    "\n",
    "    # dag\n",
    "    dag = np.zeros([d, d])\n",
    "    dag_index = np.zeros([d+1, d+1])\n",
    "    dag[1, 0] = dag[2, 1] = dag[3, 2] = 1\n",
    "    dag_index[1, 0] = dag_index[2, 1] = dag_index[3, 2] = 1\n",
    "    dag_index[0, -1] = dag_index[1, -1] = dag_index[3, -1] = 1\n",
    "    \n",
    "    # data with index\n",
    "    total_index = np.zeros([total_samples, 1])\n",
    "    total_one_hot = np.zeros([total_samples, n_datasets])\n",
    "    for i in range(n_datasets):\n",
    "        total_index[i*n_samples_each:(i+1)*n_samples_each, 0] = i\n",
    "        total_one_hot[i*n_samples_each:(i+1)*n_samples_each, i] = 1\n",
    "    total_xs_index = np.concatenate((total_xs, total_index), axis=1)\n",
    "    total_xs_one_hot = np.concatenate((total_xs, total_one_hot), axis=1)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "# save your data\n",
    "    dir_name = os.path.join(os.getcwd(), 'nonlinear_d{}_size{}_domains{}_seed{}_{}'.format(d, n_samples_each, n_datasets, seed, time.time()))\n",
    "    os.mkdir(dir_name)\n",
    "    print(dir_name)\n",
    "    # np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    np.save(os.path.join(dir_name, 'DAG.npy'), dag)\n",
    "    np.save(os.path.join(dir_name, 'DAG_index.npy'), dag_index)\n",
    "    np.save(os.path.join(dir_name, 'data_onehot.npy'), total_xs_one_hot)\n",
    "    np.save(os.path.join(dir_name, 'data_index.npy'), total_xs_index)\n",
    "    np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    # np.save(os.path.join(dir_name, 'data_onehot.npy'), total_xs_one_hot)\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "output_type": "error",
     "ename": "SyntaxError",
     "evalue": "invalid syntax (<ipython-input-3-c91c60db8aa6>, line 64)",
     "traceback": [
      "\u001b[0;36m  File \u001b[0;32m\"<ipython-input-3-c91c60db8aa6>\"\u001b[0;36m, line \u001b[0;32m64\u001b[0m\n\u001b[0;31m    os.mkdir(dir_name)\u001b[0m\n\u001b[0m     ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
     ]
    }
   ],
   "source": [
    "# nonlinear\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import scipy.io as scio\n",
    "\n",
    "from synthetic_datasets import *\n",
    "# exp3  multiple datasets, x1->x2->x3->x4,\n",
    "import time\n",
    "seeds = [0, 1, 2, 3, 4, 5]\n",
    "\n",
    "# automatically construct data\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 4\n",
    "    n_datasets = 10\n",
    "    n_samples_each = 200\n",
    "    total_samples = n_datasets * n_samples_each\n",
    "\n",
    "    c = np.zeros([n_datasets, d])\n",
    "    # c[:, 0] = np.random.uniform(0.5, 10, [n_datasets])\n",
    "    # c[:, 1] = np.random.uniform(0.5, 10, [n_datasets])\n",
    "    # c[:, 3] = np.random.uniform(0.5, 10, [n_datasets])\n",
    "\n",
    "    for i in range(n_datasets):\n",
    "        c[i, 0] = c[i, 1] = c[i, 3]= (1**i)*i*0.2\n",
    "\n",
    "    s = 0.5*np.ones([n_datasets, d])\n",
    "\n",
    "    total_xs = np.zeros([n_datasets*n_samples_each, d])\n",
    "    index = np.zeros([n_datasets*n_samples_each, 1])\n",
    "    for i in range(n_datasets):\n",
    "        total_xs[i*n_samples_each:(i+1)*n_samples_each, 0] = np.random.normal(c[i, 0], s[i, 0], [n_samples_each])\n",
    "        for j in range(d-1):\n",
    "            if j == 1:\n",
    "                total_xs[i*n_samples_each:(i+1)*n_samples_each, j+1] = 1*np.sin(total_xs[i*n_samples_each:(i+1)*n_samples_each, j])  + np.random.normal(c[i, j+1], s[i, j+1], [n_samples_each])\n",
    "                continue\n",
    "            total_xs[i*n_samples_each:(i+1)*n_samples_each, j+1] = np.random.uniform(0.5, 1)*np.sin(total_xs[i*n_samples_each:(i+1)*n_samples_each, j])  + np.random.normal(c[i, j+1], s[i, j+1], [n_samples_each])\n",
    "        index[i*n_samples_each:(i+1)*n_samples_each, 0]=i\n",
    "    total_xs_index = np.concatenate((total_xs, index), axis=1)\n",
    "\n",
    "    # dag\n",
    "    dag = np.zeros([d, d])\n",
    "    dag_index = np.zeros([d+1, d+1])\n",
    "    dag[1, 0] = dag[2, 1] = dag[3, 2] = 1\n",
    "    dag_index[1, 0] = dag_index[2, 1] = dag_index[3, 2] = 1\n",
    "    dag_index[0, -1] = dag_index[1, -1] = dag_index[3, -1] = 1\n",
    "\n",
    "    # data with index\n",
    "    total_index = np.zeros([total_samples, 1])\n",
    "    total_one_hot = np.zeros([total_samples, n_datasets])\n",
    "    for i in range(n_datasets):\n",
    "        total_index[i*n_samples_each:(i+1)*n_samples_each, 0] = i\n",
    "        total_one_hot[i*n_samples_each:(i+1)*n_samples_each, i] = 1\n",
    "    total_xs_index = np.concatenate((total_xs, total_index), axis=1)\n",
    "    total_xs_one_hot = np.concatenate((total_xs, total_one_hot), axis=1)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# save your data\n",
    "    dir_name = os.path.join(os.getcwd(), 'nonlinear_d{}_size{}_domains{}_seed{}'.format(d, n_samples_each, n_datasets, seed))\n",
    "    os.mkdir(dir_name)\n",
    "    print(dir_name)\n",
    "    # np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    np.save(os.path.join(dir_name, 'DAG.npy'), dag)\n",
    "    np.save(os.path.join(dir_name, 'DAG_index.npy'), dag_index)\n",
    "    np.save(os.path.join(dir_name, 'data_onehot.npy'), total_xs_one_hot)\n",
    "    np.save(os.path.join(dir_name, 'data_index.npy'), total_xs_index)\n",
    "    np.save(os.path.join(dir_name, 'data.npy'), total_xs)\n",
    "    # np.save(os.path.join(dir_name, 'data_onehot.npy'), total_xs_one_hot)\n",
    "    scio.savemat(os.path.join(dir_name, 'data.mat'), {'data': total_xs})\n",
    "    scio.savemat(os.path.join(dir_name, 'data_index.mat'), {'data_index': total_xs_index})"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}