{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from synthetic_datasets import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def calculate_LR(X_train, y_train):\n",
    "    ones = np.ones((X_train.shape[0], 1), dtype=np.float32)\n",
    "    X = np.hstack((X_train, ones))\n",
    "    XtX = X.T.dot(X)\n",
    "    Xty = X.T.dot(y_train)\n",
    "    theta = np.linalg.solve(XtX, Xty)\n",
    "    y_err = X.dot(theta) - y_train\n",
    "    return y_err, theta\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "data = np.load(\"Datasets/chain_gaussian_same_noise_d4_size10000_seed1/data.npy\")\n",
    "x = data[:, 0]\n",
    "y = data[:, 1]\n",
    "x = np.expand_dims(x, axis=1)\n",
    "y = np.expand_dims(y, axis=1)\n",
    "# true\n",
    "res, theta = calculate_LR(x, y)\n",
    "print(theta)\n",
    "mean_sq_err1 = np.mean(np.square(res))\n",
    "mean_sq_err2 = np.var(x)\n",
    "print(mean_sq_err1, mean_sq_err2, mean_sq_err2+mean_sq_err1)\n",
    "# false\n",
    "res, theta = calculate_LR(y, x)\n",
    "print(theta)\n",
    "\n",
    "mean_sq_err1 = np.mean(np.square(res))\n",
    "mean_sq_err2 = np.var(y)\n",
    "print(mean_sq_err1, mean_sq_err2, mean_sq_err2+mean_sq_err1)\n",
    "hstack = np.hstack((x,y))\n",
    "print(np.cov(hstack.T))\n",
    "print(np.var(x))\n",
    "%matplotlib inline\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "outputs": [
    {
     "data": {
      "text/plain": "<matplotlib.collections.PathCollection at 0x11304b9e8>"
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": "<Figure size 432x288 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJztnX+MHOd537/PLofiHp3ojvU1tlaiqDoFCTM0edYhUnNAE8qu6FiWfJFss4qUIikKIkAThLJ6yclWTNJRqwsYWzaaAIWcpE0hwaF++WyZTikbZJCGNlUffUfRjMnUsi3KKxe+gFzZ4q3Ivdunf+zNcnZ23pl3dmd/zX4/ACHd3uzMu3sz33nm+SmqCkIIIekh0+0FEEIISRYKOyGEpAwKOyGEpAwKOyGEpAwKOyGEpAwKOyGEpAwKOyGEpAwKOyGEpAwKOyGEpIw13TjoW9/6Vt20aVM3Dk0IIX3LyZMn/0lVR6O264qwb9q0CXNzc904NCGE9C0i8orNdnTFEEJIykhM2EUkKyLzIvLlpPZJCCEkPkla7L8H4DsJ7o8QQkgTJCLsInI9gDsA/HkS+yOEENI8SVnsnwHw+wAqpg1EZI+IzInI3OLiYkKHJYQQ4qdlYReRDwD4saqeDNtOVR9X1XFVHR8djczWIYQQ0iRJpDtOALhLRN4PYB2AnxWRJ1T1/gT2TQghbWV2voCDR87htWIJ1w3nMLVrMybH8t1eVku0bLGr6kOqer2qbgLwbwEcpagTQvqB2fkCHnruNArFEhRAoVjCQ8+dxux8odtLawnmsRNCBpaDR86hVF6pe61UXsHBI+e6tKJkSLTyVFX/FsDfJrlPQghpF68VS7Fe7xdosRNCBpbrhnOxXu8XKOyEkIFlatdm5Jxs3Ws5J4upXZu7tKJk6EoTMEII6QXc7Je0ZcVQ2AkZcJJM9+vH1MHJsXzPrzEuFHZCBhg33c/NDHHT/QDEFrsk90Vagz52QgaYJNP90po62I/QYidkgEki3c91vxRSmjrYj1DYCRlgrhvOBQqybbrf7HwBU8+cQnlFQ4/h3b6bPvhuH79TUNgJGWCmdm2u84sD8dL9Djx/JlTUvfvqpA8+SMABDEwMgMJOyADjTfcrFEvIitT5xaME7+JS2fi7vM8iDvPBJymsphvINWsyHTl+L8DgKSEDzuRYvlaos6JV6zuJZljHp2+rE0yTr71QLCXadMt0AymWgm9CaYwBUNgJSQmz8wVMzBzFTdOHMTFzNJZYNpvRMpxzrF8P89sn2VExrlD3e/uAICjshKSAVtvPNpsds/+urXAyUveakxHsv2trw7ZB5fsuSaZFmoR6ZMhJZfuAICjshKSAVnPIm22GNTmWx8EPb0d+OAdB1a9+8MPb61ww7pPEA4cWsM4xS05SLhFT/5d9d27Fo3dvq1vro3dvS51/HWDwlJBU0Eo++ux8AUtXlhtedzKCpSvLuGn6cENqoG3aoD+QGRZsTcolEtX/JY1C7ofCTkgKaDYf3S+8Ljkng+WK1oTYmxoI2KcNBj1JBJG0SySN/V/iQFcMISmg2fazJuG9sqwN+eml8goefOoU9h5asHb72DwxZEVS6xLpFi1b7CKyDsDfAbhmdX/PqOq+VvdLCKknzP3RbPtZk/C6aY+2r5v2ZXqS8FJRpagnTBKumMsAblPVN0TEAfD3IvI3qnoigX0TQmBXtdmM+8EkvFmRUBE37cvPzi2jeOLE+cj3DUqpf6do2RWjVd5Y/dFZ/RfvjCCEhNKuzokmF869t9xgTE0MwuT2OXZ2MfJ9O7eMNqRqPnBoAQ/Png59LzGTiI9dRLIisgDgxwC+qqovBmyzR0TmRGRucTH8j00IqafVrBdT4dLkWD4wBfCRyW11r2dFjPs3pQ3OzhdC3TDu+46dXWy4aSmAJ0+cT7QidZAQjfm4FbozkWEAXwDwu6r6bdN24+PjOjc3l9hxCelnbNwQEzNHA0UyP5zD8enbQvcd1OTrnpvzOHZ20dr1YdqPKehpyrYJWvdN04eNj/hRn2/QEJGTqjoetV2i6Y6qWhSRYwDeB8Ao7IT0Kp329dp2PGy2C6PJhfPkifM1MbXpchg3OHvg+TNGUfevOyzA6n0ioR/eniSyYkYBlFdFPQfg3wD445ZXRkiH6cZoN9uOh0lnvfgtZJsuh0HBWe+QDTfgOjLkhBYiXbOm3gM8tWszHji0EGi1uwFZjt2LRxIW+9sB/JWIZFH12T+lql9OYL+EdJR2tJX1WpnDQw5UgddL5Zowh/nOgyzUKLeE/z3DESLrJSotMehYXrF1s2iijlcsleuKnQ4eORco6l7LvlMtf9NCy8Kuqi8BGEtgLYR0lSTGxHkJK6d3Mz9MvuXhISfSQvWL+M4to3j2ZKHuPXGyI2R1n7ZCaVtVGkSpvIIDz5/Bm+VK3T4E1acJfy/3pP82aYeVp4Ss0mwjLBNRwmcS9ZyThSoCLdT9XzoDILib45Mnzje8pxJjvbq6ZltaFdWLS+XAbBg3YOq9wZj+BhkRZs4EQGEnBFWhvHS5sRFWKz1MmhE+NwXwdcNQiGKpXLPUg0SxVQrFknU/93b1MQ/63kwtf1dUE+3lnhbYBIwMJF43xrU5B5euLDf0RhkZcrDvzq1N+3Btyun9uO4HNyAZhLvuduHt5w6Yg5NBmTomshnBSsXu1jM85GBi5mhgkPjBp041VMTS194ILXYycPjdGMVSOXAg89DaNZFiEVT8475WKJZgLusJxnWFhD0luIJng38IhpeotUVVtrrFTWHFSy4VS1F3soI33lwOHBgyOZY3tjmgr70eCjsZOGyDflFiEeTn3ntoAXsPLdSsbcVVAR0Zcoyj5PzHnBzLY2QoeFvXinWyjYKaWT0OUK0WLVcUuYDhFjkni196x4ZIUbYRTJueMjayPjLkYE2mumYv7g1mdr5gvBkF3ehaGRXY71DYycBha925LgGTMIQV4XhRVEVr/hO3Y2Hf7fjBzB1GgfcK1B3venvgNju3jGJyLI/1axs9qW6w1DuYulSuwMkIRoacWtuAe27O41vnX48U5bAng9n5Ah58+lTo++PwxpvLKJWDw72FYgkPPnUq8AYhaHzCaXVUYL9DYSc9T9KWl40bI8wl4K7JNj8cqGaAeNe9/66tgc23dm4ZrX3WJ18M7or45VM/AgBjgDUo26RcUQytXYPvz9yB49O3BfZnCSLMJfQHz75k7Te3wW+p+zHdhBR2Az6SnKva61DYSU/TDssrKMPCb9GuX7vG6BIA4qUFunjfE9R8656b83j2ZKH2WU3GdHFV0ONmpXiDsa34pGfnCxj75Au4vBwnmbJ9uPn3XqLy3tPupqGwk56mHZZXkKge/PB2zH/i9ppFa7KGXWFoRhgLq9WkQHDfE1sr2sXUctfkh/a+bntT+OhTC3Vr3nHgBew9tBDracUlKCaQBEH592E1CYPgpmG6I+lp2lVxGDWUImqGaDOpjEB1VujcKxcaKkTDqlD9uMFRU/+YvYcWAt+nuGrZFpeuWB2rotVYgrv2ZitNR4YcvPFmY51AUvjPh7CmaYPQnoDCTnqaZoc0xyWqPB+oWrw7t4wCqArH1NOnIv3CfkrlFXz+xVcb/MVx9nJxqYx3PPQVrKjWGm95S/BNwg4ADz33Et4sV2Ifr5X2AUDVrRT3u4qDW4FqMyrwAcP3k6aUSQo76WmabVcbh6DOgc+eLODdG6/F11++UBNBBfDsyQLGb9yAybE8Djx/pimXRNyRc27/lKB9uP/1FhSFYco6iaKZpxMvRYNrKy5DTgYKabjJuBWoQPSowDBjIS2tgeljJz2NacJPkheb6dH869+70CCoXv9+sQlRbwbb20C/Zn2I778mck4W/+XudxmLomw/vyk2ETSir19977TYSc/TzJBmEw/Pnq65QrIiuPeWG8w9yw2KWiiWsGn6cCLrSZpWLetuoKgKa5irx+tqmp0vtFSBanLTtNP33uknAQo76WmSvCDu+9w3cPzlC7WfV1TxxInzWL82i0tXmvcfk9aJ8t+7fehdt5kJ29hLkLHQLt97N4aEUNhJz2J7QdiI/+x8oU7UvVDUk2X92iyuLFfqgqU5J4t1TqapmIS3tUJYEDcs9mJzjrQrUN+NLBwKO7GiG0ElmwsiTPzdfbxWLCFj0agqiLVZwZWABmHEzPDQ2pprw3u+zL1yAU+cCK6mBYKDxE5WsO/OrbWfw6xn28HaSc+VjaIbQ0KSmHl6A4D/CeDnUP27PK6qn211v6R36Na8SZsLwiT+/pS/uJkoLhT1+Lh5+dcN5/DY7h21c8TNhzfh/6aD2iaH1Q+4gdM47QWSmCsbRadSdr0kYbEvA3hQVb8lIj8D4KSIfFVV/yGBfZMeoFsFHTYXRJpyj9OEv6c7ED0L1c9PSo0FTWE94E0GRxyLOclAvUsnUnb9tJzuqKo/UtVvrf7/TwF8B0D/JX4SI92aN2lKS/NeEO20ekjruAZAlLUehGk60rqANsT+43lJeuRhXDqRsusnUR+7iGxCdbD1i0nul3SXTj9Kev35w0MOBIql1cKaa9bUX9RxpviQ7tBKCqb3ydDvEjQRp71Ap2jHk0AYiQm7iLwFwLMA9qrqTwJ+vwfAHgDYuHFjUoclHaAdF4YpGOu/eP2P78VSue5x2+sX7YUc7oxU+6uQ5HD/rrZtDYZ9A0ra5TvvZUSbDCrV7UTEAfBlAEdU9dNR24+Pj+vc3FzLxyWdI8msmCDLK+dk8ejd26wF2p1kH7Vfkg7uv3VjaEaNFycjOPjh7akUbhE5qarjkdu1KuwiIgD+CsAFVd1r8x4K+2DjzgP1kx/O4bXVcu4oBMD3Z+5oeN29AfWC9U6SIygVMoygG38asBX2JHrFTAD4DQC3icjC6r/3J7BfklLCgrG2fnvTdpNjeRyfvg0T79jQ9PpI7xHX/Bz0bKkksmL+XlVFVd+lqjtW/30licWRdBKWpRCUCeMnqsJw7JMvGKtMSbqwmR07iLDylHScsGCst5jFlPfsZsbY9FAn6SUrgmKp3OCm6XTGSy+SSPA0LvSxE9v+Lq6/3H/xOpnqgAlmoKSb9Wuz+LV35yNv2O75kU95xkvHgqfNQGHvfzrZO8YUbCXpIb/6xHXs7GLgOeU93zKrU6P8ZEVQUU11OiOFnbSNsHTFpC+m2flC6Kg30v/4M1hm5wt1rrjhnIP9d13tGXPT9OHIYGq7zsdu08msGDJghPWOSZKo3tuk//H6w93A995DC3XxlWKpjKmnT9VaC9gERvt1mlRSMHhKYpNU7xivDz1oKHOrA5RJbxLkD48qLitXtNZawLaNxCCnPNJiJ7ExWUyKqj/cZkakeyG7vnP/UOaHZ0/Tr55SXFF33S8TM0ex99CCtVD7m2oFzT8FBjvlkRY7CSQsONpM61T//i5dXjZeyKXyinX5OOlPXiuWYreA8Aq1t0+QKeYzyCmPDJ6SBmyCo1Gl+16LbHa+gKmnT9WNSiODTX5VpJt5KvMHU4HuTPjqBsyKIU0T1svF339j0/Rh435+sNrLZceBF1AsxZ91SfqT4ZyDy8srKK22WvbjGgkPHFqI3SrAJc2NvsJgVgxpmjjBUZN/0/s6RX2wuHRluUHUM6ung3fIRJgPPD+cw2d276hZ9n7cYCoJhj520oDNYA330dc0S7TZGaOk/ykHzIl9+7WNT3um1hJel98DITUMg5z1EgWFnTQQNVjDJuiVH87VxJ+QQrGEiZmjgT7wMN942PDqQc56iYLCPuCEBZ3CXg8T9ZyTxc4toxx6QWoIrgZK/ZlTYX7yqV2bMfXMqcCngJ1bRtuy1jRAYR9g/Ja37QUX9gjsFp2wuGjwCBsL6H/ZO8s0DPf3H//CaVy6Un8+PXuygPEbNwxcANUGBk8HGNvWALPzBUzMHMVN04cxMXMU1xp6YHuh/3PwiJvNanuOTI7lMTy0tuH1QW8bEAYt9gHGJvslyKoPo1AsYe+hBQTnyhBylTg+8qTaWAwKiVjsIvKXIvJjEfl2EvsjnSFskpFLsy4V5sSQMPyVof6nQn9bCptzlVwlKVfM/wDwvoT2RRJkdr6AHQdewKbpw9g0fRhjn3yhdtEEjaHzX3Ds10KSxpvLDtT3DVJcjfV4xd3mXCVXScQVo6p/JyKbktgXSY6gUv6LS2VMPXMKAOqyXNwOi36/Zdzp8IRE4c9nD4v1uOeoKVMLgDGNcpChjz3FHDxyLrA/S3nlagtU9yLw+9GnnjmF5YpaiTrFn9gSVEka5T/3p+Q+tntHYKtfUwO6QaRjWTEiskdE5kRkbnFxsVOHHWjCAkve3wVZTOUVhU3xqJMV/NI7NsDQWYCQGgIEuk7C/OdhbppODXzpRzom7Kr6uKqOq+r46CgLCzpBWGDJ+7tWMguWVxTHX75gdRMgg40i2JIO85+HiTczZcwwjz3FTO3aDCfTaEo7WamznIaHovPSTVDPiS2mhl7+wRne4GqYeDNTxkwiPnYR+TyAXwHwVhH5IYB9qvoXSeybNI9rHe3/0plah8WRIQf77txal5HwOrsvEkvcEYZxicpgMVU6hzWki+ppNMgklRVzbxL7Ia0T1PtlYd/txu2YzkhsyYrgUx/Zjr0hHReDyLeQrRIm3jZNxAYVDtroUZqZCGMz+ci0HSFJE3TuucQ5vwdlOpINnKDUx9gKtB/byUem7QhJivxwDju3jOLY2cUGQW72/Cb2ws489h7EpmAjCNssAWYNkHbhCjTQWBvh5pg3e34Te5gV04M0m8ZlygbIOVf/zLPzBWSYdE7agDebhWmK3YUWew9iM5ouiKldm/HRpxYa2qculSt45x/+DZbKFVaJkkQRoFYJ6sXk6isUS8jHGL1Iv3pz0GLvQZpteDQ5ljeK9tLqcGGKOkmS+27d2BCYn5g5GvqenVtGQ89vm6ZgJBwKew8SVrARBStASacYzjl4ZHJb7efZ+QKmnjkVGZg/dnYx9Pxmq4DWoSumRwkbTRf2mNpsAQkhcXAygv13ba177cDzZwJnk/opFEtNjV6kD94eWuw9SNjQgajH1HtvuaFLqyaDRLmiOPD8mbpz8+KSXQVzNiJ4z1YBrUNh7zGihDvqMdX7aExIO7m4VG7K9x31RMmhGq1DYe8xooTb5jHV1GyJkKQplVdw4PkzkQFTL1HnZysxJlKFPvYeI0q4h4ecwEdeBbBp+jCyIrj1X4zgwqUrbBlAOsLFpbK1G8bW8g7zwZNoaLH3GGF+xIdnT+ONN5dD37+i1f7o7954bZ3FQ0inyQ/ncP+tG2l5dwFa7D3G1K7NDXNKgapF/sSJ89b7OfG9i3j50fcDqPrt43bkI4ONAPj5f74e//fHl5p+v3+2KekctNh7jMmxPN6yrvX7rTdAxfxfEhcF8L3FpdBtRoYcmPJbmMHSXSjsPUjR0l8ZhZutwPxf0gxh2Svr12bxZrkSWMnMDJbuQ1dMD2IKkAKI1evlY8+9hINHzrGNAGmKsGK3S1eCA/NZkUA/Onu/dJZELHYReZ+InBOR74rIdBL7TCthxUcuYWm+v/SODdbHWipX2HedNEXOyeLeW24wulpMVFQDRZ29XzpLy8IuIlkAfwbgVwG8E8C9IvLOVvebRmxP8LAZpN86/3qbV0kGFVfE3eyV8Rs31LV8tiHIt87eL50nCYv9FwF8V1W/p6pXAPw1gA8msN/UYXuChwWemJtO2oWiftrWQ8+drnUFtcHkW2fvl86ThLDnAbzq+fmHq68RH7YneFBJNSGdwD0Xg4wQPyNDjlWOOnu/dJ6OBU9FZA+APQCwcePGTh22pzAN0MiIYHa+UHdRXLMmU7uwRNiOl7RGVgQ/s24NiiFuPgC4NudYzcTNOVnsu3OrVQB0atfmwBmnzJxpH0kIewGAt6Xg9auv1aGqjwN4HKgOs07guH1H0AkOVNPK3HmQc69cwJMnztdlslDUSausqEaKupMRXLqyHLldPmZWi7fPOrNiOoNoi6ohImsA/COA96Aq6N8E8Ouqesb0nvHxcZ2bm2vpuP3K7HwBDz51KjCNjGPrSLfIiuBnc2tCe744WcHBD22nIHcRETmpquNR27Vssavqsoj8DoAjALIA/jJM1AeNoPzdiuFmSlEn3aKiGlkYt37tGop6n5CIj11VvwLgK0nsK0246Y2u68VNb7w250Q+7hLSSdxAZphvPSwNl/QWrDy1pJnKOVN64zong5yTZeoi6Qm8gcywZnHMYukfKOwWmCxv4GpgyC/8O7eMGq0f297VhLSbrAjuuflq7/MDz58JPD8FYBZLH8EmYBZEFRYFVZTGabFLSLdYUcUTJ85j7JMvYHa+gH13bm2ooRAA9926kf71PoLCbkFUYZFNMQchnSKbidvh5er8UgANY+ke272Ds3T7DLpiLDAVFrk+R5ZGk14hv+oGPPTNV1Feqc+zyjkZlEJaBLhPocenb6N13udQ2H24vvJCsVRrWzqcc+Bkpe5C8QacTMJPSCdx+7xMzBxtEHUA2LD+GgDhmS80UtIBXTEevL5y4OqggWKpDOjViTH+vhjs7UJ6gUuXlzE7Xwh1HUadq8x8SQe02D2E+crLFcXQ2jWY/8TtDb8LKpne9M9yOP7yhcB9rV+bxdKVFWRCBhkQEpdiqRxaJ5GRqu/90bu3Yf+XzjRsw/4t6YHC7iHqMdT9vSmn3euXnJg5atzP8NBanPnkbbjvc98wij8hzRBWJ+H2JHr07m1Y2Hc7pxqlGAq7hyhf+XXDOaucdiD8JlEoljA7X8DXKerEgJMBINLgK7fpJ1RcKuOx3TsCexK5AVLXEKGQpxP62D2E+R/dx1TbYRnX5hzjcQTVQhA6YYiJcgWAVt12XmzOmeuGc5gcyxt7EjFAmn4o7Ku4j6Wl8gqyq75I97/eYKntsAwJSSVWsPqURFOuaOjQaAANM0n92VpBMECafuiKQWPLgBVV5Jxs4ESYqJx2l6hOeYS0QkUVP5i5I9RPzgEXg8vACrv3ggjKTvH6Ir3YXizMbSftxDUkwvzkHHAxuAyksAdZ6EEEuV3ci8KbLrYuYJL71K7NmHrmVGChCCGtEMfqZoB0MBlIH7ttb5cwX+Tl5aul2W6fjdn5+omAKxR10ga83RgJCWIghd0mK8BkFc3OF/DRpxYiM2MOHjkHc1cOQprn0DdfbTAiCPHSkrCLyIdF5IyIVEQkcg5frxCVFTCccwIDp7PzBUw9cwoVgyHuvWEwpYy0i/KKNqTXEuKlVYv92wDuBvB3CaylY0T1y1h/TfBsx4NHzoX6zL03DKaUkXZCw4GE0VLwVFW/AwASlrTdg7iibRoDZpur7sfrupnatTl0zBghreAaDq20BWBLgfQykFkxQFXc3fa8fsIKO0wpjE6muk/vxUJIO8hmBFO7Nlu3twiilfeS3ifSFSMiXxORbwf8+2CcA4nIHhGZE5G5xcXF5lecIDu3jMZ6PSzFrFwBHp49XTcij5AwRoaC206YXndxL1rb9hZBtPJe0vtECruqvldVfyHg3xfjHEhVH1fVcVUdHx0NFs5Oc+xs8A3G9HqUJfPEifMckUesyA/nAueL5pws9t25FfffurGhXYBLuaLGp03Azv/erLuR9AcDme7owpObdAM3lXZyLN8wX9TNxnpkchse273DuI9CsWQUfpvAPfvIpJuWfOwi8msA/iuAUQCHRWRBVXclsrIQkgr62PZ98ZIRGNMdCYki7ztfo1oCmCzzrGFIiyDcZejCPjLppiWLXVW/oKrXq+o1qvpznRJ1rx/bDfo0U7ARlPYYdXL/+i0bYx+HDBZhOWJxjRDTOWpqg6GwC36GPS2Q/qfvsmLCgj7ek9LGqrft++LlkclteOLE+SQ+CkkhOSeDy8sVmCYePvjUKTxwaCHWk+Y6J1M754dzDvbftdVoyedjuFLYRya99J2P3cYvbmvVPzx7Gh99aqFu9qOp74uXOBcPGRwyqPYQCnPVrajWzsmpZ06Fnmfueezt3e/2KGrmaZMMDn0n7DZBH5tUrodnT+OJE+cDL8KotK+oylUymFQQL/5SXlF87LmXjL+PejqlK4WY6DtXjE3Qx8aq//yLr4YeJywzZnIsj7lXLuDJE+fr8tXdeZSmwBYhfpbKFczOFwIFOeo8piuFmOg7i93GUrGx6qOENyrt69jZxYYiJEVV3D/1ke2hATRCvJieDpmSSJql7yx2INpSsbHqw6a9e7c1BWFNxSEKYOrpU6w8JQCqT2+f+sj2ugC9H5NlzpRE0ix9KexR2IwEG1qbNQ4KdiXf1E9j7pULoTeGMhPdCdAwN9fUFM5kgXO0HWkW0S74gsfHx3Vubq7jx/Vy0/ThUKs6A+DaIacuI4EQW/yFSEA1YO+Py5iGphMShIicVNXI2ReptNhtiBo2XQEo6qTG+tAnvEaOT9/W8Nojk9swfuMGWuCk7QyssHPYNInDlZUK7r91I549WWip0VsnM1nYb31w6busmESx0HTmqxOgmnN+7OxiXUZWGBMzR3HT9GFMzBztynzSJFtvkP5jYIX94JFzkUHOkSGndiET8lqxhMmxPI5P34bvz9yB4Zy5b3q3BZX91gebgRX2qNa8TlZwx7vejv1fOhPqiyfpID+cw8Q7NoRu489e2X/XVjiZ6IqFbggqW1IPNgPrYw8LnuaHc9i5ZRSH/s+rTF1MMSLAfbdsxCOT22qv3fe5b+D4yxcatnWy0pA/7k1HjLr5d1pQm2lJTdLDwFrspiZKn9m9A8enb8Oxs4vWos4q0/5EFXj2ZAGz89V/EzNH8fWXL2BkyEHO0+VzZMjBwQ9tDww8uq6ZqHOg04LKJmGDzcBa7FHFH7YWlpuHDFRbsrJHTO9g07OnVF7BgefP4M1ypeaTvrhUrt3kbbNIwp4AuyGoLG4abPq2QMlN5SoUS7ULOKgoxPS+qJN9x4EXjCXgXoacDErlCq5bdd+0mg5H4pPNCNZmBaVytaXtyJCDfXduDWzUFof8cC4wHz0If5Wyi9s/nYJKkqAjBUoichDAnQCuAHgZwG+parGVfdrgv4hcq8zNQACCp8iYWgT4t5+dL+Cnl5et1rK0KiaFYglPnDiPnJPBCCtWO8b6tVmoau3vAFQt7qmnTwFildFqJI5fnBZ+pqdZAAAOxklEQVQy6SVadcV8FcBDqrosIn8M4CEAf9D6ssIJSuVyCZqmFPY+//az84WWXCpVqzHa655zsnj3xmvx9ZcvsGFYCyxdWQn8/pIIesf1i7ONLukVWhJ2VX3B8+MJAB9qbTl2RFlSrxVLgS4Xkw+0UCxhYuZoNRPmm6+27Ce3ccWsczI489pPKeot0q7vj4FG0s8k5mMXkecBHFLVJ6K2bdXHPjFzNDS9bGTIqQuGAYCTEaYupoiw7pphmAKqI0MOhtauoRuF9DSJ+dhF5GsA3hbwq4+r6hdXt/k4gGUAT4bsZw+APQCwcePGqMOGEtSn2sub5ZVaIM2lW6IuAuNg40EmmxGsNPk3yUc0cAOqN3II6noB5Zws7rk53xDgzjlZ7LuTAU6SHiKFXVXfG/Z7EflNAB8A8B4NMf9V9XEAjwNViz3eMutxL0CTL9wv6t2Eoh5MBoBN7pBrYWdFcO8tN+CRyW2RcRA3EwUIDmaywyJJOy25YkTkfQA+DeCXVXXR9n1J9WOP6qneb9hYooNEULqhKa0QqLpn7ru1vpKUkDRh64pptfL0TwH8DICvisiCiPy3FvcXi7hZC062uRrRz+ze0fbq0jg504OAKXhpyojKiuCx3Tso6oSgRWFX1Z9X1RtUdcfqv99OamE2TO3aHEusD35ou1XLVT+TY/m2loQLgJ1bRjExc7Rtx+gHsiLGAeUupoyoiirdKYSs0tctBSbH8qFDgr0M55y6POOozBovm6YPY2TIQQbVyUpJowAOffPVgRn6sX5tFhVFQwDTZkScqXQ/I4LZ+QLFnRCkoAnY6xaiDlSzU7zEzVG+uFRGNit1zaGSZFBEHQCuLFdwz8352tNTmIXuJ6i5FVCtPuYgCUKq9L2w27pILi6V66bZTI7lQwclBFFeUWxYfw0Hb7RIuVKdRuQOrDg+fZu1pT05lsejd29D1n+nBgdJEOLS98JusuCCKBRLmHr6VE3c99+1Nfbou9eKJVYkJkAr/cknx/KoGLK5OEiCkBQIu2vBuY/1wzknNKBarij2f+lM3XvjWO5pHVRgMQgoUVr9Hk3vT+vfh5A49L2wA1eHHTy2ewfWX7MG5RUNfFR38QZbJ8fyWNh3Oz6ze0fNxWJ6p5uCl8bH/U9/ZAeGEogf2DwBJdGHJWqQhDs4o5sDpQnpFqkQdqB+KjuAyEZeD8+ervvZvTn8YOYOPOYRefcG4Q3wpfVx/x/+6FcjLfewX7vfUdgTUJxAaRj+JzXvfr3nQjcHShPSLfo63dFLWCvfIJ48cR7jN24IFRgB8LZr1zWUnIdNy+lX3L70Ye1bwipjvdbypSuNveydjODgh4PHy9kOP/FjapNr056ZkDSTCot9dr4QW2gVCHSp2Fh7U7s298ScU9dS9bqRmqVUXsHeQwvG32dFQp9UXOE8eORcYOrmW9atCR1+Ymtd27hYTOtM65MWIX76XthdYTAR4moPvNDDrD2XybF8z/WoiZMd1Az33nIDhofCg8yvFUtG8SwaJkrZfN8utjcBBlbJoNP3wh7mgsk5Wdx3y0ajdR10oZuEqbA6vMPFxkIeGXLamvPuFTcAePTubRiJEN+4ZEVw/60bMX7jBrzxZvi4wOuGc7FFNY51bXsTiAqsEpJ2+t7HHvZ47Q3S+Ycamy70MP+5dz5qVE94t8d3qwOVbSiVV3Dg+TMYWrsGxaUyRoYcqFarcjOGwRJRCIDvz9xR+3li5mhoT3vv9zn1zKk6d4yTFaOomr7vODdd/+ucP0oGnb632E2WYH44V7uQH5ncVst0iSphD3NpeK3DoPz5kSGnbv9A4w2lXVxcKtdcFBeXyri8XMFju3fgUx/ZHmi9RqU2+r/XsBtow/fp/8AhX0Ac6zrO04Cb5RS3spWQNJDYaLw4JNWPHQjuz23bUCpsn6ZAot+SDSNOo7F24LYCDso6eeDQglFvg74/02fxtxs2beeOnisUS7XhGfnVtQB21nU7/taE9BOd6sfedcLymVvZp8k3HicA1+0sjEKxhPs+9w08+NQpFIolZESwc8toZBvidQHWvK1lbfrM7hMFcLXGwBsfsLGu2/G3JiSN9L3F3ixRudNJWIedsNgFVSGOMw7QDYZGxQhcd5L7PQ17fPcmy7qZz8whI4TYkdgw6zTiF22v5egKVVgAzragJirAGoY7ePnY2cVQoXxs9w4AiHWcz7/4am3S0MEj5wL37wZk3yxXavu9uFRGzsnisd07jDe3Zj5zt59sCEkbA2mx2/qLgwiy5AXV+GDeYPm74uluF4Rbhu/2sRkZcrDvzq2YHMtbrTcsLhDEDzxxgrizY6O+J/+N79Ll5dBhKLTYCbGjIxa7iPwRgA+iOljoxwB+U1Vfa2WfJpotOw+ilcrEoFxqVxRNlr/7/7PzBRx4/gwu+op1ck4WH9j+djx78mqe/MWlcm1fQVawrB5vYuZo7bswWd9+/A3S4rZIiPqe/KX+YQOomV9OSPK0Gjw9qKrvUtUdAL4M4BMJrKmBpJs6tVKZGCVqYcMeJsfymP/E1U6S3gDgsbOLof1N3KAhgDrL3/td2Faf3nvLDXWl+ZcuLze0Os45WWMzr7gVnP71BzVWI4QkR0sWu6r+xPPjeoRmLDdP0k2dgixgW8vRxrptxmccVvE6MXO09qQyMuQ0WPzud+G6M8JcP27Ci/fzF0tlOBnByJCD4tLVwKh/O6B5C9vUsIsQkjwtB09F5D8D+HcAXgewM2S7PQD2AMDGjRtjHSPppk7NVibOzhdw6XJ4WT1QFVOvi8S/j6DA7XCAYANXXS7w/DeI11ZbHrifKT+cw84tozj80o/q9luuBBdNlSuKobVrMP+J2xv2zQpOQvqLyOCpiHwNwNsCfvVxVf2iZ7uHAKxT1X1RB40bPG0l2JkUYX5iE3EKfYZzDi4vVwKDsjYEvT/nZHHNmkxo4NJLnOIrQkjnSaxASVXfq6q/EPDvi75NnwRwT7MLDqMXmjqZmo0N58yNvoL87aanjNdL5YbiG1tRzzlZiCDQXWUr6gC7HxKSFloKnorIv/T8+EEAZ1tbTjDdrDh0g4wmN8jrpTKOT99m7CDpF/KwwK2/v4nphuHeTLzfhaktrgn/epmdQkh6aNXHPiMim1FNd3wFwG+3vqRguhF8s3G/DA85mJg5arSu/UJuE7gNy33POVnsv2trw3dhSnUcGXLqiozcfbjFT7a+8yTTTQkh7aXVrJi2uF56hahxe05W8Maby4FBT5dLl5cxO1+wqmgFGm8mivACKBfTDWPfnVtDj2eDTaUuIaR3GMiWArZEtaqNqqgEqqmEYUVLfkwFUFGB4qgbRisCzBmihPQXFPYQTDnrrsjeNH3Yaj9+EQxza7SS2tkudxVniBLSX/R92952EpWN00wL36gq2marYm2GPDcLZ4gS0l9Q2EOIysYJEv6o+apRczubSe1MuuWCn15INyWE2ENXTARh7o0gv/bOLaN49mTBmPUS5dZopiq23T5wzhAlpL+gsLdIkPCP37jBKII2w5vj+so74QNnrxdC+gcKexsIE8FWGpCZsLlZEEIGB/rYO0yQ3/6em6u91JsNfNIHTgjxQou9C/iHb7Ra/EMfOCHEC4W9yyQV+KQPnBDiQmHvMP7iJFNzMRb/EEKahcLeQYLcLqae6wx8EkKahcHTDmLqA8MWuoSQJKGwdxCTe8Vt8tXpXvOEkHRCV0wHiWoqRgghSUCLvYMw35wQ0glosXcQ5psTQjpBIsIuIg8C+BMAo6r6T0nsM60w35wQ0m5adsWIyA0AbgdwvvXlEEIIaZUkfOyPAfh9BKdjE0II6TAtCbuIfBBAQVVPWWy7R0TmRGRucXGxlcMSQggJIdLHLiJfA/C2gF99HMDHUHXDRKKqjwN4HADGx8dp3RNCSJuIFHZVfW/Q6yKyDcBNAE6JCABcD+BbIvKLqvr/El0lIYQQa0Q1GeNZRH4AYNwmK0ZEFgG8ksiBW+etAAYhk2dQPicwOJ91UD4nwM/qcqOqjkbtoCt57DYL6xQiMqeq491eR7sZlM8JDM5nHZTPCfCzxiUxYVfVTUntixBCSPOwpQAhhKQMCvtqps4AMCifExiczzoonxPgZ41FYsFTQgghvQEtdkIISRkUdg8i8qCIqIi8tdtraQciclBEzorISyLyBREZ7vaakkRE3ici50TkuyIy3e31tAsRuUFEjonIP4jIGRH5vW6vqZ2ISFZE5kXky91eSzsRkWEReWb1Gv2OiPyrZvdFYV9lQJqZfRXAL6jquwD8I4CHuryexBCRLIA/A/CrAN4J4F4ReWd3V9U2lgE8qKrvBHArgP+Y4s8KAL8H4DvdXkQH+CyA/6WqWwBsRwufmcJ+ldQ3M1PVF1R1efXHE6hWC6eFXwTwXVX9nqpeAfDXAD7Y5TW1BVX9kap+a/X/f4qqAKSyF7SIXA/gDgB/3u21tBMRuRbAvwbwFwCgqldUtdjs/ijsiNfMLEX8ewB/0+1FJEgewKuen3+IlIqdFxHZBGAMwIvdXUnb+AyqBlel2wtpMzcBWATw31fdTn8uIuub3dnATFBKqplZrxP2OVX1i6vbfBzVx/knO7k2kiwi8hYAzwLYq6o/6fZ6kkZEPgDgx6p6UkR+pdvraTNrALwbwO+q6osi8lkA0wD+sNmdDQSD0szM9DldROQ3AXwAwHs0XbmuBQA3eH6+fvW1VCIiDqqi/qSqPtft9bSJCQB3icj7AawD8LMi8oSq3t/ldbWDHwL4oaq6T17PoCrsTcE8dh9xmpn1GyLyPgCfBvDLqpqqpvgisgbVgPB7UBX0bwL4dVU909WFtQGpWiB/BeCCqu7t9no6warF/p9U9QPdXku7EJH/DeA/qOo5EdkPYL2qTjWzr4Gx2AkA4E8BXAPgq6tPJydU9be7u6RkUNVlEfkdAEcAZAH8ZRpFfZUJAL8B4LSILKy+9jFV/UoX10Ra53cBPCkiawF8D8BvNbsjWuyEEJIymBVDCCEpg8JOCCEpg8JOCCEpg8JOCCEpg8JOCCEpg8JOCCEpg8JOCCEpg8JOCCEp4/8DWA+PP68LqwMAAAAASUVORK5CYII=\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(y, x)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 1.         -0.05047483]\n",
      " [-0.05047483  1.        ]]\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(4)\n",
    "mat = np.zeros([10, 4,4])\n",
    "mat[:, 1, 0]=mat[:, 3,2]=-1\n",
    "for i in range(10):\n",
    "    mat[i, :, :]-=np.random.uniform(0.5, 2, [4, 4])\n",
    "print(np.corrcoef(np.vstack((mat[:, 1, 0], mat[:, 3, 2]))))\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'synthetic_datasets'",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mModuleNotFoundError\u001B[0m                       Traceback (most recent call last)",
      "\u001B[0;32m<ipython-input-1-791f03340481>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m()\u001B[0m\n\u001B[1;32m      3\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mmatplotlib\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mpyplot\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mplt\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      4\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mos\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 5\u001B[0;31m \u001B[0;32mfrom\u001B[0m \u001B[0msynthetic_datasets\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0;34m*\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m      6\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      7\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mgen_data\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'synthetic_datasets'"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from synthetic_datasets import *\n",
    "\n",
    "def gen_data():\n",
    "    import time\n",
    "    seeds = [3]\n",
    "    for seed in seeds:\n",
    "        np.random.seed(seed)\n",
    "        d = 4\n",
    "        n_datasets = 10\n",
    "        n_samples_each = 1000\n",
    "        total_samples = n_datasets * n_samples_each\n",
    "        total_xs = np.zeros([total_samples, d])\n",
    "        total_xs_index = np.zeros([total_samples, d + 1])\n",
    "        total_xs_one_hot = np.zeros([total_samples, d + n_datasets])\n",
    "        # W = generate_W(d=d, prob=0.5) # 0.2\n",
    "        W = np.array([[0, 0, 0, 0], [-0.8, 0, 0, 0], [0, 2, 0, 0,], [0, 0, -1, 0]])\n",
    "        s = np.ones([d]) # s = np.round(np.random.uniform(low=0.5, high=2, size=[d]), 1) different varicne\n",
    "        s[0] = 5\n",
    "        s[1] = 1\n",
    "        index = np.zeros([d], dtype=np.bool)\n",
    "        index[1] = index[3] = True\n",
    "        intercept_domain_related = generate_intercept_domain_related_g_c_theta_c(d, n_domain=n_datasets, index=index)\n",
    "        index2 = np.zeros([d], dtype=np.bool)\n",
    "        index2[0] = True\n",
    "        intercept_domain_related2 = generate_intercept_domain_related_theta_c(d, n_domain=n_datasets, index=index2)\n",
    "        intercept_domain_related[:, 0] = intercept_domain_related2[:, 0]\n",
    "        x2_theta = intercept_domain_related[:, 1]\n",
    "        x4_theta = intercept_domain_related[:, 3]\n",
    "        x3_theta = intercept_domain_related[:, 0]\n",
    "        print(np.corrcoef(np.vstack((x2_theta, x4_theta))))\n",
    "        print(np.corrcoef(np.vstack((x2_theta, x3_theta))))\n",
    "        for domain_index in range(n_datasets):\n",
    "            xs, b_, c_ = gen_data_given_model(W, s, intercept_domain_related[domain_index], n_samples=n_samples_each, noise_type='gaussian')\n",
    "            xs_index = np.concatenate((xs, domain_index*np.ones([n_samples_each, 1])), axis=1)\n",
    "            one_hot = np.zeros([n_samples_each, n_datasets])\n",
    "            one_hot[:, domain_index] = 1\n",
    "            xs_one_hot = np.concatenate((xs, one_hot), axis=1)\n",
    "            total_xs[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs\n",
    "            total_xs_index[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_index\n",
    "            total_xs_one_hot[domain_index * n_samples_each : (domain_index + 1) * n_samples_each, :] = xs_one_hot\n",
    "        DAG_index = np.array([[0, 0, 0, 0, 1], [-0.8, 0, 0, 0, 1], [0, 5, 0, 0, 0], [0, 0, -1, 0, 1], [0, 0, 0, 0, 0]])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "from synthetic_datasets import *\n",
    "# exp3  multiple datasets, x1->x2->x3->x4,\n",
    "import time\n",
    "seeds = [4]\n",
    "for seed in seeds:\n",
    "    np.random.seed(seed)\n",
    "    d = 4\n",
    "    n_datasets = 6\n",
    "    n_samples_each = 200\n",
    "    total_samples = n_datasets * n_samples_each\n",
    "\n",
    "    c = np.zeros([n_datasets, d])\n",
    "    # c[:, 0] = np.random.uniform(0.5, 10, [n_datasets])\n",
    "    # c[:, 1] = np.random.uniform(0.5, 10, [n_datasets])\n",
    "    # c[:, 3] = np.random.uniform(0.5, 10, [n_datasets])\n",
    "\n",
    "    for i in range(n_datasets):\n",
    "        c[i, 0] = c[i, 1] = c[i, 3]= (-1**i)*i*2\n",
    "\n",
    "    s = 5*np.ones([n_datasets, d])\n",
    "\n",
    "    # data 0\n",
    "    x0_0 = np.random.normal(c[0, 0], s[0, 0], [n_samples_each, 1])\n",
    "    x0_1 = 10*x0_0 + np.random.normal(c[0, 1], s[0, 1], [n_samples_each, 1])\n",
    "    x0_2 = 10*x0_1 + np.random.normal(c[0, 2], s[0, 2], [n_samples_each, 1])\n",
    "    x0_3 = 10*x0_2 + np.random.normal(c[0, 3], s[0, 3], [n_samples_each, 1])\n",
    "    total_xs_0 = np.concatenate((x0_0, x0_1, x0_2, x0_3), axis=1)\n",
    "    # data 1\n",
    "    x1_0 = np.random.normal(c[1, 0], s[1, 0], [n_samples_each, 1])\n",
    "    x1_1 = 10*(x1_0**2)  + np.random.normal(c[1, 1], s[1, 1], [n_samples_each, 1])\n",
    "    x1_2 = 10*x1_1 + np.random.normal(c[1, 2], s[1, 2], [n_samples_each, 1])\n",
    "    x1_3 = 10*(x1_2**2)  + np.random.normal(c[1, 3], s[1, 3], [n_samples_each, 1])\n",
    "    total_xs_1 = np.concatenate((x1_0, x1_1, x1_2, x1_3), axis=1)\n",
    "    # data 2\n",
    "    x2_0 = np.random.normal(c[2, 0], s[2, 0], [n_samples_each, 1])\n",
    "    x2_1 = 10*np.sin(x2_0)  + np.random.normal(c[2, 1], s[2, 1], [n_samples_each, 1])\n",
    "    x2_2 = 10*x2_1 + np.random.normal(c[2, 2], s[2, 2], [n_samples_each, 1])\n",
    "    x2_3 = 10*np.sin(x2_2)  + np.random.normal(c[2, 3], s[2, 3], [n_samples_each, 1])\n",
    "    total_xs_2 = np.concatenate((x2_0, x2_1, x2_2, x2_3), axis=1)\n",
    "    # total_xs = np.concatenate((total_xs_0, total_xs_1, total_xs_2), axis=0)\n",
    "    #\n",
    "    # # data 3\n",
    "    x3_0 = np.random.normal(c[3, 0], s[3, 0], [n_samples_each, 1])\n",
    "    x3_1 = 10*np.cos(x3_0)  + np.random.normal(c[3, 1], s[3, 1], [n_samples_each, 1])\n",
    "    x3_2 = 10*x3_1 + np.random.normal(c[3, 2], s[3, 2], [n_samples_each, 1])\n",
    "    x3_3 = 10*np.cos(x3_2)  + np.random.normal(c[3, 3], s[3, 3], [n_samples_each, 1])\n",
    "    total_xs_3 = np.concatenate((x3_0, x3_1, x3_2, x3_3), axis=1)\n",
    "\n",
    "    # # data 4\n",
    "    x4_0 = np.random.normal(c[4, 0], s[4, 0], [n_samples_each, 1])\n",
    "    x4_1 = 10*np.tan(x4_0)  + np.random.normal(c[4, 1], s[4, 1], [n_samples_each, 1])\n",
    "    x4_2 = 10*x4_1 + np.random.normal(c[4, 2], s[4, 2], [n_samples_each, 1])\n",
    "    x4_3 = 10*np.tan(x4_2)  + np.random.normal(c[4, 3], s[4, 3], [n_samples_each, 1])\n",
    "    total_xs_4 = np.concatenate((x4_0, x4_1, x4_2, x4_3), axis=1)\n",
    "\n",
    "    # # data 5\n",
    "    x5_0 = np.random.normal(c[5, 0], s[5, 0], [n_samples_each, 1])\n",
    "    x5_1 = 5*np.sin(x5_0)  + np.random.normal(c[5, 1], s[5, 1], [n_samples_each, 1])\n",
    "    x5_2 = 10*x5_1 + np.random.normal(c[5, 2], s[5, 2], [n_samples_each, 1])\n",
    "    x5_3 = 5*np.sin(x5_2)  + np.random.normal(c[5, 3], s[5, 3], [n_samples_each, 1])\n",
    "    total_xs_5 = np.concatenate((x5_0, x5_1, x5_2, x5_3), axis=1)\n",
    "\n",
    "    total_xs = np.concatenate((total_xs_0, total_xs_1, total_xs_2, total_xs_3, total_xs_4, total_xs_5), axis=0)\n",
    "\n",
    "\n",
    "    # dag\n",
    "    dag = np.zeros([d, d])\n",
    "    dag_index = np.zeros([d+1, d+1])\n",
    "    dag[1, 0] = dag[2, 1] = dag[3, 2] = 1\n",
    "    dag_index[1, 0] = dag_index[2, 1] = dag_index[3, 2] = 1\n",
    "    dag_index[0, -1] = dag_index[1, -1] = dag_index[3, -1] = 1\n",
    "    \n",
    "    # data with index\n",
    "    total_index = np.zeros([total_samples, 1])\n",
    "    total_one_hot = np.zeros([total_samples, n_datasets])\n",
    "    for i in range(n_datasets):\n",
    "        total_index[i*n_samples_each:(i+1)*n_samples_each, 0] = i\n",
    "        total_one_hot[i*n_samples_each:(i+1)*n_samples_each, i] = 1\n",
    "    total_xs_index = np.concatenate((total_xs, total_index), axis=1)\n",
    "    total_xs_one_hot = np.concatenate((total_xs, total_one_hot), axis=1)\n",
    "\n",
    "graph_batch = np.array( [[0., 0., 0., 0., 1.],\n",
    "         [1., 0., 0., 0., 1.],\n",
    "        [0., 1., 0., 0., 0.],\n",
    "        [0., 0., 1., 0., 1.],\n",
    "        [0., 0., 0., 0., 0.]])\n",
    "\n",
    "import matlab\n",
    "eng = matlab.start\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}