{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from utils import Config\n",
    "from simulation import Generator\n",
    "\n",
    "cfg = Config()\n",
    "for exp in range(cfg.exps):\n",
    "    print(\"Gen - exp: {}.\".format(exp))\n",
    "    cfg.seed = cfg.seed_base + cfg.seed_mul * exp\n",
    "    gen = Generator(cfg)\n",
    "    data_GT, data_train = gen.training(show=False)\n",
    "    test_GT = gen.testing(num=cfg.tnum, show=False)\n",
    "\n",
    "    data_setting = f\"{cfg.num}_{cfg.dim}_{cfg.y0_add}_{cfg.y1_add}_{cfg.noise_scale}\"\n",
    "    os.makedirs(os.path.dirname(f'./data/{data_setting}/{exp}/'), exist_ok=True)\n",
    "    np.savez(f'./data/{data_setting}/{exp}/train.npz', X=data_train[0], W=data_train[1], T=data_train[2], \n",
    "             D=data_train[3], Y=data_train[4], G=data_train[5], P=data_train[6], L=data_train[7])\n",
    "    np.savez(f'./data/{data_setting}/{exp}/valid.npz', X=data_GT[0], Lam0=data_GT[1], Lam1=data_GT[2], \n",
    "             P0=data_GT[3], P1=data_GT[4], Y0=data_GT[5], Y1=data_GT[6])\n",
    "    np.savez(f'./data/{data_setting}/{exp}/test.npz', X=test_GT[0], Lam0=test_GT[1], Lam1=test_GT[2], \n",
    "             P0=test_GT[3], P1=test_GT[4], Y0=test_GT[5], Y1=test_GT[6])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.optim as optim\n",
    "from cfr_df import \n",
    "from utils import Config, set_seed, trainData, testData, batchData, saveResult\n",
    "\n",
    "cfg = Config()\n",
    "\n",
    "batch_size = cfg.batch_size\n",
    "epochs = cfg.iterations\n",
    "lrate = cfg.lrate\n",
    "w_decay = 0.0001\n",
    "step_put = cfg.step_put\n",
    "seed = 2023\n",
    "\n",
    "epochs_draw = [600, 800, 1000, 1200]\n",
    "fig2_flag = True\n",
    "\n",
    "set_seed(seed)\n",
    "\n",
    "if cfg.batch_flag:\n",
    "    method_name = f'{cfg.mode}-{cfg.alpha}-{cfg.beta}-{cfg.batch_size}'\n",
    "else:\n",
    "    method_name = f'{cfg.mode}-{cfg.alpha}-{cfg.beta}'\n",
    "for exp in range(cfg.exps):\n",
    "    print(f\"This is the {exp}-th experiments ---- {method_name}. \")\n",
    "    data_setting = f\"{cfg.num}_{cfg.dim}_{cfg.y0_add}_{cfg.y1_add}_{cfg.noise_scale}\"\n",
    "    os.makedirs(os.path.dirname(f'./results/{data_setting}/{method_name}/figure/'), exist_ok=True)\n",
    "    data = np.load(f'./data/{data_setting}/{exp}/train.npz')\n",
    "    ttrain = trainData(data)\n",
    "\n",
    "    X, w, t, d, y, g, pY, lam = ttrain.all()\n",
    "    trains = batchData([X, w, t, d, y, g, pY, lam], batch_size)\n",
    "\n",
    "\n",
    "    valid = np.load(f'./data/{data_setting}/{exp}/valid.npz')\n",
    "    teval = testData(valid)\n",
    "\n",
    "    test = np.load(f'./data/{data_setting}/{exp}/test.npz')\n",
    "    ttest = testData(test)\n",
    "\n",
    "    p_treated = torch.mean(w.float())\n",
    "\n",
    "    dim = X.shape[1]\n",
    "    net = Nets(dim, p_treated, cfg.mode, cfg.alpha, cfg.beta)\n",
    "    optimizer = optim.Adam(net.parameters(), lr=lrate, weight_decay=w_decay)\n",
    "\n",
    "    result_trt = saveResult()\n",
    "    result_tst = saveResult()\n",
    "    for epoch in range(epochs):\n",
    "\n",
    "        if cfg.batch_flag:\n",
    "            X_b, w_b, t_b, d_b, y_b, g_b, pY_b, lam_b = trains.get_batch()\n",
    "        else:\n",
    "            X_b, w_b, t_b, d_b, y_b, g_b, pY_b, lam_b = trains.get_all()\n",
    "\n",
    "        p, lamb, loss = net(X_b, w_b, d_b, t_b, y_b)\n",
    "        \n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if epoch % step_put == 0:\n",
    "            X_b, w_b, t_b, d_b, y_b, g_b, pY_b, lam_b = trains.get_all()\n",
    "            p, lamb, loss = net(X_b, w_b, d_b, t_b, y_b)\n",
    "\n",
    "            hat_d0, hat_y0, hat_d1, hat_y1 = net.predict(teval.X)\n",
    "            result_trt.one(epoch, loss, ((p-pY_b)**2).mean(), ((lamb*lam_b - 1)**2).mean(), hat_d0, hat_y0, \n",
    "                    hat_d1, hat_y1, teval.Lam0, teval.Y0, teval.P0, teval.Lam1, teval.Y1, teval.P1)\n",
    "            \n",
    "            hat_d0, hat_y0, hat_d1, hat_y1 = net.predict(ttest.X)\n",
    "            result_tst.one(epoch, loss, ((p-pY_b)**2).mean(), ((lamb*lam_b - 1)**2).mean(), hat_d0, hat_y0, \n",
    "                    hat_d1, hat_y1, ttest.Lam0, ttest.Y0, ttest.P0, ttest.Lam1, ttest.Y1, ttest.P1)\n",
    "            \n",
    "            print(\"Epoch-{}: {:.2f}, {:.2f}.\".format(epoch, ((p-pY_b)**2).mean(), ((lamb*lam_b - 1)**2).mean()))\n",
    "        \n",
    "        if epoch in epochs_draw and fig2_flag:\n",
    "            hat_d0, hat_y0, _, _ = net.predict(ttest.X)\n",
    "            fig2_lams = torch.cat([hat_d0, hat_y0, ttest.Lam0, ttest.Lam1, ttest.Y0, ttest.Y1], axis=1).detach().numpy()\n",
    "            fig2_path = f'./results/{data_setting}/{method_name}/figure/lam{exp}_{epoch}.npy'\n",
    "            np.save(fig2_path, fig2_lams)\n",
    "\n",
    "    \n",
    "    result_trt.full = result_trt.full.round(4)\n",
    "    result_trt.full.to_csv(f'./results/{data_setting}/{method_name}/re{exp}_trt.csv', index=False)\n",
    "\n",
    "    result_tst.full = result_tst.full.round(4)\n",
    "    result_tst.full.to_csv(f'./results/{data_setting}/{method_name}/re{exp}_tst.csv', index=False)\n",
    "    print(f\"Result save to: /results/{data_setting}/{method_name}/re{exp}_*.csv.\")\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
