{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aa667681",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fd4ddbe6370>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import models as mdl\n",
    "import utility as util\n",
    "import constraint as cnst\n",
    "import L0 as lreg\n",
    "import torch\n",
    "from torch import optim\n",
    "from torch import nn\n",
    "from torch.autograd import Function\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from tqdm import tqdm\n",
    "import heapq\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6e016a9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#lat_dim=120\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "epochs=1500#4000\n",
    "batch_size=512#512\n",
    "hid_y=100\n",
    "BCE=nn.BCELoss() #reduction='mean'\n",
    "MSE=nn.MSELoss(reduction='mean')\n",
    "MSE_2=nn.MSELoss(reduction='sum')\n",
    "beta=0.2\n",
    "threshold=0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "20519458",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_loss=[]\n",
    "tr_pehe=[]\n",
    "val_pehe=[]\n",
    "reg_loss_tr=[]\n",
    "classT_loss_tr=[]\n",
    "classW_loss_tr=[]\n",
    "wasser_loss_tr=[]\n",
    "rc_loss=[]\n",
    "OR_re_los_li=[]\n",
    "\n",
    "snr_gamma=[]\n",
    "snr_delta=[]\n",
    "snr_upsilon=[]\n",
    "snr_omega=[]\n",
    "val_reg_li=[]\n",
    "val_tloss_li=[]\n",
    "val_wasser_li=[]\n",
    "val_rec_li=[]\n",
    "kl_l=[]\n",
    "gamma_mean=[]\n",
    "gamma_var=[]\n",
    "\n",
    "lambd_li=[]\n",
    "lambd_li2=[]\n",
    "kl_total=[]\n",
    "pen_li=[]\n",
    "gamma_count=[]\n",
    "delta_count=[]\n",
    "upsilon_count=[]\n",
    "omega_count=[]\n",
    "bce_li=[]\n",
    "mse_li=[]\n",
    "reg_const=[]\n",
    "bce_const=[]\n",
    "def train_CFR(x_data,y_data,t):\n",
    "   \n",
    "    torch.manual_seed(0)\n",
    "\n",
    "    patience = 100 # Number of epochs without improvement before stopping\n",
    "    max_saved_models = 5  # Store the best 10 models\n",
    "    \n",
    "    # Tracking best models (min-heap)\n",
    "    best_models = []\n",
    "    best_val_loss = float('inf')\n",
    "    epochs_without_improvement = 0\n",
    "    saved_model_count=0\n",
    "    tol=t\n",
    "    mse_constraint=cnst.Constraint(tol, device, lambda_min=0.0, lambda_max=200.0, lambda_init=1.0, alpha=0.99)\n",
    "    iter_num = 0\n",
    "    alpha=0.99\n",
    "    pretrain=1\n",
    "    lbd_step=30\n",
    "    #x_data,y_data=get_data('train',1)\n",
    "    X_train, X_val,y_train, y_val = train_test_split(x_data,y_data ,\n",
    "                                       random_state=42, \n",
    "                                       test_size=0.01)\n",
    "    # write input_feature instead of 25\n",
    "    net=mdl.net(input_features,hid_enc,lat_dim,.1)\n",
    "    #wnet=mdl.Wclassifier(lat_dim,.1)\n",
    "    tnet=mdl.Tclassifier(lat_dim*2,.1)\n",
    "    rnet=mdl.Regressors(lat_dim*2,hid_y,.1)\n",
    "    #lmd_net=lamd()\n",
    "    #dnet=Decoder(lat_dim*4,input_features,.01)\n",
    "    dnet=mdl.Decoder_2(lat_dim*4,input_features,.1)\n",
    "\n",
    "    sparse_model1 = lreg.SparseNet(lat_dim,1)\n",
    "    sparse_model2 = lreg.SparseNet(lat_dim,1)\n",
    "    sparse_model3 = lreg.SparseNet(lat_dim,1)\n",
    "    sparse_model4 = lreg.SparseNet(lat_dim,1)\n",
    "\n",
    "    net.to(device)\n",
    "    tnet.to(device)\n",
    "    rnet.to(device)\n",
    "    dnet.to(device)\n",
    "    #lmd_net.to(device)\n",
    "    sparse_model1.to(device)\n",
    "    sparse_model2.to(device)\n",
    "    sparse_model3.to(device)\n",
    "    sparse_model4.to(device)\n",
    "    #print(input_features)\n",
    "    #rnet=Regressors_2(lat_dim*2,.1)\n",
    "    opt_net = torch.optim.Adam(net.parameters(), lr=0.0001,weight_decay=1e-3)#\n",
    "    opt_nett = torch.optim.Adam(tnet.parameters(), lr=0.00002, weight_decay=1e-3)\n",
    "    opt_netr = torch.optim.Adam(rnet.parameters(), lr=0.00002,weight_decay=1e-2) #0.00003\n",
    "    opt_netd = torch.optim.Adam(dnet.parameters(), lr=0.0001,weight_decay=1e-3)\n",
    "    opt_lmd = torch.optim.Adam(mse_constraint.parameters(), lr=0.001,weight_decay=1e-3)\n",
    "    #opt_lmd2 = torch.optim.Adam(bce_constraint.parameters(), lr=0.0001,weight_decay=1e-3)\n",
    "    sparse_optimizer1 = optim.Adam(sparse_model1.parameters(), lr=0.00022) #0.0002,0.00013\n",
    "    sparse_optimizer2 = optim.Adam(sparse_model2.parameters(), lr=0.00022) # 0.0002,0.00015\n",
    "    sparse_optimizer3 = optim.Adam(sparse_model3.parameters(), lr=0.00022) #0.00014,0.00015\n",
    "    sparse_optimizer4 = optim.Adam(sparse_model4.parameters(), lr=0.00022) # 0.0002,0.00018\n",
    "\n",
    "\n",
    "    mask_gamma = torch.ones(dim_range*4).to(device)\n",
    "    mask_delta = torch.ones(dim_range*4).to(device)\n",
    "    mask_upsilon = torch.ones(dim_range*4).to(device)\n",
    "    mask_omega = torch.ones(dim_range*4).to(device)\n",
    "   \n",
    "    milestones = {int(epochs * ratio): False for ratio in [0.05]} #0.4, 0.6, 0.99\n",
    "\n",
    "    for ep in tqdm(range(1,epochs+1 )):\n",
    "                             \n",
    "            #print(torch.sum(mask_gamma).item(),torch.sum(mask_delta).item(),torch.sum(mask_upsilon).item(),torch.sum(mask_omega).item())\n",
    "        train_dataloader_sr, train_dataloader_tr=util.get_dataloader(X_train,y_train,batch_size)\n",
    "        tot_los=0\n",
    "        reg_los=0\n",
    "        t_los=0\n",
    "        w_los=0\n",
    "        wa_los=0\n",
    "        rc_los=0\n",
    "        OR_re_los=0\n",
    "        cnt=0\n",
    "        kl=0\n",
    "        mse_er=0\n",
    "        pen=0\n",
    "        gc=0\n",
    "        dc=0\n",
    "        uc=0\n",
    "        oc=0\n",
    "        bce_er=0\n",
    "        mse_lamba=0\n",
    " \n",
    "        for batch_idx, (train_source_data, train_target_data) in enumerate(zip(train_dataloader_sr, train_dataloader_tr)):\n",
    "            \n",
    "            xs,ys=train_source_data\n",
    "            xt,yt=train_target_data\n",
    "            # Replace 30 with :\n",
    "            xs_train=xs[:,5:].to(device)\n",
    "            xt_train=xt[:,5:].to(device)\n",
    "\n",
    "\n",
    "            train_x=torch.cat((xs_train,xt_train),0)\n",
    "            train_y=torch.unsqueeze(torch.cat((ys,yt),0), dim=1).to(device)\n",
    "            true_t=torch.unsqueeze(torch.cat((xs[:,0],xt[:,0]),0), dim=1).to(device)\n",
    "            concat_true=torch.cat((train_y,true_t),1)\n",
    "            prop_t1=(xt_train.shape[0]/train_x.shape[0])\n",
    "\n",
    "            \n",
    "            \n",
    "           \n",
    "            #phi, phi_mean, phi_var=net(train_x,mask_gamma,mask_delta,mask_upsilon)\n",
    "            phi, phi_mean, phi_var=net(train_x,fstart,fend,sstart,send,tstart,tend,frstart,frend)\n",
    "            \n",
    "                \n",
    "\n",
    "\n",
    "            \n",
    "            phi_gamma=phi[:,fstart:fend]\n",
    "            phi_delta=phi[:,sstart:send]\n",
    "            phi_upsilon=phi[:,tstart:tend]\n",
    "            phi_irr=phi[:,frstart:frend]\n",
    "            \n",
    "            #print(phi.shape)            \n",
    "            phi_gamma, penalty_gamma,mask_gamma_d = sparse_model1(phi)\n",
    "            phi_delta, penalty_delta,mask_delta_d = sparse_model2(phi)\n",
    "            phi_upsilon, penalty_upsilon,mask_upsilon_d = sparse_model3(phi)\n",
    "            phi_omega, penalty_omega,mask_omega_d = sparse_model4(phi)\n",
    "           \n",
    "            \n",
    "            all_mask=torch.vstack((mask_gamma_d,mask_delta_d,mask_upsilon_d,mask_omega_d))\n",
    "            \n",
    "            #soft_masks= F.softmax(all_mask/0.3, dim=0) \n",
    "\n",
    "            L_excl,soft_masks = util.exclusivity_loss(all_mask, temperature=0.1)\n",
    "            \n",
    "            \n",
    "\n",
    "        \n",
    "            mask_gamma=soft_masks[0, :]\n",
    "            mask_delta=soft_masks[1, :]\n",
    "            mask_upsilon=soft_masks[2, :]\n",
    "            mask_omega=soft_masks[3, :]\n",
    "            \n",
    "           \n",
    "            gamma_c=util.get_dim_count(mask_gamma,0.5)\n",
    "            delta_c=util.get_dim_count(mask_delta,0.5)\n",
    "            upsilon_c=util.get_dim_count(mask_upsilon,0.5)\n",
    "            omega_c=util.get_dim_count(mask_omega,0.5)\n",
    "            \n",
    "            penalty=penalty_delta+penalty_upsilon+penalty_gamma+penalty_omega\n",
    "\n",
    "            phi_all=torch.cat((phi_gamma*mask_gamma,phi_delta*mask_delta,phi_upsilon*mask_upsilon,phi_omega*mask_omega), 1)\n",
    "            del_ups=torch.cat((phi_delta*mask_delta, phi_upsilon*mask_upsilon), 1)\n",
    "            gam_del=torch.cat((phi_gamma*mask_gamma,phi_delta*mask_delta), 1)\n",
    "\n",
    "\n",
    "\n",
    "            \n",
    "            \n",
    "            concat_pred=rnet(del_ups)\n",
    "           \n",
    "            \n",
    "            #w_f=wnet(phi_delta)\n",
    "            w_t=tnet(gam_del)\n",
    "            \n",
    "            decoded_space=dnet(phi_all)\n",
    "            \n",
    "            predicted_y=torch.unsqueeze(torch.where(true_t.squeeze() == 0, concat_pred[:,0], concat_pred[:,1]),dim=1)\n",
    "            opt_net.zero_grad()\n",
    "           \n",
    "            opt_nett.zero_grad()\n",
    "            opt_netr.zero_grad()\n",
    "            opt_netd.zero_grad()\n",
    "            opt_lmd.zero_grad()\n",
    "            #opt_lmd2.zero_grad()\n",
    "\n",
    "\n",
    "\n",
    "            sparse_optimizer1.zero_grad()\n",
    "            sparse_optimizer2.zero_grad()\n",
    "            sparse_optimizer3.zero_grad()\n",
    "            sparse_optimizer4.zero_grad()\n",
    "            \n",
    "            y_LOSS,mse_cons=util.Y_GECO(predicted_y,train_y, tol)\n",
    "            kl_div_g=util.kl_loss_2(phi_mean*mask_gamma, phi_var*mask_gamma,beta)\n",
    "            kl_div_d=util.kl_loss_2(phi_mean*mask_delta, phi_var*mask_delta,beta)\n",
    "            kl_div_u=util.kl_loss_2(phi_mean*mask_upsilon, phi_var*mask_upsilon,beta)\n",
    "            kl_div_o=util.kl_loss_2(phi_mean*mask_omega, phi_var*mask_omega,beta)\n",
    "            kl_div=((kl_div_g+kl_div_d+kl_div_u+kl_div_o)/4)\n",
    "            \n",
    "            Reconstruction_loss=(MSE(decoded_space,train_x))\n",
    "            \n",
    "            Wassloss,dist=util.wasserstein(phi_upsilon*mask_upsilon,true_t)\n",
    "            \n",
    "            #t_man_loss,bce_cons=binary_cross_entropy_manual(w_t, true_t,tol2)\n",
    "            \n",
    "            Tloss=BCE(w_t, true_t)\n",
    "            Rloss,lambd=mse_constraint(y_LOSS)\n",
    "            #Rloss=lambd*y_LOSS\n",
    "\n",
    "            combined_loss=(Rloss)+(0.1*kl_div)+penalty+(Tloss)+(Reconstruction_loss)+(0.1*Wassloss)+L_excl\n",
    "            \n",
    "            #+100*ME_loss\n",
    "            \n",
    "            \n",
    "            \n",
    "            combined_loss.backward() #retain_graph=True\n",
    "            tot_los=tot_los+combined_loss.item()\n",
    "            reg_los=reg_los+Rloss.item()\n",
    "            t_los=t_los+Tloss.item()\n",
    "            #w_los=w_los+OR.item()\n",
    "            wa_los=wa_los+Wassloss.item()\n",
    "            rc_los=rc_los+Reconstruction_loss.item()\n",
    "            #OR_re_los=OR_re_los+OR_re.item()\n",
    "            kl=kl+kl_div.item()\n",
    "            cnt=cnt+1\n",
    "            mse_er=mse_er+mse_cons\n",
    "            pen=pen+penalty.item()\n",
    "            gc=gc+gamma_c\n",
    "            dc=dc+delta_c\n",
    "            uc=uc+upsilon_c\n",
    "            oc=oc+omega_c\n",
    "            \n",
    "            #bce_er=bce_er+bce_cons\n",
    "            \n",
    "            # optimize\n",
    "            \n",
    "            \n",
    "            opt_net.step()\n",
    "           \n",
    "            opt_nett.step()\n",
    "            opt_netr.step()\n",
    "            opt_netd.step()\n",
    "            sparse_optimizer1.step()\n",
    "            sparse_optimizer2.step()\n",
    "            sparse_optimizer3.step()\n",
    "            sparse_optimizer4.step()\n",
    "            opt_lmd.step()\n",
    "        \n",
    "          \n",
    "       \n",
    "        val_PEHE_nn=util.cal_pehe_nn(X_val, y_val,rnet,net,mask_gamma,mask_delta,mask_upsilon,fstart,fend,sstart,send,tstart,tend,frstart,frend)\n",
    "        tr_PEHE=util.cal_pehe(X_train,y_train,rnet,net,mask_gamma,mask_delta,mask_upsilon,fstart,fend,sstart,send,tstart,tend,frstart,frend)\n",
    "\n",
    "      \n",
    "        \n",
    "        \n",
    "        # Evaluation of validation\n",
    "        X_train_t= X_train.to_numpy()\n",
    "        X_train_t=torch.from_numpy(X_train_t.astype(np.float32)).to(device)\n",
    "        \n",
    "        \n",
    "        X_val_t=X_val.to_numpy()\n",
    "        X_val_t=torch.from_numpy(X_val_t.astype(np.float32)).to(device)\n",
    "        \n",
    "        y_val_t=y_val.to_numpy()\n",
    "        y_val_t=torch.from_numpy(y_val_t.astype(np.float32)).to(device)\n",
    "\n",
    "        \n",
    "\n",
    "        \n",
    "        if ep > 500:\n",
    "                if val_PEHE_nn < best_val_loss:\n",
    "                    best_val_loss = val_PEHE_nn\n",
    "                    epochs_without_improvement = 0  # Reset patience counter\n",
    "                else:\n",
    "                    epochs_without_improvement += 1\n",
    "                \n",
    "                # Save the model if it's among the top 10 and hasn't improved for 50 epochs\n",
    "                if epochs_without_improvement >= patience:\n",
    "                    epochs_without_improvement = 0 \n",
    "                    if len(best_models) < max_saved_models or val_PEHE_nn < best_models[-1][0]:\n",
    "                        # Save the model\n",
    "                        model_state = {\n",
    "                            'encoder': net.state_dict(),\n",
    "                            'regressor': rnet.state_dict(),\n",
    "                            'classifier': tnet.state_dict(),\n",
    "                            'val_loss': val_PEHE_nn,\n",
    "                            'epoch': ep\n",
    "                        }\n",
    "                        model_filename = f'model_{ep}_val{val_PEHE_nn:.4f}.pth'\n",
    "                        torch.save(model_state, model_filename)\n",
    "        \n",
    "                        # Add to best models list\n",
    "                        heapq.heappush(best_models, (val_PEHE_nn, model_filename))\n",
    "                        saved_model_count += 1\n",
    "        \n",
    "                        # Keep only the best 10 models\n",
    "                        if len(best_models) > max_saved_models:\n",
    "                            worst_model = heapq.heappop(best_models)  # Remove the worst model\n",
    "                            torch.cuda.empty_cache()  # Free memory if needed\n",
    "        \n",
    "                # Stop training after saving 10 models\n",
    "        if saved_model_count >= max_saved_models:\n",
    "            print(f\"Early stopping at epoch {ep}: 10 best models saved.\")\n",
    "            break\n",
    "\n",
    "    \n",
    "        \n",
    "        reg_loss_tr.append(reg_los/cnt)\n",
    "        classT_loss_tr.append(t_los/cnt)\n",
    "        #classW_loss_tr.append(w_los)\n",
    "        wasser_loss_tr.append(wa_los/cnt)\n",
    "        total_loss.append(tot_los/cnt)\n",
    "        tr_pehe.append(tr_PEHE)\n",
    "        val_pehe.append(val_PEHE_nn)\n",
    "        rc_loss.append(rc_los/cnt)\n",
    "        #OR_re_los_li.append(OR_re_los)\n",
    "        kl_l.append(kl/cnt)\n",
    "\n",
    "        lambd_li.append(lambd)\n",
    "        #lambd_li2.append(lambd2)\n",
    "      \n",
    "        pen_li.append(pen/cnt)\n",
    "        gamma_count.append(gc/cnt)\n",
    "        delta_count.append(dc/cnt)\n",
    "        upsilon_count.append(uc/cnt)\n",
    "        omega_count.append(oc/cnt)\n",
    "        #bce_li.append(bce_er/cnt)\n",
    "        mse_li.append(mse_er/cnt)\n",
    "        #total_loss.append(tot_los/cnt)\n",
    "            \n",
    "\n",
    "    return rnet,net,X_train_t,mask_gamma,mask_delta,mask_upsilon,mask_omega,mse_er,best_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "513e9486-2beb-4aa1-82ed-4db2d2cfce49",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "scaler_x = MinMaxScaler()\n",
    "\n",
    "lat_dim_li=[15*4] #20,40,60\n",
    "dummy_list=[5]\n",
    "tol_li=[0.01] #1,3,5,7,9,11,13,15\n",
    "gamma_dim_count=[]\n",
    "delta_dim_count=[]\n",
    "upsilon_dim_count=[]\n",
    "omega_dim_count=[]\n",
    "bce_list=[]\n",
    "mse_list=[]\n",
    "pehe_list=[]\n",
    "dim_range=15\n",
    "lat_dim=dim_range*4\n",
    "fstart=0\n",
    "fend=fstart+dim_range\n",
    "sstart=fend\n",
    "send=sstart+dim_range\n",
    "tstart=send\n",
    "tend=tstart+dim_range\n",
    "frstart=tend\n",
    "frend=frstart+dim_range\n",
    "\n",
    "hid_enc=lat_dim\n",
    "\n",
    "N = 1000  # Sample size\n",
    "m_gamma = 8    # Dimensionality of Γ\n",
    "m_delta = 8    # Dimensionality of ∆\n",
    "m_upsilon = 8 \n",
    " # 0,5,10,15,20,25\n",
    "for di,dum in enumerate(dummy_list):\n",
    "\n",
    "    if((dim_range>=50) or (dum>=10)) :\n",
    "        beta=0.1\n",
    "    else:\n",
    "        beta=0.2  \n",
    "        \n",
    "    dummies=dum\n",
    "    reg_loss_li=[]\n",
    "    tol=0.01\n",
    "\n",
    "    rel_index=m_gamma+m_delta+m_upsilon\n",
    "\n",
    "    input_features=rel_index+dummies\n",
    "       # Dimensionality of Υ\n",
    "\n",
    "    #t_index=m_gamma+m_delta+m_upsilon+dummies\n",
    "\n",
    "    pehe_l=[]\n",
    "    ate_l=[]\n",
    "    pehenn_l=[]\n",
    "    pehe_final=[]\n",
    "    gamma_arr=np.array([0]*input_features)\n",
    "    delta_arr=np.array([0]*input_features)\n",
    "    upsilon_arr=np.array([0]*input_features)\n",
    "    irrel_arr=np.array([0]*input_features)\n",
    "    mig_gamma=[0]*input_features\n",
    "    mig_delta=[0]*input_features\n",
    "    mig_upsilon=[0]*input_features\n",
    "    mig_omega=[0]*input_features\n",
    "    \n",
    "    col =  [\"treatment\", \"y_factual\", \"y_cfactual\",\"mu0\",\"mu1\"]\n",
    "    for j in range(rel_index+dummies):\n",
    "        col.append(\"x\"+str(j))\n",
    "\n",
    "        #clean()\n",
    "\n",
    "\n",
    "    \n",
    "    for fl in range(1,5):\n",
    "\n",
    "        datapre=pd.read_csv(f\"./Dataset/Syn_1.0_1.0_0/8_8_8/4_{fl}.csv\")\n",
    "      \n",
    "\n",
    "        x_data_n=util.add_dummy_features_shuffle_syn(datapre,dummies)\n",
    "        x_data_n.columns=col\n",
    "        \n",
    "        x_data_n_nor= pd.DataFrame(scaler_x.fit_transform(x_data_n.iloc[:,5:]))\n",
    "        data_train_tran=pd.concat([x_data_n.iloc[:,0:5], x_data_n_nor], axis=1)\n",
    "        data_train_tran.columns=col\n",
    "        Regressor,Encoder,X_train_t,mask_gamma,mask_delta,mask_upsilon,mask_omega,reg_los,best_models=train_CFR(x_data_n,x_data_n.iloc[:,1],tol)\n",
    "\n",
    "\n",
    "         #_______________loading best models\n",
    "        Enc=mdl.net(input_features,hid_enc,lat_dim,.1)\n",
    "        Enc.to(device)\n",
    "        Reg=mdl.Regressors(lat_dim*2,hid_y,.1)\n",
    "        Reg.to(device)\n",
    "        tnet=mdl.Tclassifier(lat_dim*2,.1)\n",
    "        tnet.to(device)\n",
    "\n",
    "        for idx, (val_loss, filename) in enumerate(best_models):\n",
    "            #print(f\"Loading model {idx+1}: {filename} (val_loss={val_loss:.4f})\")\n",
    "        \n",
    "            # Load saved state dict\n",
    "            model_state = torch.load(filename) #map_location='cpu'\n",
    "            Enc.load_state_dict(model_state['encoder'])\n",
    "            Reg.load_state_dict(model_state['regressor'])\n",
    "            tnet.load_state_dict(model_state['classifier'])\n",
    "        \n",
    "            x_data_test=pd.read_csv(f\"./Dataset/Syn_1.0_1.0_0/8_8_8/4_test.csv\")\n",
    "            #x_data_test=pd.read_csv(f\"./data/Syn_1.0_1.0_0/16_16_16/4_test.csv\")\n",
    "            #x_data_test=pd.read_csv(f\"./data/Syn_1.0_1.0_0/4_4_4/4_test.csv\")\n",
    "            data_n=util.add_dummy_features_shuffle_syn(x_data_test,dummies)\n",
    "            data_n.columns=col\n",
    "          \n",
    "            data_n_nor= pd.DataFrame(scaler_x.transform(data_n.iloc[:,5:]))\n",
    "            data_test_tran=pd.concat([data_n.iloc[:,0:5], data_n_nor], axis=1)\n",
    "            data_test_tran.columns=col\n",
    "            \n",
    "    \n",
    "        #save_plot(Encoder,dummies)\n",
    "            tepehe=util.cal_pehe(data_n,data_n.iloc[:,1],Reg,Enc,mask_gamma,mask_delta,mask_upsilon,fstart,fend,sstart,send,tstart,tend,frstart,frend)\n",
    "            tepehenn=util.cal_pehe_nn(data_n,data_n.iloc[:,1],Reg,Enc,mask_gamma,mask_delta,mask_upsilon,fstart,fend,sstart,send,tstart,tend,frstart,frend)\n",
    "            pehe_l.append(tepehe)\n",
    "            #ate_l.append(teate)\n",
    "            pehenn_l.append(tepehenn)\n",
    "                \n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "\n",
    "        \n",
    "        reg_loss_li.append(reg_los)\n",
    "        gamma_dim_count.append(np.mean(gamma_count[-5:]))\n",
    "        delta_dim_count.append(np.mean(delta_count[-5:]))\n",
    "        upsilon_dim_count.append(np.mean(upsilon_count[-5:]))\n",
    "        omega_dim_count.append(np.mean(omega_count[-5:]))\n",
    "        torch.cuda.empty_cache()\n",
    "        pehe_final.append(np.mean(pehe_l))\n",
    "        #clear()\n",
    "        \n",
    "    print('Data : 8_8_8_',dum)\n",
    "    print('Dummies: ',dum)\n",
    "    print('PEHE mean: ', np.mean(pehe_final))\n",
    "    print('PEHE std: ', np.std(pehe_final))\n",
    "    #print('PEHE_nn mean: ', np.mean(pehenn_l))\n",
    "    print('******************************')\n",
    "    print('Regression loss',np.mean(reg_loss_li))\n",
    "\n",
    "    print('Gamma dim on avg',np.mean(gamma_dim_count))\n",
    "    print('Delta dim on avg',np.mean(delta_dim_count))\n",
    "    print('Upsilon dim on avg',np.mean(upsilon_dim_count))\n",
    "    print('Omega dim on avg',np.mean(omega_dim_count))\n",
    "    \n",
    "   \n",
    "    print('*************************************************')\n",
    "    print('___________________________Next Tolerance_______________________________')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "167f490c-d58d-4c3c-bdd9-026d71ecb42d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "VAEITE",
   "language": "python",
   "name": "vaeite"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
