{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pickle\n",
    "# Note that this data is pre-filtered by df[ df['Treatment'] == 'T2a' ]\n",
    "df = pd.read_csv('./sangiovese.csv').drop('Unnamed: 0', axis=1)\n",
    "with open('./sangiovese-scm.json') as f:\n",
    "    scm_json = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['y'] = df['GrapeW']>0\n",
    "df.drop('GrapeW', axis=1, inplace=True)\n",
    "train, test = train_test_split(df, test_size=0.2, random_state=0)\n",
    "train.to_csv('sangiovese_data_train.csv', index=False)\n",
    "valid, test = train_test_split(test, test_size=0.5, random_state=0)\n",
    "valid.to_csv('sangiovese_data_valid.csv', index=False)\n",
    "test.to_csv('sangiovese_data_test.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "name2idx = {name: i for i, name in enumerate(df.drop('y', axis=1).columns)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "del scm_json['Treatment']\n",
    "del scm_json['GrapeW']\n",
    "for k in scm_json.keys():\n",
    "    if 'Treatment' in scm_json[k]['parent']:\n",
    "        scm_json[k]['parent'].remove('Treatment')\n",
    "        scm_json[k]['weight'] = scm_json[k]['weight'][2]\n",
    "        scm_json[k]['sd'] = scm_json[k]['sd'][2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def func_gen(w):\n",
    "    def func(*args):\n",
    "        return np.array(w).dot(np.array([1]+list(args)))\n",
    "    return func\n",
    "\n",
    "scm = {}\n",
    "for col in df.drop('y', axis=1).columns:\n",
    "    scm[name2idx[col]] = {}\n",
    "    scm[name2idx[col]]['name'] = col\n",
    "    parents = scm_json[col]['parent']\n",
    "    scm[name2idx[col]]['input'] = [name2idx[pa] for pa in parents]\n",
    "    scm[name2idx[col]]['weight'] = scm_json[col]['weight']\n",
    "    scm[name2idx[col]]['std'] = scm_json[col]['sd'][0]\n",
    "\n",
    "    # scm[name2idx[col]]['func']  = func_gen(w=scm[name2idx[col]]['weight'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"scm.pickle\", \"wb\") as outfile:\n",
    "    pickle.dump(scm, outfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: {'name': 'SproutN', 'input': [], 'weight': [-0.1475826], 'std': 0.1361936},\n",
       " 1: {'name': 'BunchN',\n",
       "  'input': [0],\n",
       "  'weight': [-0.04485152, 1.21770061],\n",
       "  'std': 0.3242622},\n",
       " 2: {'name': 'WoodW',\n",
       "  'input': [0, 1, 3, 5, 6],\n",
       "  'weight': [-0.01059014,\n",
       "   0.20866888,\n",
       "   0.10551576,\n",
       "   1.32071982,\n",
       "   1.17073084,\n",
       "   0.69730125],\n",
       "  'std': 0.2683441},\n",
       " 3: {'name': 'SPAD06',\n",
       "  'input': [0],\n",
       "  'weight': [0.0768126, 0.497339],\n",
       "  'std': 0.09603947},\n",
       " 4: {'name': 'NDVI06',\n",
       "  'input': [0, 3],\n",
       "  'weight': [0.0009060401, 0.0524537942, 0.4391847164],\n",
       "  'std': 0.09162183},\n",
       " 5: {'name': 'SPAD08',\n",
       "  'input': [3, 4],\n",
       "  'weight': [0.003413511, 0.657232651, 0.347169697],\n",
       "  'std': 0.08702186},\n",
       " 6: {'name': 'NDVI08',\n",
       "  'input': [0, 4, 5],\n",
       "  'weight': [-0.004858778, 0.104257079, 0.13598597, 0.431695004],\n",
       "  'std': 0.1123024},\n",
       " 7: {'name': 'Acid',\n",
       "  'input': [0, 1, 3, 4, 6, 9, 11, 12],\n",
       "  'weight': [0.0009448236,\n",
       "   -0.0859213515,\n",
       "   0.0745025088,\n",
       "   -0.3684175757,\n",
       "   -0.1935515016,\n",
       "   -0.2441153549,\n",
       "   -0.6116201324,\n",
       "   -0.0684739505,\n",
       "   0.1803238117],\n",
       "  'std': 0.1208581},\n",
       " 8: {'name': 'Potass',\n",
       "  'input': [1, 3, 11],\n",
       "  'weight': [-0.005045565, -0.07180271, 0.391809727, 0.06120008],\n",
       "  'std': 0.1454156},\n",
       " 9: {'name': 'Brix',\n",
       "  'input': [11],\n",
       "  'weight': [-0.0387757, 0.0993189],\n",
       "  'std': 0.05455801},\n",
       " 10: {'name': 'pH',\n",
       "  'input': [0, 2, 3, 7, 8, 9, 11, 12],\n",
       "  'weight': [-0.0004402041,\n",
       "   0.0115191464,\n",
       "   0.0061100089,\n",
       "   0.0407614454,\n",
       "   -0.1814311205,\n",
       "   0.0569802619,\n",
       "   0.1606425478,\n",
       "   -0.0239600155,\n",
       "   0.0210687973],\n",
       "  'std': 0.01680764},\n",
       " 11: {'name': 'Anthoc',\n",
       "  'input': [1, 2, 6],\n",
       "  'weight': [0.00682654, -0.13613185, -0.33030867, -0.45297454],\n",
       "  'std': 0.3176561},\n",
       " 12: {'name': 'Polyph',\n",
       "  'input': [1, 4, 6, 9, 11],\n",
       "  'weight': [0.002074334,\n",
       "   0.055373115,\n",
       "   -0.373441738,\n",
       "   0.23435951,\n",
       "   0.286760526,\n",
       "   0.550064939],\n",
       "  'std': 0.1565978}}"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for OrdCE\n",
    "df_norm = (df - train.mean())/ train.std()\n",
    "df_norm['y'] = df['y']\n",
    "df_norm.to_csv('sangiovese_all_norm.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The scm for ordce should also be normalized scale (weight *= std_x/ std_y)\n",
    "stds = train.std().to_numpy()\n",
    "adj_matrix = np.zeros((13,13))\n",
    "for k in scm.keys():\n",
    "    adj_matrix[k][scm[k]['input']] = scm[k]['weight'][1:] * stds[scm[k]['input']]  / stds[k]\n",
    "# adj_matrix = adj_matrix.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ],\n",
       "       [ 0.45405802,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ],\n",
       "       [ 0.06090121,  0.08258758,  0.        ,  0.32921053,  0.        ,\n",
       "         0.33741771,  0.20435138,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ],\n",
       "       [ 0.58231467,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ],\n",
       "       [ 0.06716605,  0.        ,  0.        ,  0.48030256,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ],\n",
       "       [ 0.        ,  0.        ,  0.        ,  0.5684226 ,  0.27455309,\n",
       "         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ],\n",
       "       [ 0.10382852,  0.        ,  0.        ,  0.        ,  0.10576265,\n",
       "         0.42455182,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ],\n",
       "       [-0.07700321,  0.17906381,  0.        , -0.28199595, -0.13546634,\n",
       "         0.        , -0.21968062,  0.        ,  0.        , -0.26941677,\n",
       "         0.        , -0.17377169,  0.32460609],\n",
       "       [ 0.        , -0.17071359,  0.        ,  0.29666622,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.15363705,  0.        ],\n",
       "       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.57219318,  0.        ],\n",
       "       [ 0.03973078,  0.        ,  0.07220719,  0.12007461,  0.        ,\n",
       "         0.        ,  0.        , -0.69824987,  0.2216834 ,  0.27233441,\n",
       "         0.        , -0.2340129 ,  0.14596287],\n",
       "       [ 0.        , -0.12892675, -0.39967396,  0.        ,  0.        ,\n",
       "         0.        , -0.16062636,  0.        ,  0.        ,  0.        ,\n",
       "         0.        ,  0.        ,  0.        ],\n",
       "       [ 0.        ,  0.07393197,  0.        ,  0.        , -0.14519583,\n",
       "         0.        ,  0.117159  ,  0.        ,  0.        ,  0.07017116,\n",
       "         0.        ,  0.77546824,  0.        ]])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adj_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'dag.png'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from lingam.utils import make_dot\n",
    "import graphviz\n",
    "\n",
    "dot = make_dot(adj_matrix, labels=list(train.columns[:-1]))\n",
    "\n",
    "# Save png\n",
    "dot.format = 'png'\n",
    "dot.render('dag')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"Sangiovese_adjM.pickle\", \"wb\") as outfile:\n",
    "    pickle.dump(adj_matrix, outfile)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "13f752eec1c22bf877ed13574aae0c5036489a901cc325afaf3d1c548bba661a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
