{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# setPath"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sys import path\n",
    "path.append(r\"../\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Syn-0-4-4-2-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from utils import log, CausalDataset\n",
    "from module.SDD import run as run_SynSDD\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "\n",
    "def get_args():\n",
    "    argparser = argparse.ArgumentParser(description=__doc__)\n",
    "    # About run setting !!!!\n",
    "    argparser.add_argument('--seed',default=2022,type=int,help='The random seed')\n",
    "    argparser.add_argument('--mode',default='vx',type=str,help='The choice of v/x/vx/xx')\n",
    "    argparser.add_argument('--rewrite_log',default=False,type=bool,help='Whether rewrite log file')\n",
    "    argparser.add_argument('--use_gpu',default=True,type=bool,help='The use of GPU')\n",
    "    # About data setting ~~~~\n",
    "    argparser.add_argument('--num',default=10000,type=int,help='The num of train\\val\\test dataset')\n",
    "    argparser.add_argument('--num_reps',default=10,type=int,help='The num of train\\val\\test dataset')\n",
    "    argparser.add_argument('--ate',default=0,type=float,help='The ate of constant')\n",
    "    argparser.add_argument('--by',default=1,type=float,help='if the outcome is binary')\n",
    "    argparser.add_argument('--sc',default=1,type=float,help='The sc')\n",
    "    argparser.add_argument('--sh',default=0,type=float,help='The sh')\n",
    "    argparser.add_argument('--one',default=1,type=int,help='the coefs between VZCU when computing t')\n",
    "    argparser.add_argument('--VX',default=0,type=int,help='nonlinear relation between v and x')\n",
    "    argparser.add_argument('--mV',default=0,type=int,help='The dim of Outside Instrumental variables V')\n",
    "    argparser.add_argument('--mX',default=10,type=int,help='The dim of Observerd variables X')\n",
    "    argparser.add_argument('--mZ',default=4,type=int,help='The dim of Inside Instrumental Variales Representations Z')\n",
    "    argparser.add_argument('--mC',default=4,type=int,help='The dim of Confounder Variales Representations C')\n",
    "    argparser.add_argument('--mA',default=2,type=int,help='The dim of Adjustable Variales Representations A')\n",
    "    argparser.add_argument('--mU',default=2,type=int,help='The dim of Unobserved confounding variables U')\n",
    "    argparser.add_argument('--storage_path',default='../Data/',type=str,help='The dir of data storage')\n",
    "    # Syn\n",
    "    argparser.add_argument('--syn_alpha',default=1,type=float,help='The ratio of treament prediction loss')\n",
    "    argparser.add_argument('--syn_beta',default=1,type=float,help='The ratio of disentangle loss')\n",
    "    argparser.add_argument('--syn_gamma',default=1,type=float,help='The ratio of imb loss')\n",
    "    argparser.add_argument('--syn_lambda',default=0.0001,type=float,help='The ratio of regularization loss')\n",
    "    argparser.add_argument('--syn_twoStage',default=True,type=bool,help='whether use twostage method')\n",
    "    argparser.add_argument('--lrate',default=5e-4,type=float,help='lrate')\n",
    "    argparser.add_argument('--iteration',default=2000,type=float,help='iteration')\n",
    "    argparser.add_argument('--output-delay',default=100,type=float,help='output-delay')\n",
    "    # About Debug or Show\n",
    "    argparser.add_argument('--verbose',default=1,type=int,help='The level of verbose')\n",
    "    argparser.add_argument('--epoch_show',default=5,type=int,help='The epochs of show time')\n",
    "    args = argparser.parse_args(args=[])\n",
    "    return args\n",
    "\n",
    "args = get_args()\n",
    "\n",
    "if args.use_gpu:\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else \"cpu\")\n",
    "    \n",
    "# set path\n",
    "which_benchmark = 'Syn_'+'_'.join(str(item) for item in [args.by,args.sc, args.sh, args.one, args.mV,args.mX,args.mU,args.VX])\n",
    "which_dataset = '_'.join(str(item) for item in [args.mZ, args.mC, args.mA])\n",
    "resultDir = args.storage_path + f'/results/{which_benchmark}_{which_dataset}/'\n",
    "dataDir = f'{args.storage_path}/data/{which_benchmark}/{which_dataset}/'\n",
    "os.makedirs(os.path.dirname(resultDir), exist_ok=True)\n",
    "logfile = f'{resultDir}/log_SDD.txt'\n",
    "\n",
    "if args.rewrite_log:\n",
    "    f = open(logfile,'w')\n",
    "    f.close()\n",
    "\n",
    "results = []\n",
    "alpha = args.syn_alpha\n",
    "beta = args.syn_beta\n",
    "gamma = args.syn_gamma\n",
    "for exp in range(args.num_reps):\n",
    "    # load data\n",
    "    train_df = pd.read_csv(dataDir + f'{exp}/train.csv')\n",
    "    val_df = pd.read_csv(dataDir + f'{exp}/val.csv')\n",
    "    test_df = pd.read_csv(dataDir + f'{exp}/test.csv')\n",
    "\n",
    "    train = CausalDataset(train_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "    val = CausalDataset(val_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "    test = CausalDataset(test_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "\n",
    "    res_list = []\n",
    "\n",
    "   \n",
    "    obj_val, final = run_SynSDD(exp, args, dataDir, resultDir, train, val, test, device)\n",
    "    res_list = res_list + [obj_val['ate_train'],obj_val['ate_test']]\n",
    "    \n",
    "    res = np.array(res_list) \n",
    "    results.append(res)\n",
    "\n",
    "results.append(np.mean(results,0))\n",
    "results.append(np.std(results,0))\n",
    "    \n",
    "res_df = pd.DataFrame(np.array(results),\n",
    "                       columns=[ alpha+data_cls for alpha in ['SDD'] for data_cls in ['_train', '_test']]).round(4)\n",
    "res_df.to_csv(resultDir + f'SDD_result.csv', index=False)\n",
    "res_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Syn-2-4-4-2-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from utils import log, CausalDataset\n",
    "from module.SDD_IV import run as run_SynSDD\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "\n",
    "def get_args():\n",
    "    argparser = argparse.ArgumentParser(description=__doc__)\n",
    "    # About run setting !!!!\n",
    "    argparser.add_argument('--seed',default=2022,type=int,help='The random seed')\n",
    "    argparser.add_argument('--mode',default='vx',type=str,help='The choice of v/x/vx/xx')\n",
    "    argparser.add_argument('--rewrite_log',default=False,type=bool,help='Whether rewrite log file')\n",
    "    argparser.add_argument('--use_gpu',default=True,type=bool,help='The use of GPU')\n",
    "    # About data setting ~~~~\n",
    "    argparser.add_argument('--num',default=10000,type=int,help='The num of train\\val\\test dataset')\n",
    "    argparser.add_argument('--num_reps',default=10,type=int,help='The num of train\\val\\test dataset')\n",
    "    argparser.add_argument('--ate',default=0,type=float,help='The ate of constant')\n",
    "    argparser.add_argument('--by',default=1,type=float,help='if the outcome is binary')\n",
    "    argparser.add_argument('--sc',default=1,type=float,help='The sc')\n",
    "    argparser.add_argument('--sh',default=0,type=float,help='The sh')\n",
    "    argparser.add_argument('--one',default=1,type=int,help='the coefs between VZCU when computing t')\n",
    "    argparser.add_argument('--VX',default=0,type=int,help='nonlinear relation between v and x')\n",
    "    argparser.add_argument('--mV',default=2,type=int,help='The dim of Outside Instrumental variables V')\n",
    "    argparser.add_argument('--mX',default=10,type=int,help='The dim of Observerd variables X')\n",
    "    argparser.add_argument('--mZ',default=4,type=int,help='The dim of Inside Instrumental Variales Representations Z')\n",
    "    argparser.add_argument('--mC',default=4,type=int,help='The dim of Confounder Variales Representations C')\n",
    "    argparser.add_argument('--mA',default=2,type=int,help='The dim of Adjustable Variales Representations A')\n",
    "    argparser.add_argument('--mU',default=2,type=int,help='The dim of Unobserved confounding variables U')\n",
    "    argparser.add_argument('--storage_path',default='../Data/',type=str,help='The dir of data storage')\n",
    "    # Syn\n",
    "    argparser.add_argument('--syn_alpha',default=1,type=float,help='The ratio of treament prediction loss')\n",
    "    argparser.add_argument('--syn_beta',default=1,type=float,help='The ratio of disentangle loss')\n",
    "    argparser.add_argument('--syn_gamma',default=1,type=float,help='The ratio of imb loss')\n",
    "    argparser.add_argument('--syn_lambda',default=0.0001,type=float,help='The ratio of regularization loss')\n",
    "    argparser.add_argument('--syn_twoStage',default=True,type=bool,help='whether use twostage method')\n",
    "    argparser.add_argument('--lrate',default=5e-4,type=float,help='lrate')\n",
    "    argparser.add_argument('--iteration',default=2000,type=float,help='iteration')\n",
    "    argparser.add_argument('--output-delay',default=100,type=float,help='output-delay')\n",
    "    # About Debug or Show\n",
    "    argparser.add_argument('--verbose',default=1,type=int,help='The level of verbose')\n",
    "    argparser.add_argument('--epoch_show',default=5,type=int,help='The epochs of show time')\n",
    "    args = argparser.parse_args(args=[])\n",
    "    return args\n",
    "\n",
    "args = get_args()\n",
    "\n",
    "if args.use_gpu:\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else \"cpu\")\n",
    "    \n",
    "# set path\n",
    "which_benchmark = 'Syn_'+'_'.join(str(item) for item in [args.by,args.sc, args.sh, args.one, args.mV,args.mX,args.mU,args.VX])\n",
    "which_dataset = '_'.join(str(item) for item in [args.mZ, args.mC, args.mA])\n",
    "resultDir = args.storage_path + f'/results/{which_benchmark}_{which_dataset}/'\n",
    "dataDir = f'{args.storage_path}/data/{which_benchmark}/{which_dataset}/'\n",
    "os.makedirs(os.path.dirname(resultDir), exist_ok=True)\n",
    "logfile = f'{resultDir}/log_SDD.txt'\n",
    "\n",
    "if args.rewrite_log:\n",
    "    f = open(logfile,'w')\n",
    "    f.close()\n",
    "\n",
    "results = []\n",
    "alpha = args.syn_alpha\n",
    "beta = args.syn_beta\n",
    "gamma = args.syn_gamma\n",
    "for exp in range(args.num_reps):\n",
    "    # load data\n",
    "    train_df = pd.read_csv(dataDir + f'{exp}/train.csv')\n",
    "    val_df = pd.read_csv(dataDir + f'{exp}/val.csv')\n",
    "    test_df = pd.read_csv(dataDir + f'{exp}/test.csv')\n",
    "\n",
    "    train = CausalDataset(train_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "    val = CausalDataset(val_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "    test = CausalDataset(test_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "\n",
    "    res_list = []\n",
    "\n",
    "   \n",
    "    obj_val, final = run_SynSDD(exp, args, dataDir, resultDir, train, val, test, device)\n",
    "    res_list = res_list + [obj_val['ate_train'],obj_val['ate_test']]\n",
    "    \n",
    "    res = np.array(res_list) \n",
    "    results.append(res)\n",
    "\n",
    "results.append(np.mean(results,0))\n",
    "results.append(np.std(results,0))\n",
    "    \n",
    "res_df = pd.DataFrame(np.array(results),\n",
    "                       columns=[ alpha+data_cls for alpha in ['SDD'] for data_cls in ['_train', '_test']]).round(4)\n",
    "res_df.to_csv(resultDir + f'SDD_result.csv', index=False)\n",
    "res_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Syn-0-6-2-2-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from utils import log, CausalDataset\n",
    "from module.SDD import run as run_SynSDD\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "\n",
    "def get_args():\n",
    "    argparser = argparse.ArgumentParser(description=__doc__)\n",
    "    # About run setting !!!!\n",
    "    argparser.add_argument('--seed',default=2022,type=int,help='The random seed')\n",
    "    argparser.add_argument('--mode',default='vx',type=str,help='The choice of v/x/vx/xx')\n",
    "    argparser.add_argument('--rewrite_log',default=False,type=bool,help='Whether rewrite log file')\n",
    "    argparser.add_argument('--use_gpu',default=True,type=bool,help='The use of GPU')\n",
    "    # About data setting ~~~~\n",
    "    argparser.add_argument('--num',default=10000,type=int,help='The num of train\\val\\test dataset')\n",
    "    argparser.add_argument('--num_reps',default=10,type=int,help='The num of train\\val\\test dataset')\n",
    "    argparser.add_argument('--ate',default=0,type=float,help='The ate of constant')\n",
    "    argparser.add_argument('--by',default=1,type=float,help='if the outcome is binary')\n",
    "    argparser.add_argument('--sc',default=1,type=float,help='The sc')\n",
    "    argparser.add_argument('--sh',default=0,type=float,help='The sh')\n",
    "    argparser.add_argument('--one',default=1,type=int,help='the coefs between VZCU when computing t')\n",
    "    argparser.add_argument('--VX',default=0,type=int,help='nonlinear relation between v and x')\n",
    "    argparser.add_argument('--mV',default=0,type=int,help='The dim of Outside Instrumental variables V')\n",
    "    argparser.add_argument('--mX',default=10,type=int,help='The dim of Observerd variables X')\n",
    "    argparser.add_argument('--mZ',default=6,type=int,help='The dim of Inside Instrumental Variales Representations Z')\n",
    "    argparser.add_argument('--mC',default=2,type=int,help='The dim of Confounder Variales Representations C')\n",
    "    argparser.add_argument('--mA',default=2,type=int,help='The dim of Adjustable Variales Representations A')\n",
    "    argparser.add_argument('--mU',default=2,type=int,help='The dim of Unobserved confounding variables U')\n",
    "    argparser.add_argument('--storage_path',default='../Data/',type=str,help='The dir of data storage')\n",
    "    # Syn\n",
    "    argparser.add_argument('--syn_alpha',default=1,type=float,help='The ratio of treament prediction loss')\n",
    "    argparser.add_argument('--syn_beta',default=1,type=float,help='The ratio of disentangle loss')\n",
    "    argparser.add_argument('--syn_gamma',default=1,type=float,help='The ratio of imb loss')\n",
    "    argparser.add_argument('--syn_lambda',default=0.0001,type=float,help='The ratio of regularization loss')\n",
    "    argparser.add_argument('--syn_twoStage',default=True,type=bool,help='whether use twostage method')\n",
    "    argparser.add_argument('--lrate',default=5e-4,type=float,help='lrate')\n",
    "    argparser.add_argument('--iteration',default=2000,type=float,help='iteration')\n",
    "    argparser.add_argument('--output-delay',default=100,type=float,help='output-delay')\n",
    "    # About Debug or Show\n",
    "    argparser.add_argument('--verbose',default=1,type=int,help='The level of verbose')\n",
    "    argparser.add_argument('--epoch_show',default=5,type=int,help='The epochs of show time')\n",
    "    args = argparser.parse_args(args=[])\n",
    "    return args\n",
    "\n",
    "args = get_args()\n",
    "\n",
    "if args.use_gpu:\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else \"cpu\")\n",
    "    \n",
    "# set path\n",
    "which_benchmark = 'Syn_'+'_'.join(str(item) for item in [args.by,args.sc, args.sh, args.one, args.mV,args.mX,args.mU,args.VX])\n",
    "which_dataset = '_'.join(str(item) for item in [args.mZ, args.mC, args.mA])\n",
    "resultDir = args.storage_path + f'/results/{which_benchmark}_{which_dataset}/'\n",
    "dataDir = f'{args.storage_path}/data/{which_benchmark}/{which_dataset}/'\n",
    "os.makedirs(os.path.dirname(resultDir), exist_ok=True)\n",
    "logfile = f'{resultDir}/log_SDD.txt'\n",
    "\n",
    "if args.rewrite_log:\n",
    "    f = open(logfile,'w')\n",
    "    f.close()\n",
    "\n",
    "results = []\n",
    "alpha = args.syn_alpha\n",
    "beta = args.syn_beta\n",
    "gamma = args.syn_gamma\n",
    "for exp in range(args.num_reps):\n",
    "    # load data\n",
    "    train_df = pd.read_csv(dataDir + f'{exp}/train.csv')\n",
    "    val_df = pd.read_csv(dataDir + f'{exp}/val.csv')\n",
    "    test_df = pd.read_csv(dataDir + f'{exp}/test.csv')\n",
    "\n",
    "    train = CausalDataset(train_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "    val = CausalDataset(val_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "    test = CausalDataset(test_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "\n",
    "    res_list = []\n",
    "\n",
    "   \n",
    "    obj_val, final = run_SynSDD(exp, args, dataDir, resultDir, train, val, test, device)\n",
    "    res_list = res_list + [obj_val['ate_train'],obj_val['ate_test']]\n",
    "    \n",
    "    res = np.array(res_list) \n",
    "    results.append(res)\n",
    "\n",
    "results.append(np.mean(results,0))\n",
    "results.append(np.std(results,0))\n",
    "    \n",
    "res_df = pd.DataFrame(np.array(results),\n",
    "                       columns=[ alpha+data_cls for alpha in ['SDD'] for data_cls in ['_train', '_test']]).round(4)\n",
    "res_df.to_csv(resultDir + f'SDD_result.csv', index=False)\n",
    "res_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Syn-0-4-4-2-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from utils import log, CausalDataset\n",
    "from module.SDD import run as run_SynSDD\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "\n",
    "def get_args():\n",
    "    argparser = argparse.ArgumentParser(description=__doc__)\n",
    "    # About run setting !!!!\n",
    "    argparser.add_argument('--seed',default=2022,type=int,help='The random seed')\n",
    "    argparser.add_argument('--mode',default='vx',type=str,help='The choice of v/x/vx/xx')\n",
    "    argparser.add_argument('--rewrite_log',default=False,type=bool,help='Whether rewrite log file')\n",
    "    argparser.add_argument('--use_gpu',default=True,type=bool,help='The use of GPU')\n",
    "    # About data setting ~~~~\n",
    "    argparser.add_argument('--num',default=10000,type=int,help='The num of train\\val\\test dataset')\n",
    "    argparser.add_argument('--num_reps',default=10,type=int,help='The num of train\\val\\test dataset')\n",
    "    argparser.add_argument('--ate',default=0,type=float,help='The ate of constant')\n",
    "    argparser.add_argument('--by',default=1,type=float,help='if the outcome is binary')\n",
    "    argparser.add_argument('--sc',default=1,type=float,help='The sc')\n",
    "    argparser.add_argument('--sh',default=0,type=float,help='The sh')\n",
    "    argparser.add_argument('--one',default=1,type=int,help='the coefs between VZCU when computing t')\n",
    "    argparser.add_argument('--VX',default=0,type=int,help='nonlinear relation between v and x')\n",
    "    argparser.add_argument('--mV',default=0,type=int,help='The dim of Outside Instrumental variables V')\n",
    "    argparser.add_argument('--mX',default=10,type=int,help='The dim of Observerd variables X')\n",
    "    argparser.add_argument('--mZ',default=4,type=int,help='The dim of Inside Instrumental Variales Representations Z')\n",
    "    argparser.add_argument('--mC',default=4,type=int,help='The dim of Confounder Variales Representations C')\n",
    "    argparser.add_argument('--mA',default=2,type=int,help='The dim of Adjustable Variales Representations A')\n",
    "    argparser.add_argument('--mU',default=10,type=int,help='The dim of Unobserved confounding variables U')\n",
    "    argparser.add_argument('--storage_path',default='../Data/',type=str,help='The dir of data storage')\n",
    "    # Syn\n",
    "    argparser.add_argument('--syn_alpha',default=1,type=float,help='The ratio of treament prediction loss')\n",
    "    argparser.add_argument('--syn_beta',default=1,type=float,help='The ratio of disentangle loss')\n",
    "    argparser.add_argument('--syn_gamma',default=1,type=float,help='The ratio of imb loss')\n",
    "    argparser.add_argument('--syn_lambda',default=0.0001,type=float,help='The ratio of regularization loss')\n",
    "    argparser.add_argument('--syn_twoStage',default=True,type=bool,help='whether use twostage method')\n",
    "    argparser.add_argument('--lrate',default=5e-4,type=float,help='lrate')\n",
    "    argparser.add_argument('--iteration',default=2000,type=float,help='iteration')\n",
    "    argparser.add_argument('--output-delay',default=100,type=float,help='output-delay')\n",
    "    # About Debug or Show\n",
    "    argparser.add_argument('--verbose',default=1,type=int,help='The level of verbose')\n",
    "    argparser.add_argument('--epoch_show',default=5,type=int,help='The epochs of show time')\n",
    "    args = argparser.parse_args(args=[])\n",
    "    return args\n",
    "\n",
    "args = get_args()\n",
    "\n",
    "if args.use_gpu:\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else \"cpu\")\n",
    "    \n",
    "# set path\n",
    "which_benchmark = 'Syn_'+'_'.join(str(item) for item in [args.by,args.sc, args.sh, args.one, args.mV,args.mX,args.mU,args.VX])\n",
    "which_dataset = '_'.join(str(item) for item in [args.mZ, args.mC, args.mA])\n",
    "resultDir = args.storage_path + f'/results/{which_benchmark}_{which_dataset}/'\n",
    "dataDir = f'{args.storage_path}/data/{which_benchmark}/{which_dataset}/'\n",
    "os.makedirs(os.path.dirname(resultDir), exist_ok=True)\n",
    "logfile = f'{resultDir}/log_SDD.txt'\n",
    "\n",
    "if args.rewrite_log:\n",
    "    f = open(logfile,'w')\n",
    "    f.close()\n",
    "\n",
    "results = []\n",
    "alpha = args.syn_alpha\n",
    "beta = args.syn_beta\n",
    "gamma = args.syn_gamma\n",
    "for exp in range(args.num_reps):\n",
    "    # load data\n",
    "    train_df = pd.read_csv(dataDir + f'{exp}/train.csv')\n",
    "    val_df = pd.read_csv(dataDir + f'{exp}/val.csv')\n",
    "    test_df = pd.read_csv(dataDir + f'{exp}/test.csv')\n",
    "\n",
    "    train = CausalDataset(train_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "    val = CausalDataset(val_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "    test = CausalDataset(test_df, variables = ['u','x','v','z','a','p','s','m','t','g','y','f','c'])\n",
    "\n",
    "    res_list = []\n",
    "\n",
    "   \n",
    "    obj_val, final = run_SynSDD(exp, args, dataDir, resultDir, train, val, test, device)\n",
    "    res_list = res_list + [obj_val['ate_train'],obj_val['ate_test']]\n",
    "    \n",
    "    res = np.array(res_list) \n",
    "    results.append(res)\n",
    "\n",
    "results.append(np.mean(results,0))\n",
    "results.append(np.std(results,0))\n",
    "    \n",
    "res_df = pd.DataFrame(np.array(results),\n",
    "                       columns=[ alpha+data_cls for alpha in ['SDD'] for data_cls in ['_train', '_test']]).round(4)\n",
    "res_df.to_csv(resultDir + f'SDD_result.csv', index=False)\n",
    "res_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "vscode": {
   "interpreter": {
    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
