
import sys,os
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from parse_args_tr import parse_args
import pickle
from treatment_prediction.process_data import get_treatment_var_id_suffix

import random
import numpy as np
import torch

from utils_treatment import random_split_train_valid_test_ids, calculate_input_size

import GRU_ODE.data_utils as data_utils

from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
from baseline_methods.baseline import *
from sklearn.metrics import roc_auc_score
from tabular.tab_models.tab_model import *
from utils_treatment import load_configs, load_dataset_configs
from tabular.tabular_data_utils.tabular_dataset import *
from tabular import synthetic_lang
from sklearn.linear_model import LinearRegression, LogisticRegression
from econml.metalearners import TLearner

classical_baseline_ls = ["Ganite", "dt", "lr", "rf"]

def compute_cont_outcome_error(test_dataset, est, X, A, method_name):
    t_grid = test_dataset.t_grid
    n_test = t_grid.shape[1]
    t_grid_hat = torch.zeros(2, n_test)
    t_grid_hat[0, :] = t_grid[0, :]

            

    for n in tqdm(range(n_test)):

        # for step, batch in enumerate(test_loader):
            
        #         # batch = (x.cuda() for x in batch)
        #     # text_id, text_len, text_mask, A, _, Y, (origin_all_other_pats_ls, X_pd_ls) = batch
        #     # (idx, sample_idx, origin_all_other_pats_ls, X_pd_ls, text_id_ls, text_mask_ls, text_len_ls), Y, A = batch
        #     # (idx, sample_idx, origin_all_other_pats_ls, X_pd_ls, X), Y, A = batch
        #     idx, origin_X, A, Y, count_Y, D, X, origin_all_other_pats_ls = batch
        #     X =X.to(self.device)
        #     Y = Y.to(self.device)

        A *= 0
        A += t_grid[0,n]
        
        if not method_name == "Ganite":
            pred_outcome = est._ortho_learner_model_final._model_final._model.predict(est._ortho_learner_model_final._model_final._combine(X, A))
        #     all_outcome_ls.append(pred.view(-1))
            
        #     all_gt_treatment_ls.append(A.cpu().view(-1))
        #     all_gt_outcome_ls.append(Y.cpu().view(-1))
        
        # # all_treatment_pred_tensor = torch.cat(all_treatment_ls)
        # all_outcome_pred_tensor = torch.cat(all_outcome_ls)
        # all_gt_outcome_tensor = torch.cat(all_gt_outcome_ls)
        # all_gt_treatment_tensor = torch.cat(all_gt_treatment_ls)
        t_grid_hat[1,n] = pred_outcome.mean()
    mse = ((t_grid_hat[1, :].squeeze() - t_grid[1, :].squeeze()) ** 2).mean().data

    print("outcome loss::%f"%(mse))


def compute_outcome_error(est, X, A, Y, method_name):
    # t_grid = test_dataset.t_grid
    # n_test = t_grid.shape[1]
    # t_grid_hat = torch.zeros(2, n_test)
    # t_grid_hat[0, :] = t_grid[0, :]

            

    # for n in tqdm(range(n_test)):

        # for step, batch in enumerate(test_loader):
            
        #         # batch = (x.cuda() for x in batch)
        #     # text_id, text_len, text_mask, A, _, Y, (origin_all_other_pats_ls, X_pd_ls) = batch
        #     # (idx, sample_idx, origin_all_other_pats_ls, X_pd_ls, text_id_ls, text_mask_ls, text_len_ls), Y, A = batch
        #     # (idx, sample_idx, origin_all_other_pats_ls, X_pd_ls, X), Y, A = batch
        #     idx, origin_X, A, Y, count_Y, D, X, origin_all_other_pats_ls = batch
        #     X =X.to(self.device)
        #     Y = Y.to(self.device)

        # A *= 0
        # A += t_grid[0,n]
    
    if not method_name == "Ganite":
        pred_outcome = est._ortho_learner_model_final._model_final._model.predict(est._ortho_learner_model_final._model_final._combine(X, A))
        mse = ((pred_outcome.reshape(-1) - Y.reshape(-1)) ** 2).mean().data
    else:
        all_outcome = est.inference_nets(torch.from_numpy(X).cuda())
        mse = ((all_outcome.detach().cpu()[torch.arange(len(all_outcome)), A.long().reshape(-1)].reshape(-1) - Y.reshape(-1)) ** 2).mean().data
        #     all_outcome_ls.append(pred.view(-1))
            
        #     all_gt_treatment_ls.append(A.cpu().view(-1))
        #     all_gt_outcome_ls.append(Y.cpu().view(-1))
        
        # # all_treatment_pred_tensor = torch.cat(all_treatment_ls)
        # all_outcome_pred_tensor = torch.cat(all_outcome_ls)
        # all_gt_outcome_tensor = torch.cat(all_gt_outcome_ls)
        # all_gt_treatment_tensor = torch.cat(all_gt_treatment_ls)
        # t_grid_hat[1,n] = pred_outcome.mean()
    # mse = ((t_grid_hat[1, :].squeeze() - t_grid[1, :].squeeze()) ** 2).mean().data

    print("outcome loss::%f"%(mse))


def random_sampling(X, T, W, Y, ratio=0.2):
    perturbed_ids = torch.randperm(len(X))
    selected_perturb_ids = perturbed_ids[0:int(ratio*len(X))]
    sub_X = X[selected_perturb_ids]
    sub_T = T[selected_perturb_ids]
    sub_W = None
    if W is not None:
        sub_W = W[selected_perturb_ids]
    sub_Y = Y[selected_perturb_ids]
    return sub_X, sub_T, sub_W, sub_Y


def compute_eval_metrics_baseline(meta_info, test_dataset, num_treatments, do_prediction, train=False):
        mises = []
        ites = []
        dosage_policy_errors = []
        policy_errors = []
        pred_best = []
        pred_vals = []
        true_best = []

        samples_power_of_two = 6
        num_integration_samples = 2 ** samples_power_of_two + 1
        step_size = 1. / num_integration_samples
        treatment_strengths = np.linspace(np.finfo(float).eps, 1, num_integration_samples)

        # for patient in test_patients:
        # for step, batch in enumerate(test_loader):
        with torch.no_grad():
            for idx, origin_X, A, Y, count_Y, D, patient, all_other_pats_ls in tqdm(test_dataset):
                if train and len(pred_best) > 10:
                    return np.sqrt(np.mean(mises)), np.sqrt(np.mean(dosage_policy_errors)), np.sqrt(np.mean(policy_errors)), np.mean(ites)
                for treatment_idx in range(num_treatments):
                    test_data = dict()
                    test_data['x'] = np.repeat(np.expand_dims(patient, axis=0), num_integration_samples, axis=0)
                    X = test_data['x']
                    X_pd_full = X
                    origin_X = X
                    test_data['t'] = np.repeat(treatment_idx, num_integration_samples)
                    A = test_data["t"]
                    test_data['d'] = treatment_strengths
                    D = test_data["d"]
                    X_D = np.concatenate([X, D.reshape(-1,1)], axis=-1)
                    origin_all_other_pats_ls= [all_other_pats_ls.clone() for _ in range(num_integration_samples)]
                    pred_dose_response = do_prediction(X_D, A)
                    # pred_dose_response = get_model_predictions(num_treatments=num_treatments, test_data=test_data, model=model)
                    # pred_dose_response = pred_dose_response * (
                    #         dataset['metadata']['y_max'] - dataset['metadata']['y_min']) + \
                    #                         dataset['metadata']['y_min']

                    true_outcomes = [get_patient_outcome(patient, meta_info, treatment_idx, d) for d in
                                        treatment_strengths]
                    
                    # if len(pred_best) < num_treatments and train == False:
                    #     #print(true_outcomes)
                    #     print([item[0] for item in pred_dose_response])
                    mise = romb(np.square(true_outcomes - pred_dose_response), dx=step_size)
                    inter_r = np.array(true_outcomes) - pred_dose_response.squeeze()
                    ite = np.mean(inter_r ** 2)
                    mises.append(mise)
                    ites.append(ite)

                    best_encountered_x = treatment_strengths[np.argmax(pred_dose_response)]

                    def pred_dose_response_curve(dosage):
                        test_data = dict()
                        test_data['x'] = np.expand_dims(patient, axis=0)
                        test_data['t'] = np.expand_dims(treatment_idx, axis=0)
                        test_data['d'] = np.expand_dims(dosage, axis=0)
                        X = test_data['x']
                        X_pd_full = X
                        origin_X = X
                        # X = X.to(device)
                        A = test_data["t"]
                        D = test_data["d"]
                        X_D = np.concatenate([X, D.reshape(-1,1)], axis=-1)
                        
                        ret_val =do_prediction(X_D, A)
                        # ret_val = get_model_predictions(num_treatments=num_treatments, test_data=test_data, model=model)
                        # ret_val = ret_val * (dataset['metadata']['y_max'] - dataset['metadata']['y_min']) + \
                        #             dataset['metadata']['y_min']
                        return ret_val

                    true_dose_response_curve = get_true_dose_response_curve(meta_info, patient, treatment_idx)

                    min_pred_opt = minimize(lambda x: -1 * pred_dose_response_curve(x),
                                            x0=[best_encountered_x], method="SLSQP", bounds=[(0, 1)])

                    max_pred_opt_y = - min_pred_opt.fun
                    max_pred_dosage = min_pred_opt.x
                    max_pred_y = true_dose_response_curve(max_pred_dosage)

                    min_true_opt = minimize(lambda x: -1 * true_dose_response_curve(x),
                                            x0=[0.5], method="SLSQP", bounds=[(0, 1)])
                    max_true_y = - min_true_opt.fun
                    max_true_dosage = min_true_opt.x

                    dosage_policy_error = (max_true_y - max_pred_y) ** 2
                    dosage_policy_errors.append(dosage_policy_error)

                    pred_best.append(max_pred_opt_y)
                    pred_vals.append(max_pred_y)
                    true_best.append(max_true_y)
                    

                selected_t_pred = np.argmax(pred_vals[-num_treatments:])
                selected_val = pred_best[-num_treatments:][selected_t_pred]
                selected_t_optimal = np.argmax(true_best[-num_treatments:])
                optimal_val = true_best[-num_treatments:][selected_t_optimal]
                policy_error = (optimal_val - selected_val) ** 2
                policy_errors.append(policy_error)

        return np.sqrt(np.mean(mises)), np.sqrt(np.mean(dosage_policy_errors)), np.sqrt(np.mean(policy_errors)), np.mean(ites)



def classical_baseline_main(args, method_name, X, T, W, Y, valid_X, valid_T, valid_W, valid_Y, test_X, test_T,  test_W, test_Y, dataset, test_dataset, count_Y=None, valid_count_Y=None, test_count_Y=None, classification=False):
    # if args.dataset_name == "tcga":
    #     X, T, W, Y = random_sampling(X,T, W, Y)
    if W is not None:
        W = W.view(-1,1)
        valid_W =  valid_W.view(-1,1)
        test_W =  test_W.view(-1,1)       
        X = torch.cat([X, W], dim=-1)
        valid_X=  torch.cat([valid_X, valid_W], dim=-1)
        test_X=  torch.cat([test_X, test_W], dim=-1)
    


    X =X.numpy()
    Y = transform_outcome_by_rescale_back(dataset, Y)
    valid_Y = transform_outcome_by_rescale_back(dataset, valid_Y)
    test_Y = transform_outcome_by_rescale_back(dataset, test_Y)
    if count_Y is not None:
        count_Y = transform_outcome_by_rescale_back(dataset, count_Y)
    if valid_count_Y is not None:
        valid_count_Y = transform_outcome_by_rescale_back(dataset, valid_count_Y)
    if test_count_Y is not None:
        test_count_Y = transform_outcome_by_rescale_back(dataset, test_count_Y)
    
    if method_name == "Ganite":
        est = ganite(X, T, Y)
        train_ites = est(X)
        train_ites = train_ites.detach().cpu().numpy()
        valid_ites = est(valid_X)
        valid_ites = valid_ites.detach().cpu().numpy()
        test_ites = est(test_X)
        test_ites = test_ites.detach().cpu().numpy()
    elif method_name == "dt":
        if classification:
            y_model = DecisionTreeClassifier(max_depth=args.program_max_len)
        else:
            y_model = DecisionTreeRegressor(max_depth=args.program_max_len)
            
        if not args.cont_treatment:
            if not args.dataset_name == "tcga":
                est = LinearDML(model_y=y_model, model_t=DecisionTreeClassifier(max_depth=args.program_max_len))
            else:
                est = TLearner(models=y_model)#LinearDML(model_y=y_model, model_t=DecisionTreeClassifier(max_depth=args.program_max_len))
        else:
            if not args.dataset_name == "tcga":
                est = LinearDML(model_y=y_model, model_t=DecisionTreeRegressor(max_depth=args.program_max_len))
            else:
                est = TLearner(models=y_model)

        est.fit(Y, T, X=X)
        train_ites = est.effect(X)
        valid_ites = est.effect(valid_X)
        test_ites = est.effect(test_X)
    elif method_name == "lr":
        if classification:
            y_model = LogisticRegression(multi_class="multinomial")
        else:
            y_model = LinearRegression()
        if not args.cont_treatment:
            if not args.dataset_name == "tcga":
                est = LinearDML(model_y=y_model, model_t=LogisticRegression())
            else:
                est = TLearner(models=y_model)
        else:
            if not args.dataset_name == "tcga":
                est = LinearDML(model_y=y_model, model_t=LinearRegression())
            else:
                est = TLearner(models=y_model)
        
        est.fit(Y, T, X=X)
        train_ites = est.effect(X)
        valid_ites = est.effect(valid_X)
        test_ites = est.effect(test_X)
    elif method_name == "rf":
        if classification:
            y_model = RandomForestClassifier(n_estimators=args.topk_act, max_depth=args.program_max_len)
        else:
            y_model = RandomForestRegressor(n_estimators=args.topk_act, max_depth=args.program_max_len)

        if not args.cont_treatment:
            if not args.dataset_name == "tcga":
                est = LinearDML(model_y=y_model, model_t=RandomForestClassifier(n_estimators=args.topk_act, max_depth=args.program_max_len))
            else:
                est = TLearner(models=y_model)
        else:
            if not args.dataset_name == "tcga":
                est = LinearDML(model_y=y_model, model_t=RandomForestRegressor(n_estimators=args.topk_act, max_depth=args.program_max_len))
            else:
                est = TLearner(models=y_model)
        
        est.fit(Y, T, X=X)
        train_ites = est.effect(X)
        valid_ites = est.effect(valid_X)
        test_ites = est.effect(test_X)
    if count_Y is not None:
        gt_train_ites = count_Y - Y
    if valid_count_Y is not None:
        gt_valid_ites = valid_count_Y - valid_Y
    if test_count_Y is not None:
        gt_test_ites = test_count_Y - test_Y

    if not args.cont_treatment and not args.has_dose and args.num_treatments == 2:
        if count_Y is not None:
            best_train_ite, best_train_ate = np.mean(np.abs(train_ites- gt_train_ites.numpy())).item(), np.abs(np.mean(train_ites- gt_train_ites.numpy())).item()
            best_val_ite, best_val_ate = np.mean(np.abs(valid_ites- gt_valid_ites.numpy())).item(), np.abs(np.mean(valid_ites- gt_valid_ites.numpy())).item()
            best_test_ite, best_test_ate = np.mean(np.abs(test_ites- gt_test_ites.numpy())).item(), np.abs(np.mean(test_ites- gt_test_ites.numpy())).item()

            print("best train ite::%f"%(best_train_ite))
            print("best validation ite::%f"%(best_val_ite))
            print("best test ite::%f"%(best_test_ite))
            
            print("best train ate::%f"%(best_train_ate))
            print("best validation ate::%f"%(best_val_ate))
            print("best test ate::%f"%(best_test_ate))
            print()
        else:
            train_ate = np.abs(np.mean(train_ites)).item()
            valid_ate = np.abs(np.mean(valid_ites)).item()
            test_ate = np.abs(np.mean(test_ites)).item()
            
            print("best train ate::%f"%(train_ate))
            print("best validation ate::%f"%(valid_ate))
            print("best test ate::%f"%(test_ate))
            
        compute_outcome_error(est, X, T, Y, method_name)
    elif args.has_dose:
        
            def tcga_pred_function(X_D, A):
                res = np.zeros(len(X_D))
                if np.sum(A==1) > 0:
                    res[A==1] = est.models[1].predict(X_D[A==1]).reshape(-1)
                if np.sum(A==0) > 0:
                    res[A==0] = est.models[0].predict(X_D[A==0]).reshape(-1)
                return res
                
        
            mise, dpe, pe, ate = compute_eval_metrics_baseline(dataset.metainfo, dataset, args.num_treatments, tcga_pred_function)
            print("Train Mise: %s" % str(mise))
            print("Train DPE: %s" % str(dpe))
            print("Train PE: %s" % str(pe))
            print("Train ATE: %s" % str(ate))
            
            
            mise, dpe, pe, ate = compute_eval_metrics_baseline(test_dataset.metainfo, test_dataset, args.num_treatments, tcga_pred_function)
            print("Mise: %s" % str(mise))
            print("DPE: %s" % str(dpe))
            print("PE: %s" % str(pe))
            print("ATE: %s" % str(ate))

    elif args.cont_treatment:
        compute_cont_outcome_error(test_dataset, est, X, T, method_name)
        

def mean_impute(X, M, mean_X=None):
    X[M==0] = 0
    if mean_X is None:
        mean_X = torch.sum(X*M, dim=0) / (torch.sum(M, dim=0) + 1e-5)
    X = X*M + (1-M)*mean_X.unsqueeze(0)
    return X, mean_X


def flatten_data(dl, numer_feat_ls, cat_feat_ls, cat_feat_ls_onehot, feat_to_onehot_embedding, impute=True, mean_X=None, interval=1, patient_id=None, time=None):
    
    X_ls = []
    M_ls = []
    treatment_arr_ls = []
    outcome_arr_ls = []    
    feat_name_ls = []

    if patient_id is not None and time is not None:    
        found=False

    for i, b in tqdm(enumerate(dl)):
        # {"df":df, "batch_ids": batch_ids, "pid": list(df["PAT_ID"].values)}
        # data_utils.post_process_batch(b["df"], b["pid"], b["batch_ids"], self.numer_feat_ls, self.cat_feat_ls, self.cat_feat_ls_onehot, self.feat_to_onehot_embedding, ["label"], ["concat_treatment_label_id"], return_cat_numeric_feat=True)
        X, M, outcome_arr, treatment_arr, times, time_ptr = data_utils.post_process_batch_normal_time_series(b["df"], b["pid"], b["batch_ids"], numer_feat_ls, cat_feat_ls, cat_feat_ls_onehot, feat_to_onehot_embedding, ["label"], ["concat_treatment_label_id"])
        if patient_id is not None and time is not None:
            if patient_id in b["pid"]:
                idx = b["pid"].index(patient_id)
                X_ls.append(X[idx, time:time+interval])
                M_ls.append(M[idx, time:time+interval])
                treatment_arr_ls.append(treatment_arr[idx, time:time+interval])
                outcome_arr_ls.append(outcome_arr[idx, time:time+interval])
                found=True
                break
        else:
            unfold_X = X[:,1:].unfold(1, interval, 1).permute(0, 1, 3, 2)
            unfold_X = unfold_X.reshape(unfold_X.shape[0] * unfold_X.shape[1], unfold_X.shape[2]*unfold_X.shape[3])
            unfold_M = M[:,1:].unfold(1, interval, 1).permute(0, 1, 3, 2).reshape(unfold_X.shape)
            X_ls.append(unfold_X)
            M_ls.append(unfold_M)
            
            unfold_tr = treatment_arr[:,interval:]
            unfold_outcome = outcome_arr[:,interval:]
            unfold_tr = unfold_tr.reshape(unfold_tr.shape[0] * unfold_tr.shape[1], unfold_tr.shape[2])
            unfold_outcome = unfold_outcome.reshape(unfold_outcome.shape[0] * unfold_outcome.shape[1], unfold_outcome.shape[2])
            treatment_arr_ls.append(unfold_tr)
            outcome_arr_ls.append(unfold_outcome)
    if patient_id is not None and time is not None:   
        if not found:
            return None

    for k in range(interval):
        feat_name_ls += [f"{feat}_{k}" for feat in numer_feat_ls] + [f"{feat}_{val}" for feat in feat_to_onehot_embedding for val in feat_to_onehot_embedding[feat]]
    
    X_tensor, M_tensor, treatment_tensor, outcome_tensor = torch.cat(X_ls), torch.cat(M_ls), torch.cat(treatment_arr_ls), torch.cat(outcome_arr_ls)

    if impute:
        X_tensor, mean_X = mean_impute(X_tensor, M_tensor, mean_X)
        
    
        return feat_name_ls, X_tensor, M_tensor, treatment_tensor, outcome_tensor, mean_X
    else:
        return feat_name_ls, X_tensor, M_tensor, treatment_tensor, outcome_tensor


def calculate_real_bound(feat_range_mappings,feat_name_ls, single_test_X, idx):
    real_bound_val = (feat_range_mappings[feat_name_ls[idx].split("_")[0]][1] - feat_range_mappings[feat_name_ls[idx].split("_")[0]][0])*single_test_X[0,idx] + feat_range_mappings[feat_name_ls[idx].split("_")[0]][0]
    print(single_test_X[0,idx])
    print(real_bound_val)
    print(feat_name_ls[idx])

def transform_tensor_to_df(train_X, train_treatment, train_outcome, feat_name_ls):
    
    all_feat_name_ls = []
    
    all_feat_name_ls += feat_name_ls
    
    all_feat_name_ls += ["treatment", "outcome"]
    
    df = pd.DataFrame(np.concatenate([train_X.numpy(), train_treatment.view(-1,1).numpy(), train_outcome.view(-1,1).numpy()], axis=-1), columns=all_feat_name_ls)
    
    df.reset_index(inplace=True)
    
    df = df.rename(columns={"index":"id"})
    
    return df

if __name__ == "__main__":

    args = parse_args()
    args.method = "TransTEE"
    args.tr = False
    args.has_dose= False
        
    normalize_y = False
    args.structured_treatment = False
    args.cont_treatment = False
    args.cat_and_cont_treatment = False
    args.alpha=1.0
    args.p=0
    # parser = argparse.ArgumentParser(description="Running GRUODE on Double OU")
    # parser.add_argument('--model_name', type=str, help="Model to use", default="double_OU_gru_ode_bayes")
    # parser.add_argument('--log_folder', type=str, help="Dataset CSV file", default=None)
    # parser.add_argument('--dataset_name', type=str, help="dataset name", default=None)
    # parser.add_argument('--jitter', type=float, help="Time jitter to add (to split joint observations)", default=0)
    # parser.add_argument('--lr', type=float, help="learning rate", default=2e-3)
    # parser.add_argument('--seed', type=int, help="Seed for data split generation", default=432)
    # parser.add_argument('--full_gru_ode', action="store_true", default=True)
    # parser.add_argument('--solver', type=str, choices=["euler", "midpoint","dopri5"], default="euler")
    # parser.add_argument('--no_impute',action="store_true",default = True)
    # parser.add_argument('--demo', action = "store_true", default = False)
    # parser.add_argument('--treatment_var_ids', nargs='+', type=int, help='List of integers', default=[0,1])

    # args = parser.parse_args()
    # if args.demo:
    #     print(f"Demo Mode - Loading model for double_OU ....")
    #     gru_ode_bayes.paper_plotting.plot_trained_model(model_name = "double_OU_gru_ode_bayes_demo")
    #     exit()

    treatment_var_ids_str_suffix = get_treatment_var_id_suffix(args)

    with open(os.path.join(args.log_folder, "feat_to_onehot_embedding" + treatment_var_ids_str_suffix), "rb") as f:
        feat_to_onehot_embedding = pickle.load(f)
    # with open(os.path.join(args.log_folder, "reduced_processed_treatment_data" + treatment_var_ids_str_suffix), "rb") as f:
    #     df_by_pat_mapping = pickle.load(f)
    # with open(os.path.join(args.log_folder, "reduced_numer_feat_ls" + treatment_var_ids_str_suffix), "rb") as f:
    #     numer_feat_ls = pickle.load(f)
    # with open(os.path.join(args.log_folder, "reduced_cat_feat_ls" + treatment_var_ids_str_suffix), "rb") as f:
    #     cat_feat_ls = pickle.load(f)



    with open(os.path.join(args.log_folder, "processed_treatment_data" + treatment_var_ids_str_suffix), "rb") as f:
        df_by_pat_mapping = pickle.load(f)
    with open(os.path.join(args.log_folder, "numer_feat_ls" + treatment_var_ids_str_suffix), "rb") as f:
        numer_feat_ls = pickle.load(f)
    with open(os.path.join(args.log_folder, "cat_feat_ls" + treatment_var_ids_str_suffix), "rb") as f:
        cat_feat_ls = pickle.load(f)
    # with open(os.path.join(args.log_folder, "processed_treatment_data" + treatment_var_ids_str_suffix), "rb") as f:
    #     df_by_pat_mapping = pickle.load(f)
    # with open(os.path.join(args.log_folder, "numer_feat_ls" + treatment_var_ids_str_suffix), "rb") as f:
    #     numer_feat_ls = pickle.load(f)
    # with open(os.path.join(args.log_folder, "cat_feat_ls" + treatment_var_ids_str_suffix), "rb") as f:
    #     cat_feat_ls = pickle.load(f)


    with open(os.path.join(args.log_folder, "treatment_var_ls" + treatment_var_ids_str_suffix), "rb") as f:
        treatment_var_ls = pickle.load(f)

    with open(os.path.join(args.log_folder, "unique_treatment_label_ls" + treatment_var_ids_str_suffix), "rb") as f:
        unique_treatment_label_ls = pickle.load(f)

    random.seed(args.seed)

    # Set the random seed for NumPy
    np.random.seed(args.seed)

    # Set the random seed for PyTorch
    torch.manual_seed(args.seed)

    # If you are using CUDA (GPU), you should also set the seed for CUDA
    torch.cuda.manual_seed_all(args.seed)

    # Additional configurations for reproducibility (if necessary)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    val_options = {"T_val": 0.6, "max_val_samples": args.max_val_samples}
    train_ids,  valid_ids, test_ids = random_split_train_valid_test_ids(df_by_pat_mapping)
    data_train = data_utils.ODE_Dataset_medical(df_by_pat_mapping, train_ids, balance=True)
    data_valid = data_utils.ODE_Dataset_medical(df_by_pat_mapping, valid_ids, validation=True, val_options=val_options)
    data_test = data_utils.ODE_Dataset_medical(df_by_pat_mapping, test_ids, validation=True, val_options=val_options)
    all_ids = []
    all_ids = all_ids + train_ids + valid_ids + test_ids
    data_all = data_utils.ODE_Dataset_medical(df_by_pat_mapping,all_ids, validation=True, val_options=val_options)

    cat_feat_ls_onehot = [feat + "_onehot" for feat in cat_feat_ls]


    # Set the random seed for Python's built-in random module

    model_name = args.model_name
    params_dict=dict()

    # gpu_num = 2
    # device  = torch.device(f"cuda:{gpu_num}")
    # torch.cuda.set_device(gpu_num)
    device = "cuda" if torch.cuda.is_available() else "cpu"


    #Dataset metadata
    # metadata = np.load(f"{args.dataset[:-4]}_metadata.npy",allow_pickle=True).item()

    delta_t = 0.001#metadata["delta_t"]
    T       = 1#metadata["T"]

    # train_idx, val_idx = train_test_split(np.arange(metadata["N"]),test_size=0.2, random_state=args.seed)
    # val_options = {"T_val": 4, "max_val_samples": 1}
    # data_train = data_utils.ODE_Dataset(csv_file=args.dataset, idx=train_idx, jitter_time=args.jitter)
    # data_val   = data_utils.ODE_Dataset(csv_file=args.dataset, idx=val_idx, jitter_time=args.jitter,validation = True,
    #                                     val_options = val_options )

    input_size = calculate_input_size(feat_to_onehot_embedding, cat_feat_ls, numer_feat_ls)

    #Model parameters.
    params_dict["input_size"]  = input_size
    params_dict["hidden_size"] = 50
    params_dict["p_hidden"]    = 50
    params_dict["prep_hidden"] = 50
    params_dict["logvar"]      = True
    params_dict["mixing"]      = 0.1
    params_dict["delta_t"]     = delta_t
    params_dict["jitter"]      = args.jitter
    #params_dict["gru_bayes"]   = "masked_mlp"
    params_dict["full_gru_ode"] = args.full_gru_ode
    params_dict["solver"]      = args.solver
    params_dict["impute"]      = False

    params_dict["T"]           = T


    #Model parameters and the metadata of the dataset used to train the model are stored as a single dictionnary.
    # summary_dict ={"model_params":params_dict,"metadata":metadata}
    # np.save(f"./../trained_models/{model_name}_params.npy",summary_dict)

    dl     = DataLoader(dataset=data_train, collate_fn=data_utils.custom_collate_fn_medical, shuffle=True, batch_size=32,num_workers=2)
    dl_val = DataLoader(dataset=data_valid, collate_fn=data_utils.custom_collate_fn_medical, shuffle=False, batch_size=32,num_workers=1)
    dl_test = DataLoader(dataset=data_test, collate_fn=data_utils.custom_collate_fn_medical, shuffle=False, batch_size=32,num_workers=1)
    dl_all = DataLoader(dataset=data_all, collate_fn=data_utils.custom_collate_fn_medical, shuffle=False, batch_size=32,num_workers=1)

    feat_name_ls, train_X, train_M,train_treatment, train_outcome, mean_X = flatten_data(dl, numer_feat_ls, cat_feat_ls, cat_feat_ls_onehot, feat_to_onehot_embedding)
    _, valid_X, valid_M,valid_treatment, valid_outcome,_ = flatten_data(dl_val, numer_feat_ls, cat_feat_ls, cat_feat_ls_onehot, feat_to_onehot_embedding, mean_X=mean_X)
    _, test_X, test_M, test_treatment, test_outcome,_ = flatten_data(dl_test, numer_feat_ls, cat_feat_ls, cat_feat_ls_onehot, feat_to_onehot_embedding, mean_X=mean_X)
    _, single_test_X, single_test_M, single_test_treatment, single_test_outcome,_ = flatten_data(dl_all, numer_feat_ls, cat_feat_ls, cat_feat_ls_onehot, feat_to_onehot_embedding, mean_X=mean_X, patient_id="Z3117664", time=0)
    
    with open(os.path.join(args.log_folder, "feat_range_mappings" + treatment_var_ids_str_suffix), "rb") as f:
        feat_range_mappings = pickle.load(f)
    adapted_feat_name_ls = ["x_" + str(k) for k in range(train_X.shape[1])]
    # 
    # interpret_tree(model, test_X, feat_name_ls)

    # policy learning
    # model = DR_policy_learning(train_X.numpy(), train_treatment.view(-1).numpy(), train_outcome.view(-1).numpy())
    # model.predict_proba(test_X.numpy())
    # draw_policy_fig(model)
    
    # model = DR_policy_learning_forest(train_X.numpy(), train_treatment.view(-1).numpy(), train_outcome.view(-1).numpy())
    # draw_policy_fig_forest(model)

    # # causal forest
    model = causal_forest(train_X.numpy(), train_treatment.view(-1).numpy(), train_outcome.view(-1).numpy())
    # interpret_tree(model, torch.cat([test_X, single_test_X.view(1,-1)]).numpy(), adapted_feat_name_ls, max_depth=model.max_depth, min_samples_leaf=model.min_samples_leaf)
    pred_y = causal_forest_predict(model, test_X.numpy(), test_treatment.numpy())
    auc_score = roc_auc_score(test_outcome.view(-1).numpy(), pred_y.reshape(-1))
    accuracy = np.mean((pred_y > 0.5).astype(int).reshape(-1) == test_outcome.numpy().reshape(-1))
    print("test auc score::", auc_score)
    print("test accuracy::", accuracy)
    
    # interpret_tree(model, torch.cat([test_X, single_test_X.view(1,-1)]).numpy(), adapted_feat_name_ls, max_depth=model.max_depth)
    
    # model = ortho_forest(train_X.numpy(), train_treatment.view(-1).numpy(), train_outcome.view(-1).numpy())
    # pred_y = causal_forest_predict(model, test_X.numpy(), test_treatment.numpy())
    # auc_score = roc_auc_score(test_outcome.view(-1).numpy(), pred_y.reshape(-1))
    # accuracy = np.mean((pred_y > 0.5).astype(int).reshape(-1) == test_outcome.numpy().reshape(-1))
    # print("test auc score::", auc_score)
    # print("test accuracy::", accuracy)
    # interpret_tree(model, torch.cat([test_X, single_test_X.view(1,-1)]).numpy(), adapted_feat_name_ls, max_depth=model.max_depth, output_suffix="_ortho")
    
    # print()
    


    # ganite
    # 
    model = ganite(train_X, train_treatment.view(-1), train_outcome.view(-1))
    # fit_decision_tree(adapted_feat_name_ls, model, torch.cat([test_X, single_test_X.view(1,-1)], dim=0))
    # all_pred_y = model.inference_nets(test_X.cuda()).detach().cpu()
    # pred_y_score = all_pred_y[torch.arange(len(all_pred_y)),single_test_treatment.type(torch.long)]
    # auc_score = roc_auc_score(test_outcome.view(-1).numpy(), pred_y_score.view(-1).detach().numpy())

    # interpret_policy(model, single_test_X, feat_name_ls)

    # model = xlearner(train_X.numpy(), train_treatment.view(-1).numpy(), train_outcome.view(-1).numpy())
    # pred_y = xlearner_evaluation_by_treatment(model, test_X.numpy(), test_treatment.view(-1).numpy())
    # auc_score = roc_auc_score(test_outcome.view(-1).numpy(), pred_y.reshape(-1))
    # accuracy = np.mean((pred_y > 0.5).astype(int).reshape(-1) == test_outcome.numpy().reshape(-1))
    # print("test auc score::", auc_score)
    # print("test accuracy::", accuracy)
    # print()
    
    
    
    # model = Slearner(train_X.numpy(), train_treatment.view(-1).numpy(), train_outcome.view(-1).numpy())
    # # model.effect(test_X.numpy())
    # pred_y = slearner_evaluation_by_treatment(model, test_X.numpy(), test_treatment.view(-1).numpy())
    # auc_score = roc_auc_score(test_outcome.view(-1).numpy(), pred_y.reshape(-1))
    # accuracy = np.mean((pred_y > 0.5).astype(int).reshape(-1) == test_outcome.numpy().reshape(-1))
    # print("test auc score::", auc_score)
    # print("test accuracy::", accuracy)
    # print()
    # train_df = transform_tensor_to_df(train_X, train_treatment, train_outcome, feat_name_ls)
    # valid_df = transform_tensor_to_df(valid_X, valid_treatment, valid_outcome, feat_name_ls)
    # test_df = transform_tensor_to_df(test_X, test_treatment, test_outcome, feat_name_ls)
    # all_data = pd.concat([train_df, valid_df, test_df])
    # args.num_treatments=2
    
    # train_dataset, valid_dataset, test_dataset, feat_range_mappings = create_dataset("cancer_data", all_data, train_df, valid_df, test_df, synthetic_lang, "id", "outcome","treatment", synthetic_lang.DROP_FEATS)
    # root_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tabular")
    # rl_config, model_config = load_configs(args,root_dir=root_dir)
    # args.epochs = 5
    # trainer = baseline_trainer(args, input_size, model_config, device, outcome_regression=False)
    
    # trainer.run(train_dataset, valid_dataset, test_dataset)
    # with torch.no_grad():
    #     trainer.model.eval()
    #     _,pred_test_y = trainer.model(torch.cat([test_X, single_test_X.view(1,-1)]).to(device), torch.cat([test_treatment.view(-1), single_test_treatment.view(-1)]).to(device))
    # fit_decision_tree(feat_name_ls, trainer.model, torch.cat([test_X, single_test_X.view(1,-1)], dim=0), pred_test_y=pred_test_y.detach().cpu())
    

    # model = DR_policy_learning(train_X.numpy(), train_treatment.view(-1).numpy(), train_outcome.view(-1).numpy())
    
    # model = dynamic_dml(train_X.numpy(), train_treatment.view(-1).numpy(), train_outcome.view(-1).numpy())
    # pred_test_Y = xlearner_evaluation_by_treatment(model, test_X, test_treatment.view(-1).numpy())
    # pred_valid_Y = xlearner_evaluation_by_treatment(model, valid_X, valid_treatment.view(-1).numpy())
    # test_auc = roc_auc_score(test_outcome, pred_test_Y)
    # test_acc = np.mean((pred_test_Y > 0.5).astype(int).reshape(-1) == test_outcome.numpy().reshape(-1))
    # print("test auc::", test_auc)
    # print("test accuracy::", test_acc)
    # interpret_shap(model, single_test_X.numpy())
    # draw_policy_fig(model)
    


    