{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "30bcee80-0133-4636-b36f-ed74bc23b490",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f575311d370>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import models as mdl\n",
    "import utility as util\n",
    "import constraint as cnst\n",
    "import L0 as lreg\n",
    "import torch\n",
    "from torch import optim\n",
    "from torch import nn\n",
    "from torch.autograd import Function\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from tqdm import tqdm\n",
    "import heapq\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "38b75bf8-434b-4054-8326-1114e508f18c",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "epochs=8000 #5500\n",
    "batch_size=500\n",
    "hid_y=100\n",
    "beta=1#0.03,0.5\n",
    "BCE=nn.BCELoss()\n",
    "MSE=nn.MSELoss(reduction='mean')\n",
    "MSE_2=nn.MSELoss(reduction='sum')\n",
    "threshold=0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cdb4daa1-a5db-4b93-84d9-3a20ec99358d",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_loss=[]\n",
    "tr_pehe=[]\n",
    "val_pehe=[]\n",
    "reg_loss_tr=[]\n",
    "classT_loss_tr=[]\n",
    "classW_loss_tr=[]\n",
    "wasser_loss_tr=[]\n",
    "rc_loss=[]\n",
    "OR_reg_li=[]\n",
    "\n",
    "val_reg_li=[]\n",
    "val_tloss_li=[]\n",
    "val_wasser_li=[]\n",
    "val_rec_li=[]\n",
    "kl_l=[]\n",
    "snr_gamma=[]\n",
    "snr_delta=[]\n",
    "snr_upsilon=[]\n",
    "snr_omega=[]\n",
    "\n",
    "gamma_mean=[]\n",
    "gamma_var=[]\n",
    "delta_mean=[]\n",
    "delta_var=[]\n",
    "upsilon_mean=[]\n",
    "upsilon_var=[]\n",
    "omega_mean=[]\n",
    "omega_var=[]\n",
    "total_mean=[]\n",
    "total_var=[]\n",
    "lambd_li=[]\n",
    "kl_total=[]\n",
    "pen_li=[]\n",
    "gamma_count=[]\n",
    "delta_count=[]\n",
    "upsilon_count=[]\n",
    "omega_count=[]\n",
    "mse_li=[]\n",
    "l0_li=[]\n",
    "lambd2_li=[]\n",
    "ex_cl_lis=[]\n",
    "def train_CFR(x_data,y_data,t):\n",
    "   \n",
    "    torch.manual_seed(0)\n",
    "    #device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    \n",
    "    patience = 500 # Number of epochs without improvement before stopping\n",
    "    max_saved_models = 5  # Store the best 10 models\n",
    "    \n",
    "    # Tracking best models (min-heap)\n",
    "    best_models = []\n",
    "    best_val_loss = float('inf')\n",
    "    epochs_without_improvement = 0\n",
    "    saved_model_count=0\n",
    "    tol=t\n",
    "    \n",
    "    \n",
    "    iter_num = 0\n",
    "    alpha=0.99\n",
    "    pretrain=1\n",
    "    lbd_step=30\n",
    "    kl_flag=True\n",
    "\n",
    "    #x_data,y_data=get_data('train',1)\n",
    "    X_train, X_val,y_train, y_val = train_test_split(x_data,y_data ,\n",
    "                                       random_state=42, \n",
    "                                       test_size=0.2, shuffle=True, stratify=x_data.iloc[:,0].values)\n",
    "    # write input_feature instead of 25\n",
    "    net=mdl.net(input_features,hid_enc,lat_dim,.1)\n",
    "    #wnet=Wclassifier(lat_dim,.1)\n",
    "    tnet=mdl.Tclassifier(lat_dim*2,.1)\n",
    "    rnet=mdl.Regressors(lat_dim*2,hid_y,.1)\n",
    "    mse_constraint=cnst.Constraint(tol, device, lambda_min=0.0, lambda_max=200.0, lambda_init=1.0, alpha=0.99)\n",
    "    #dnet=Decoder_VAE(lat_dim*4,input_features,.1)\n",
    "    dnet=mdl.Decoder(lat_dim*4,input_features,.1)\n",
    "    sparse_model1 = lreg.SparseNet(lat_dim,1)\n",
    "    sparse_model2 = lreg.SparseNet(lat_dim,1)\n",
    "    sparse_model3 = lreg.SparseNet(lat_dim,1)\n",
    "    sparse_model4 = lreg.SparseNet(lat_dim,1)\n",
    "    net.to(device)\n",
    "    tnet.to(device)\n",
    "    rnet.to(device)\n",
    "    dnet.to(device)\n",
    "    sparse_model1.to(device)\n",
    "    sparse_model2.to(device)\n",
    "    sparse_model3.to(device)\n",
    "    sparse_model4.to(device)\n",
    "    \n",
    "    opt_net = torch.optim.Adam(net.parameters(), lr=0.00005,weight_decay=1e-3)#1e-4\n",
    "    opt_nett = torch.optim.Adam(tnet.parameters(), lr=0.0008, weight_decay=1e-3) #000008\n",
    "    opt_netr = torch.optim.Adam(rnet.parameters(), lr=0.001,weight_decay=1e-3) #0.0008\n",
    "    opt_netd = torch.optim.Adam(dnet.parameters(), lr=0.00002,weight_decay=1e-3) #00002,weight_decay=1e-4\n",
    "    opt_lmd = torch.optim.Adam(mse_constraint.parameters(), lr=0.004,weight_decay=1e-3)\n",
    "    sparse_optimizer1 = optim.Adam(sparse_model1.parameters(), lr=0.001) #0.001\n",
    "    sparse_optimizer2 = optim.Adam(sparse_model2.parameters(), lr=0.001) #0.001\n",
    "    sparse_optimizer3 = optim.Adam(sparse_model3.parameters(), lr=0.001) #0.001\n",
    "    sparse_optimizer4 = optim.Adam(sparse_model4.parameters(), lr=0.001) #0.001\n",
    "    \n",
    "   \n",
    "    \n",
    "    \n",
    "    mask_gamma = torch.ones(dim_range*4).to(device)\n",
    "    mask_delta = torch.ones(dim_range*4).to(device)\n",
    "    mask_upsilon = torch.ones(dim_range*4).to(device)\n",
    "    mask_omega = torch.ones(dim_range*4).to(device)\n",
    "    milestones = {int(epochs * ratio): False for ratio in [0.99]} #0.4, 0.6, 0.99\n",
    "\n",
    "    for ep in tqdm(range(1,epochs+1 )):\n",
    " \n",
    "        \n",
    "        train_dataloader_sr, train_dataloader_tr=util.get_dataloader(X_train,y_train,batch_size)\n",
    "        tot_los=0\n",
    "        reg_los=0\n",
    "        t_los=0\n",
    "        w_los=0\n",
    "        wa_los=0\n",
    "        rc_los=0\n",
    "        rereg_los=0\n",
    "        cnt=0\n",
    "        kl=0\n",
    "        kl_t=0\n",
    "        mse_er=0\n",
    "        pen=0\n",
    "        gc=0\n",
    "        dc=0\n",
    "        uc=0\n",
    "        oc=0\n",
    "        l0=0\n",
    "        ex=0\n",
    " \n",
    "        for batch_idx, (train_source_data, train_target_data) in enumerate(zip(train_dataloader_sr, train_dataloader_tr)):\n",
    "            \n",
    "            xs,ys=train_source_data\n",
    "            xt,yt=train_target_data\n",
    "            # Replace 30 with :\n",
    "            xs_train=xs[:,5:].to(device)\n",
    "            xt_train=xt[:,5:].to(device)\n",
    "            ys.to(device)\n",
    "            yt.to(device)\n",
    "            train_x=torch.cat((xs_train,xt_train),0)\n",
    "            train_y=torch.unsqueeze(torch.cat((ys,yt),0), dim=1).to(device)\n",
    "            true_t=torch.unsqueeze(torch.cat((xs[:,0],xt[:,0]),0), dim=1).to(device)\n",
    "            concat_true=torch.cat((train_y,true_t),1)\n",
    "            prop_t1=(xt_train.shape[0]/train_x.shape[0])\n",
    "            phi, phi_mean, phi_var=net(train_x,fstart,fend,sstart,send,tstart,tend,frstart,frend)\n",
    "\n",
    "            \n",
    "                \n",
    "\n",
    "\n",
    "            \n",
    "            phi_gamma=phi[:,fstart:fend]\n",
    "            phi_delta=phi[:,sstart:send]\n",
    "            phi_upsilon=phi[:,tstart:tend]\n",
    "            phi_irr=phi[:,frstart:frend]\n",
    "            \n",
    "      \n",
    "            phi_gamma, penalty_gamma,mask_gamma_d = sparse_model1(phi)\n",
    "            phi_delta, penalty_delta,mask_delta_d = sparse_model2(phi)\n",
    "            phi_upsilon, penalty_upsilon,mask_upsilon_d = sparse_model3(phi)\n",
    "            phi_omega, penalty_omega,mask_omega_d = sparse_model4(phi)\n",
    "           \n",
    "           \n",
    "            \n",
    "            \n",
    "            all_mask=torch.vstack((mask_gamma_d,mask_delta_d,mask_upsilon_d,mask_omega_d))\n",
    "          \n",
    "            L_excl,soft_masks = util.exclusivity_loss(all_mask, temperature=0.01) #0.01\n",
    "            \n",
    "        \n",
    "            mask_gamma=soft_masks[0, :]\n",
    "            mask_delta=soft_masks[1, :]\n",
    "            mask_upsilon=soft_masks[2, :]\n",
    "            mask_omega=soft_masks[3, :]\n",
    "\n",
    "            gamma_c=util.get_dim_count(mask_gamma,0.5)\n",
    "            delta_c=util.get_dim_count(mask_delta,0.5)\n",
    "            upsilon_c=util.get_dim_count(mask_upsilon,0.5)\n",
    "            omega_c=util.get_dim_count(mask_omega,0.5)\n",
    "            \n",
    "            penalty=penalty_delta+penalty_upsilon+penalty_gamma+penalty_omega#+penalty_omega++\n",
    "          \n",
    "            phi_all=torch.cat((phi_gamma*mask_gamma,phi_delta*mask_delta,phi_upsilon*mask_upsilon,phi_omega*mask_omega), 1)\n",
    "            del_ups=torch.cat((phi_delta*mask_delta, phi_upsilon*mask_upsilon), 1)\n",
    "            gam_del=torch.cat((phi_gamma*mask_gamma,phi_delta*mask_delta), 1)\n",
    "            \n",
    "            \n",
    "            concat_pred=rnet(del_ups)\n",
    "            #w_f=wnet(phi_delta)\n",
    "            w_t=tnet(gam_del)\n",
    "            decoded_space=dnet(phi_all)\n",
    "            \n",
    "            predicted_y=torch.unsqueeze(torch.where(true_t.squeeze() == 0, concat_pred[:,0], concat_pred[:,1]),dim=1)\n",
    "           \n",
    "            opt_net.zero_grad()\n",
    "           \n",
    "            opt_nett.zero_grad()\n",
    "            opt_netr.zero_grad()\n",
    "            opt_netd.zero_grad()\n",
    "            sparse_optimizer1.zero_grad()\n",
    "            sparse_optimizer2.zero_grad()\n",
    "            sparse_optimizer3.zero_grad()\n",
    "            sparse_optimizer4.zero_grad()\n",
    "            opt_lmd.zero_grad()\n",
    "           \n",
    "           \n",
    "            Reconstruction_loss=(MSE(decoded_space[:,0:6],train_x[:,0:6])+BCE(decoded_space[:,6:25],train_x[:,6:25])+MSE(decoded_space[:,25:],train_x[:,25:]))\n",
    "            #Reconstruction_loss=MSE(decoded_space,train_x)\n",
    "            y_LOSS,mse_cons=util.Y_GECO(predicted_y,train_y, tol)\n",
    "        \n",
    "            mask_g_active = (mask_gamma > threshold).float()\n",
    "            mask_d_active = (mask_delta > threshold).float()\n",
    "            mask_u_active = (mask_upsilon > threshold).float()\n",
    "            mask_o_active = (mask_omega > threshold).float()\n",
    "            \n",
    "         \n",
    "            \n",
    "            kl_div_g=util.kl_loss_2(phi_mean*mask_gamma, phi_var*mask_gamma,beta)\n",
    "            kl_div_d=util.kl_loss_2(phi_mean*mask_delta, phi_var*mask_delta,beta)\n",
    "            kl_div_u=util.kl_loss_2(phi_mean*mask_upsilon, phi_var*mask_upsilon,beta)\n",
    "            kl_div_o=util.kl_loss_2(phi_mean*mask_omega, phi_var*mask_omega,beta)\n",
    "            kl_div=((kl_div_g+kl_div_d+kl_div_u+kl_div_o)/4)\n",
    "            \n",
    "            Wassloss,dist=util.wasserstein(phi_upsilon*mask_upsilon,true_t)\n",
    "           \n",
    "            Tloss=BCE(w_t, true_t)\n",
    "            cnt=cnt+1\n",
    "            \n",
    "            \n",
    "            \n",
    "           \n",
    "            Rloss,lambd=mse_constraint(y_LOSS)\n",
    "\n",
    "           \n",
    "\n",
    "            combined_loss=(Rloss)+(Wassloss)+(Tloss)+(Reconstruction_loss)+(2*kl_div)\\\n",
    "            +(5*penalty)+L_excl\n",
    "            \n",
    "           \n",
    "            \n",
    "            \n",
    "            combined_loss.backward()\n",
    "            tot_los=tot_los+combined_loss.item()\n",
    "            reg_los=reg_los+Rloss.item()\n",
    "            t_los=t_los+Tloss.item()\n",
    "            #w_los=w_los+OR.item()\n",
    "            wa_los=wa_los+Wassloss.item()\n",
    "            rc_los=rc_los+Reconstruction_loss.item()\n",
    "            #rereg_los=rereg_los+OR_re.item()\n",
    "            kl=kl+kl_div.item()\n",
    "            #kl_t=kl_t+torch.mean(kl_loss_3(phi_mean, phi_var,1)).item()\n",
    "            mse_er=mse_er+mse_cons\n",
    "            pen=pen+penalty.item()\n",
    "            gc=gc+gamma_c\n",
    "            dc=dc+delta_c\n",
    "            uc=uc+upsilon_c\n",
    "            oc=oc+omega_c\n",
    "            l0=l0+(penalty_delta.detach().item()+penalty_upsilon.detach().item()+penalty_gamma.detach().item())\n",
    "            ex=ex+L_excl.item()\n",
    "            \n",
    "            opt_net.step()\n",
    "           \n",
    "            opt_nett.step()\n",
    "            opt_netr.step()\n",
    "            opt_netd.step()\n",
    "            sparse_optimizer1.step()\n",
    "            sparse_optimizer2.step()\n",
    "            sparse_optimizer3.step()\n",
    "            opt_lmd.step()\n",
    "            sparse_optimizer4.step()\n",
    "\n",
    "           \n",
    "\n",
    "          \n",
    "        \n",
    "        val_PEHE_nn=util.cal_pehe_nn(X_val, y_val,rnet,net,mask_gamma,mask_delta,mask_upsilon,fstart,fend,sstart,send,tstart,tend,frstart,frend)\n",
    "        tr_PEHE=util.cal_pehe(X_train,y_train,rnet,net,mask_gamma,mask_delta,mask_upsilon,fstart,fend,sstart,send,tstart,tend,frstart,frend)\n",
    "\n",
    "\n",
    "         \n",
    "        # Evaluation of validation\n",
    "        X_train_t= X_train.to_numpy()\n",
    "        X_train_t=torch.from_numpy(X_train_t.astype(np.float32)).to(device)\n",
    "        \n",
    "        X_val_t=X_val.to_numpy()\n",
    "        X_val_t=torch.from_numpy(X_val_t.astype(np.float32)).to(device)\n",
    "        \n",
    "        y_val_t=y_val.to_numpy()\n",
    "        y_val_t=torch.from_numpy(y_val_t.astype(np.float32)).to(device)\n",
    "\n",
    "     \n",
    "\n",
    "        if ep > 3500:\n",
    "                if val_PEHE_nn < best_val_loss:\n",
    "                    best_val_loss = val_PEHE_nn\n",
    "                    epochs_without_improvement = 0  # Reset patience counter\n",
    "                else:\n",
    "                    epochs_without_improvement += 1\n",
    "                \n",
    "                # Save the model if it's among the top 10 and hasn't improved for 50 epochs\n",
    "                if epochs_without_improvement >= patience:\n",
    "                    epochs_without_improvement = 0 \n",
    "                    if len(best_models) < max_saved_models or val_PEHE_nn < best_models[-1][0]:\n",
    "                        # Save the model\n",
    "                        model_state = {\n",
    "                            'encoder': net.state_dict(),\n",
    "                            'regressor': rnet.state_dict(),\n",
    "                            'classifier': tnet.state_dict(),\n",
    "                            'val_loss': val_PEHE_nn,\n",
    "                            'epoch': ep\n",
    "                        }\n",
    "                        model_filename = f'Geco_1/model_{ep}_val{val_PEHE_nn:.4f}.pth'\n",
    "                        torch.save(model_state, model_filename)\n",
    "        \n",
    "                        # Add to best models list\n",
    "                        heapq.heappush(best_models, (val_PEHE_nn, model_filename))\n",
    "                        saved_model_count += 1\n",
    "        \n",
    "                        # Keep only the best 10 models\n",
    "                        if len(best_models) > max_saved_models:\n",
    "                            worst_model = heapq.heappop(best_models)  # Remove the worst model\n",
    "                            torch.cuda.empty_cache()  # Free memory if needed\n",
    "        \n",
    "                # Stop training after saving 10 models\n",
    "        if saved_model_count >= max_saved_models:\n",
    "                print(f\"Early stopping at epoch {ep}: 10 best models saved.\")\n",
    "                break\n",
    "       \n",
    "           \n",
    "        \n",
    "        reg_loss_tr.append(reg_los/cnt)\n",
    "        classT_loss_tr.append(t_los/cnt)\n",
    "        classW_loss_tr.append(w_los/cnt)\n",
    "        wasser_loss_tr.append(wa_los/cnt)\n",
    "        total_loss.append(tot_los/cnt)\n",
    "        tr_pehe.append(tr_PEHE)\n",
    "        val_pehe.append(val_PEHE_nn)\n",
    "        #print(rc_los)\n",
    "        rc_loss.append(rc_los/cnt)\n",
    "        OR_reg_li.append(rereg_los)\n",
    "        kl_l.append(kl)\n",
    "        lambd_li.append(lambd)\n",
    "        kl_total.append(kl_t)\n",
    "        pen_li.append(pen/cnt)\n",
    "        gamma_count.append(gc/cnt)\n",
    "        delta_count.append(dc/cnt)\n",
    "        upsilon_count.append(uc/cnt)\n",
    "        omega_count.append(oc/cnt)\n",
    "        mse_li.append(mse_er/cnt)\n",
    "        l0_li.append(l0/cnt)\n",
    "        ex_cl_lis.append(ex/cnt)\n",
    "        \n",
    "\n",
    "    return rnet,net,X_train_t,mask_gamma,mask_delta,mask_upsilon,mask_omega,best_models\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2cc6716-56af-4ee8-9376-663fc5a779aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_x = MinMaxScaler()\n",
    "scaler_y = MinMaxScaler()\n",
    "\n",
    "\n",
    "\n",
    "lat_dim_li=[15*4] #15*4\n",
    "dummy_list=[5,10,15]\n",
    "tol_li=[0.5] \n",
    "gamma_dim_count=[]\n",
    "delta_dim_count=[]\n",
    "upsilon_dim_count=[]\n",
    "omega_dim_count=[]\n",
    "column_name = ['0', '5', '10', '15','20']\n",
    "for ldi,ld in enumerate(lat_dim_li):\n",
    "    dim_range=ld//4 #5,10,15\n",
    "    lat_dim=dim_range*4\n",
    "    fstart=0\n",
    "    fend=fstart+dim_range\n",
    "    sstart=fend\n",
    "    send=sstart+dim_range\n",
    "    tstart=send\n",
    "    tend=tstart+dim_range\n",
    "    frstart=tend\n",
    "    frend=frstart+dim_range\n",
    "    lat_dim=ld\n",
    "    hid_enc=ld\n",
    "    print('latent dimensions', lat_dim)\n",
    "    \n",
    "    for di,d in enumerate(dummy_list):\n",
    "    \n",
    "        \n",
    "        pehe_l=[]\n",
    "        dummies=d\n",
    "        pehe_final=[]\n",
    "        input_features=25+dummies\n",
    "        kl_avg=[]\n",
    "        reg_loss_li=[]\n",
    "        tol=0.5 #0.5\n",
    "        \n",
    "        for i in range(1,31):\n",
    "            #clean()\n",
    "            x_data_tr,y_data_tr=util.get_data('train',i)\n",
    "       \n",
    "            x_data_n=util.add_dummy_features_shuffle(x_data_tr,dummies)\n",
    "           \n",
    "            x_data_n_nor= pd.DataFrame(scaler_x.fit_transform(x_data_n.iloc[:,5:]))\n",
    "            data_train_tran=pd.concat([x_data_n.iloc[:,0:5], x_data_n_nor], axis=1)\n",
    "            Regressor,Encoder,X_train_t,mask_gamma,mask_delta,mask_upsilon,mask_omega,best_models=train_CFR(data_train_tran,y_data_tr,tol)\n",
    "          \n",
    "            \n",
    "          \n",
    "            Enc=mdl.net(input_features,hid_enc,lat_dim,.1)\n",
    "            Enc.to(device)\n",
    "            Reg=mdl.Regressors(lat_dim*2,hid_y,.1)\n",
    "            Reg.to(device)\n",
    "            tclas=mdl.Tclassifier(lat_dim*2,.1)\n",
    "            tclas.to(device)\n",
    "            #____________________________\n",
    "    \n",
    "            \n",
    "            for idx, (val_loss, filename) in enumerate(best_models):\n",
    "            #print(f\"Loading model {idx+1}: {filename} (val_loss={val_loss:.4f})\")\n",
    "        \n",
    "                # Load saved state dict\n",
    "                model_state = torch.load(filename) #map_location='cpu'\n",
    "                Enc.load_state_dict(model_state['encoder'])\n",
    "                Reg.load_state_dict(model_state['regressor'])\n",
    "                tclas.load_state_dict(model_state['classifier'])\n",
    "            \n",
    "                data,y=util.get_data('test',i)\n",
    "                data_n=util.add_dummy_features_shuffle(data,dummies)\n",
    "                \n",
    "                \n",
    "                data_n_nor= pd.DataFrame(scaler_x.transform(data_n.iloc[:,5:]))\n",
    "                data_test_tran=pd.concat([data_n.iloc[:,0:5], data_n_nor], axis=1)\n",
    "               \n",
    "    \n",
    "                pehe_l.append(util.cal_pehe(data_test_tran,y,Reg,Enc,mask_gamma,mask_delta,mask_upsilon,fstart,fend,sstart,send,tstart,tend,frstart,frend))\n",
    "            ##############################\n",
    "           \n",
    "            \n",
    "            print('files completed: ',i)\n",
    "            gamma_dim_count.append(np.mean(gamma_count[-5:]))\n",
    "            delta_dim_count.append(np.mean(delta_count[-5:]))\n",
    "            upsilon_dim_count.append(np.mean(upsilon_count[-5:]))\n",
    "            omega_dim_count.append(np.mean(omega_count[-5:]))\n",
    "            #clear()\n",
    "            torch.cuda.empty_cache()\n",
    "            pehe_final.append(np.mean(pehe_l))\n",
    "            #clear()\n",
    "    \n",
    "    \n",
    "        print('Tolerance',tol)\n",
    "        print('Dummies shuffling: ',d)\n",
    "        print('PEHE mean', np.mean(pehe_final))\n",
    "        print('PEHE std', np.std(pehe_final))\n",
    "        #print('KL count', np.mean(kl_avg))\n",
    "        \n",
    "        \n",
    "\n",
    "        print('Gamma dim on avg',np.mean(gamma_dim_count))\n",
    "        print('Delta dim on avg',np.mean(delta_dim_count))\n",
    "        print('Upsilon dim on avg',np.mean(upsilon_dim_count))\n",
    "        print('Omega dim on avg',np.mean(omega_dim_count))\n",
    "        \n",
    "       \n",
    "\n",
    "        \n",
    "        \n",
    "      \n",
    "        print('*************************************************')\n",
    "    print('___________________________Next Tolerance_______________________________')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c323ee0-4858-4615-820f-5a81065d1682",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e1d2dc5-0a0e-49e0-ad04-b7903c9bd4a9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "VAEITE",
   "language": "python",
   "name": "vaeite"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
