def estimate_ipw(
    self,
    ht_or_hajek: str,
    propensity_estimation: str,  # 'logistic' or 'random_forest' or '1or0'
    propensity_federations: list = ["MW, SW"],
    final_estimators: list = ["Oracle", "Meta-SW*", "Meta-SW", "MW", "DW-param"],
    use_grf: bool = False,
    params_for_federation_omega: dict = None,
    generate_by_center: bool = True,
    print_overlap: bool = False,
    shared_PS_coefs=None,
    rerun_NN: bool = False,
    scale=False,
):
    if "oracle" in propensity_federations:
        self.show_hidden_variables = True

    dict_structure = {
        f"client{i}": [] for i in range(1, len(self.clients_list) + 1)
    } | {"total_data": []}
    local_ipw_estimates = {
        "oracle": copy.deepcopy(dict_structure),
        "pool": copy.deepcopy(dict_structure),
        "local": copy.deepcopy(dict_structure),
        "local_oracle": copy.deepcopy(dict_structure),
        "MW": copy.deepcopy(dict_structure),
        "MW_pool_logistic": copy.deepcopy(dict_structure),
        "MW_logistic": copy.deepcopy(dict_structure),
        "MW_RF": copy.deepcopy(dict_structure),
        "MW_NN": copy.deepcopy(dict_structure),
        "MW_pool": copy.deepcopy(dict_structure),
        "oracle_weights": copy.deepcopy(dict_structure),
        "SW": copy.deepcopy(dict_structure),
        "drkernel": copy.deepcopy(dict_structure),
        "drexponential": copy.deepcopy(dict_structure),
        "drpackage": copy.deepcopy(dict_structure),
        "dwparam": copy.deepcopy(dict_structure),
        "1S-IVW": copy.deepcopy(dict_structure),
    }
    estimators_ipw = {final_estimator: [] for final_estimator in final_estimators}

    X_cols = ["X" + str(i) for i in range(1, self.dim_x + 1)]
    if generate_by_center:
        mu_k = {}
        Sigma_k = {}
        for k in range(1, len(self.clients_list) + 1):
            mu_k["client" + str(k)] = self.client_params_dict["client" + str(k)][
                "mean_covariates"
            ]
            Sigma_k["client" + str(k)] = self.client_params_dict["client" + str(k)][
                "cov_covariates"
            ]

    if use_grf:
        pandas2ri.activate()
        # Import the grf package
        grf = importr("grf")

    for _ in tqdm(range(self.n_simulations)):
        X_cols = ["X" + str(i) for i in range(1, self.dim_x + 1)]

        # Generate the df with membership probabilities
        if generate_by_center == True:
            df_ipw = Simulations_Fed(
                client_params_dict=self.client_params_dict,
                estimator="ipw",
                n_simulations=1,
                fixed_design=False,
                known_sigma2=False,
                rct_binomial_treatment=False,
                estimate_Sigma=False,
                estimate_sigma2=False,
                estime_norm_beta1_minus_beta0=False,
            ).combine_data()
        else:
            df_ipw = Simulations_Fed(
                client_params_dict=self.client_params_dict,
                estimator="ipw",
                n_simulations=1,
                fixed_design=False,
                known_sigma2=False,
                rct_binomial_treatment=False,
                estimate_Sigma=False,
                estimate_sigma2=False,
                estime_norm_beta1_minus_beta0=False,
            ).make_data_by_H_given_X()

        if use_grf:
            # Convert data to R objects
            X_r = pandas2ri.py2rpy(pd.DataFrame(df_ipw[self.treatment_cols].values))
            W_r = ro.FactorVector(pandas2ri.py2rpy(pd.Series(df_ipw["W"].values)))

        # Compute propensity scores
        if "pool" in propensity_federations:
            if propensity_estimation == "logistic":
                hat_gamma_pool = compute_gamma(df_ipw, self.treatment_cols)
                df_ipw["hat_e_pool"] = logistic_function_vectorized(
                    df_ipw, self.treatment_cols, hat_gamma_pool
                )
            elif propensity_estimation == "random_forest":
                if not use_grf:
                    rf_model = RandomForestClassifier(
                        n_estimators=1000,
                        min_samples_leaf=int(np.sqrt(len(df_ipw))),
                        max_depth=2,
                    )
                    df_ipw["hat_e_pool"] = rf_model.fit(
                        df_ipw[self.treatment_cols].values, df_ipw["W"].values
                    ).predict_proba(df_ipw[self.treatment_cols])[:, 1]
                else:
                    # Train a probability forest
                    p_forest = grf.probability_forest(X_r, W_r, num_trees=2000)
                    # Predict using the forest
                    p_hat = grf.predict_probability_forest(p_forest, X_r)
                    predictions = np.array(p_hat.rx2("predictions"))
                    df_ipw["hat_e_pool"] = predictions[:, 1]
            # elif propensity_estimation == "1or0":
            #
        if any(
            [
                "MW" in propensity_federations,
                "local" in propensity_federations,
                "dwparam" in propensity_federations,
            ]
        ):
            cols_local_propensities = [
                "hat_e_local_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            if propensity_estimation == "logistic":
                dict_hat_gammas = {}
                for k in range(0, len(df_ipw["client"].unique())):
                    # Compute local propensity scores
                    dict_hat_gammas["hat_gamma_" + str(k + 1)] = compute_gamma(
                        df_ipw[df_ipw["client"] == "client" + str(k + 1)],
                        self.treatment_cols,
                    )
                    df_ipw["hat_e_local_" + str(k + 1)] = (
                        logistic_function_vectorized(
                            df_ipw,
                            self.treatment_cols,
                            dict_hat_gammas["hat_gamma_" + str(k + 1)],
                        )
                    )
            elif propensity_estimation == "random_forest":
                if not use_grf:
                    for k in range(1, len(self.clients_list) + 1):
                        rf_model = RandomForestClassifier(
                            n_estimators=2000,
                            min_samples_leaf=int(
                                np.sqrt(
                                    len(
                                        df_ipw[
                                            df_ipw["client"] == "client" + str(k)
                                        ]
                                    )
                                )
                            ),
                            max_depth=2,
                        )
                        df_ipw["hat_e_local_" + str(k)] = rf_model.fit(
                            df_ipw[df_ipw["client"] == "client" + str(k)][
                                self.treatment_cols
                            ].values,
                            df_ipw[df_ipw["client"] == "client" + str(k)][
                                "W"
                            ].values,
                        ).predict_proba(df_ipw[self.treatment_cols].values)[:, 1]
                else:
                    for k in range(1, len(self.clients_list) + 1):
                        dfk = df_ipw[df_ipw["client"] == "client" + str(k)]
                        X_rk = pandas2ri.py2rpy(
                            pd.DataFrame(dfk[self.treatment_cols].values)
                        )
                        W_rk = ro.FactorVector(
                            pandas2ri.py2rpy(pd.Series(dfk["W"].values))
                        )
                        # Train probability forest
                        p_forestk = grf.probability_forest(
                            X_rk, W_rk, num_trees=2000
                        )
                        # Predict
                        p_hatk = grf.predict_probability_forest(p_forestk, X_r)
                        predictions = np.array(p_hatk.rx2("predictions"))
                        df_ipw["hat_e_local_" + str(k)] = predictions[:, 1]
            elif propensity_estimation == "1or0":
                for k in range(1, len(self.clients_list) + 1):
                    df_ipw["hat_e_local_" + str(k)] = (
                        1
                        if df_ipw[df_ipw["client"] == "client" + str(k)]["W"].mean()
                        > 0.5
                        else 0
                    )
            elif propensity_estimation == "externalcontrolarm_client2":
                for k in range(1, len(self.clients_list) + 1):
                    if k == 2:
                        df_ipw["hat_e_local_" + str(k)] = 0
                    else:
                        df_ipw["hat_e_local_" + str(k)] = (
                            logistic_function_vectorized(
                                df_ipw,
                                self.treatment_cols,
                                self.client_params_dict["client" + str(k)]["gamma"],
                            )
                        )
        if "oracle" in propensity_federations:
            cols_e_oracle_k = [
                "e_oracle_" + str(k) for k in range(1, len(self.clients_list) + 1)
            ]
            if propensity_estimation == "logistic":
                for k in range(1, len(self.clients_list) + 1):
                    df_ipw["e_oracle_" + str(k)] = logistic_function_vectorized(
                        df_ipw,
                        self.treatment_cols,
                        self.client_params_dict["client" + str(k)]["gamma"],
                    )
            elif propensity_estimation == "random_forest":
                # df_ipw["e_oracle"] = df_ipw["propensity score*"]
                for k in range(1, len(self.clients_list) + 1):
                    df_ipw["e_oracle_" + str(k)] = generate_W_sequential(
                        df_ipw,
                        self.treatment_cols,
                        scenario=self.client_params_dict["client" + str(k)][
                            "sequential_treatment_scenario"
                        ],
                        return_p=True,
                    )
            elif propensity_estimation == "1or0":
                for k in range(1, len(self.clients_list) + 1):
                    df_ipw["e_oracle_" + str(k)] = (
                        1
                        if df_ipw[df_ipw["client"] == "client" + str(k)]["W"].mean()
                        > 0.5
                        else 0
                    )
            elif propensity_estimation == "externalcontrolarm_client2":
                for k in range(1, len(self.clients_list) + 1):
                    if k == 2:
                        df_ipw["e_oracle_" + str(k)] = 0
                    else:
                        df_ipw["e_oracle_" + str(k)] = logistic_function_vectorized(
                            df_ipw,
                            self.treatment_cols,
                            self.client_params_dict["client" + str(k)]["gamma"],
                        )

            cols_true_omegaks = [
                "omega_" + str(k) + "*"
                for k in range(1, len(self.clients_list) + 1)
            ]
            if generate_by_center:
                # get the true omegas with parametric form of f_k(X)/f(X) where f_k is a normal with true parameters
                for k in range(1, len(self.clients_list) + 1):
                    n_k = len(df_ipw[df_ipw["client"] == "client" + str(k)])
                    df_ipw["omega_" + str(k) + "*"] = (
                        normal_density(
                            df_ipw[self.X_cols].values,
                            mu_k["client" + str(k)],
                            Sigma_k["client" + str(k)],
                        )  # defined above
                        * n_k
                        / len(df_ipw)
                    )
                sum_dr_weights = np.sum(df_ipw[cols_true_omegaks].values, axis=1)
                df_ipw[cols_true_omegaks] = df_ipw[cols_true_omegaks].div(
                    sum_dr_weights, axis=0
                )  # Normalize the density ratios
            else:
                df_ipw[cols_true_omegaks] = multi_logistic(
                    df_ipw[self.sorting_columns].values,
                    self.client_params_dict["global_population"][
                        "membership_Thetas"
                    ],
                    return_prob=True,
                )

            df_ipw["e_oracle"] = membership_weighting_vectorized(
                df_ipw[cols_true_omegaks],
                df_ipw[cols_e_oracle_k],
            )
            # Skip iteration if the weak global overlap condition is not satisfied
            # if (np.sum(df_ipw[cols_e_oracle_k] * (1 - df_ipw[cols_e_oracle_k])) < 1e-4).any():
            if (
                np.sum(1 / (df_ipw["e_oracle"] * (1 - df_ipw["e_oracle"])))
                / len(df_ipw)
                > 1000
            ):
                print(
                    f"Skipping simulation {_} due to lack of weak global overlap condition."
                )
                continue

        cols_local_e_k = [
            "hat_e_local_" + str(k) for k in range(1, len(self.clients_list) + 1)
        ]
        if "MW_logistic" in propensity_federations:
            cols_MW_logistic = [
                "mw_logistic" + str(k + 1) for k in range(len(self.clients_list))
            ]
            MW_columns = (
                self.sorting_columns
                if not generate_by_center
                else self.outcome_cols
            )
            if params_for_federation_omega is not None:
                df_ipw[cols_MW_logistic] = MW_estimation(
                    df_ipw, MW_columns, **params_for_federation_omega, scale=scale
                )
            else:
                df_ipw[cols_MW_logistic] = MW_estimation_pooled(
                    df_ipw, MW_columns, scale=scale
                )

            df_ipw["hat_e_MW_logistic"] = membership_weighting_vectorized(
                df_ipw[cols_MW_logistic],
                df_ipw[cols_local_e_k],
            )
        if "MW_pool_logistic" in propensity_federations:
            cols_MW_pool_logistic = [
                "mw_pool_logistic" + str(k + 1)
                for k in range(len(self.clients_list))
            ]
            MW_columns = (
                self.sorting_columns
                if not generate_by_center
                else self.outcome_cols
            )
            df_ipw[cols_MW_pool_logistic] = MW_estimation_pooled(
                df_ipw, MW_columns, scale=scale
            )
            df_ipw["hat_e_MW_pool"] = membership_weighting_vectorized(
                df_ipw[cols_MW_pool_logistic],
                df_ipw[cols_local_e_k],
            )
        if "MW_RF" in propensity_federations:
            cols_MW_rf = [
                "mw_rf" + str(k + 1) for k in range(len(self.clients_list))
            ]
            # standardize
            # scaler = StandardScaler()
            X, y = df_ipw[self.sorting_columns].values, df_ipw["client"].values
            # X_scaled = scaler.fit_transform(X)
            rf_model = RandomForestClassifier(
                n_estimators=500,
                criterion="log_loss",  # Use log loss for probabilities
                max_depth=None,  # Let trees grow fully (or use 10-15)
                min_samples_leaf=5,  # Less regularization
                # min_samples_split=5,  # Default
                max_features="sqrt",  # Helps decorrelate trees
            )
            from sklearn.calibration import CalibratedClassifierCV
            from sklearn.model_selection import cross_val_predict
            calibrated_model = CalibratedClassifierCV(base_estimator=rf_model, method='isotonic', cv=5)
            # Cross-validated, calibrated probability predictions
            probs = cross_val_predict(calibrated_model, X, y, cv=5, method='predict_proba')
            # calibrated_model.fit(X, y)
            # df_ipw[cols_MW_rf] = calibrated_model.predict_proba(X)
            # print(probs)
            df_ipw[cols_MW_rf] = probs
            df_ipw["hat_e_MW_RF"] = membership_weighting_vectorized(
                df_ipw[cols_MW_rf],
                df_ipw[cols_local_e_k],
            )
        if "MW_NN" in propensity_federations:
            cols_MW_nn = [
                "mw_nn" + str(k + 1) for k in range(len(self.clients_list))
            ]
            # # Get probabilities
            X = df_ipw[self.sorting_columns].values
            X_scaled = StandardScaler().fit_transform(X)

            df_ipw[cols_MW_nn] = self.train_NN(df_ipw, rerun_NN=rerun_NN)

            df_ipw["hat_e_MW_NN"] = membership_weighting_vectorized(
                df_ipw[cols_MW_nn],
                df_ipw[cols_local_e_k],
            )

        if "oracle_weights" in propensity_federations:
            cols_true_omegaks = [
                "omega_" + str(k) + "*"
                for k in range(1, len(self.clients_list) + 1)
            ]
            df_ipw["hat_e_oracle_weights"] = membership_weighting_vectorized(
                df_ipw[cols_true_omegaks],
                df_ipw[cols_local_e_k],
            )

        if "drkernel" in propensity_federations:
            X_pool = df_ipw[self.sorting_columns].values
            kde_pool = KernelDensity(kernel="gaussian", bandwidth=0.05).fit(X_pool)

            for k in range(1, len(self.clients_list) + 1):
                # Filter data for client k
                dfk = df_ipw[df_ipw["client"] == "client" + str(k)]
                X_k = dfk[self.sorting_columns].values

                # Fit KDEs for client k and pool
                kde_k = KernelDensity(kernel="gaussian", bandwidth=0.05).fit(X_k)

                # Compute the density ratio for client k
                df_ipw["hat_drkernel_" + str(k)] = (
                    density_ratio_kernel(X_pool, kde_k, kde_pool)
                    * len(X_k)
                    / len(X_pool)
                )

            cols_dr_kernel = [
                "hat_drkernel_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            # Update the density ratio column for client k
            df_ipw["hat_e_drkernel"] = dr_weighting_vectorized(
                df_ipw[cols_dr_kernel], df_ipw[cols_local_propensities]
            )

        if "dwparam" in propensity_federations:
            # We estimate their mean and variance
            hat_mu_k = {}
            hat_Sigma_k = {}
            for k in range(1, len(self.clients_list) + 1):
                df_k = df_ipw[df_ipw["client"] == "client" + str(k)]
                n_k = df_k.shape[0]
                hat_mu_k["client" + str(k)] = np.mean(
                    df_k[self.sorting_columns].values,
                    axis=0,
                )
                hat_Sigma_k["client" + str(k)] = (
                    1
                    / len(df_k)
                    * np.dot(
                        (
                            df_k[self.sorting_columns].values
                            - hat_mu_k["client" + str(k)]
                        ).T,
                        df_k[self.sorting_columns].values
                        - hat_mu_k["client" + str(k)],
                    )
                )
                # n_k/n * f_k(X_i)
                df_ipw["hat_dwparam_" + str(k)] = (
                    normal_density(
                        df_ipw[self.sorting_columns],
                        hat_mu_k["client" + str(k)],
                        hat_Sigma_k["client" + str(k)],
                    )
                    * n_k
                    / len(df_ipw[self.sorting_columns])
                )

            cols_dr_param = [
                "hat_dwparam_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            sum_dr_weights = np.sum(df_ipw[cols_dr_param].values, axis=1)

            df_ipw[cols_dr_param] = df_ipw[cols_dr_param].div(
                sum_dr_weights, axis=0
            )  # Normalize the density ratios

            # Update the density ratio column for client k
            df_ipw["hat_e_dwparam"] = dr_weighting_vectorized(
                df_ipw[cols_dr_param], df_ipw[cols_local_propensities]
            )

        if "drpackage" in propensity_federations:
            X_pool = df_ipw[self.sorting_columns].values
            for k in range(1, len(self.clients_list) + 1):
                # Filter data for client k
                client_mask = df_ipw["client"] == "client" + str(k)
                X_k = df_ipw[client_mask][self.sorting_columns].values

                # Compute the density ratio for client k
                dr_np_k = densratio(
                    X_k, X_pool, lambda_range=[0.01], sigma_range=[1], verbose=False
                )
                df_ipw["hat_drpackage_" + str(k)] = (
                    dr_np_k.compute_density_ratio(X_pool) * len(X_k) / len(X_pool)
                )

            cols_dr_package = [
                "hat_drpackage_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            # Update the density ratio column for client k
            df_ipw["hat_e_drpackage"] = dr_weighting_vectorized(
                df_ipw[cols_dr_package], df_ipw[cols_local_propensities]
            )

        if "drexponential" in propensity_federations:
            skip_simulation = False
            X_pool = df_ipw[self.treatment_cols].values

            # Iterate over each client
            for k in range(1, len(self.clients_list) + 1):
                # Filter data for client k
                dfk = df_ipw[df_ipw["client"] == "client" + str(k)]
                X_k = dfk[self.treatment_cols].values

                exp_tilt_k = solve_gamma(X_k, np.mean(X_pool, axis=0))

                if exp_tilt_k is None:
                    print(f"Skipping client {k} due to non-convergence.")
                    skip_simulation = True  # Set the flag to skip the simulation
                    break  # Exit the client loop

                df_ipw["hat_drexponential_" + str(k)] = (
                    estimate_exponential_dr(X_pool, exp_tilt_k)
                    * len(X_k)
                    / len(X_pool)
                )

            # If any client caused non-convergence, skip the rest of the simulation
            if skip_simulation:
                print("Skipping the entire simulation due to non-convergence.")
                continue  # Exit the simulation loop

            cols_dr_exponential = [
                "hat_drexponential_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            df_ipw[cols_dr_exponential] = df_ipw[cols_dr_exponential].div(
                df_ipw[cols_dr_exponential].sum(axis=1), axis=0
            )  # Normalize the exponential density ratios
            df_ipw["hat_e_drexponential"] = dr_weighting_vectorized(
                df_ipw[cols_dr_exponential], df_ipw[cols_local_propensities]
            )

        # Compute final estimators
        if "Oracle" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["oracle"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="e_oracle",
                        ht_or_hajek="HT",
                    )
                )
            local_ipw_estimates["oracle"]["total_data"].append(
                ipw_function(df_ipw, hat_e_name="e_oracle", ht_or_hajek=ht_or_hajek)
            )
            estimators_ipw["Oracle"].append(
                local_ipw_estimates["oracle"]["total_data"][-1]
            )

        if "Meta-SW" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                hat_tau_k = ipw_function(
                    df_ipw[df_ipw["client"] == "client" + str(k)],
                    hat_e_name="hat_e_local_" + str(k),
                    ht_or_hajek=ht_or_hajek,
                )
                local_ipw_estimates["local"]["client" + str(k)].append(hat_tau_k)
            # Agregate with SW
            estimators_ipw["Meta-SW"].append(
                np.average(
                    [
                        local_ipw_estimates["local"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )
            # if "Meta-RE" in final_estimators:
            #     # Compute the DerSimonian and Laird estimator
            #     sigmas2 = []
            #     for k in range(1, len(self.clients_list) + 1):
            #         sigmak2 = 1 / len(
            #             df_ipw[df_ipw["client"] == "client" + str(k)] - 1
            #         ) * np.sum(

            #         )

            #     estimators_ipw["Meta-RE"].append(
            #         np.average(
            #             [
            #                 local_ipw_estimates["local_ates"]["client" + str(k)][-1]
            #                 for k in range(1, len(self.clients_list) + 1)
            #             ],
            #             weights=[
            #                 ???
            #             ],
            #         )
            #     )

        if "1S-IVW" in final_estimators:
            data_dict = {}
            for k in range(1, len(self.clients_list) + 1):
                data_dict["client" + str(k)] = df_ipw[
                    df_ipw["client"] == "client" + str(k)
                ][self.treatment_cols]
            list_gammas_k = [
                dict_hat_gammas["hat_gamma_" + str(k)]
                for k in range(1, len(self.clients_list) + 1)
            ]
            if shared_PS_coefs is not None:
                shared_PS_coefs = shared_PS_coefs
            else:
                print("Shared PS coefs not provided, using local ones.")
            list_one_shot_gammas_k = OneShot_IVW_logistic_coefficients(
                data_dict,
                list_gammas_k,
                shared_PS_coefs,
            )
            for k in range(1, len(self.clients_list) + 1):
                df_ipw["hat_e_1S-IVW_" + str(k)] = logistic_function_vectorized(
                    df_ipw,
                    self.treatment_cols,
                    list_one_shot_gammas_k[k - 1],
                )
                local_ipw_estimates["1S-IVW"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_1S-IVW_" + str(k),
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["1S-IVW"].append(
                np.average(
                    [
                        local_ipw_estimates["1S-IVW"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )

        if "Meta-SW*" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["SW"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="e_oracle_" + str(k),
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["Meta-SW*"].append(
                np.average(
                    [
                        local_ipw_estimates["SW"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )

        if "MW_logistic" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["MW_logistic"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_MW_logistic",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["MW_logistic"].append(
                np.average(
                    [
                        local_ipw_estimates["MW_logistic"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )
        if "MW_pool_logistic" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["MW_pool_logistic"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_MW_pool",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["MW_pool_logistic"].append(
                np.average(
                    [
                        local_ipw_estimates["MW_pool_logistic"]["client" + str(k)][
                            -1
                        ]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )
        if "MW_RF" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["MW_RF"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_MW_RF",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["MW_RF"].append(
                np.average(
                    [
                        local_ipw_estimates["MW_RF"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )
        if "MW_NN" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["MW_NN"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_MW_NN",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["MW_NN"].append(
                np.average(
                    [
                        local_ipw_estimates["MW_NN"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )
        # Compute the overlap
        if print_overlap and _ % 50 == 0:
            print(
                f"Global overlap: {1/len(df_ipw) * np.sum(1/((df_ipw['e_oracle']) * (1 - df_ipw['e_oracle'])))}"
            )
            for k in range(1, len(self.clients_list) + 1):
                print(
                    f"Overlap client {k}: {1/len(df_ipw[df_ipw['client'] == 'client' + str(k)]) * np.sum(1/((df_ipw[df_ipw['client'] == 'client' + str(k)]['e_oracle_'+ str(k)]) * (1 - df_ipw[df_ipw['client'] == 'client' + str(k)]['e_oracle_'+ str(k)])))}"
                )

        if "pool" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["pool"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_pool",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["pool"].append(
                np.average(
                    [
                        local_ipw_estimates["pool"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )

        if "oracle_weights" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["oracle_weights"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_oracle_weights",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["oracle_weights"].append(
                np.average(
                    [
                        local_ipw_estimates["oracle_weights"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )

        if "DW-param" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["dwparam"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_dwparam",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["DW-param"].append(
                np.average(
                    [
                        local_ipw_estimates["dwparam"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_ipw[df_ipw["client"] == "client" + str(k)])
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )

        if "DW-kernel" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["drkernel"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_drkernel",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["DW-kernel"].append(
                np.average(
                    [
                        local_ipw_estimates["drkernel"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        self.client_params_dict["client" + str(k)]["sample_size"]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )

        if "DW-exp" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["drexponential"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_drexponential",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["DW-exp"].append(
                np.average(
                    [
                        local_ipw_estimates["drexponential"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        self.client_params_dict["client" + str(k)]["sample_size"]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )

        if "DW-package" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_ipw_estimates["DW-package"]["client" + str(k)].append(
                    ipw_function(
                        df_ipw[df_ipw["client"] == "client" + str(k)],
                        hat_e_name="hat_e_drpackage",
                        ht_or_hajek=ht_or_hajek,
                    )
                )
            # Agregate with SW
            estimators_ipw["DW-package"].append(
                np.average(
                    [
                        local_ipw_estimates["DW-package"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        self.client_params_dict["client" + str(k)]["sample_size"]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )

            # sort the columns according to final_estimators
            # for col in final_estimators:
            #     if col in estimators_ipw:
            #         estimators_ipw[col] = estimators_ipw[col][::-1]

    return estimators_ipw, df_ipw, local_ipw_estimates

def estimate_aipw(
    self,
    propensity_federations: list = [
        "oracle",
        "pool",
        "MW_logistic",
        "MW_NN",
        "dwparam",
        "local",
    ],
    outcome_estimators: list = ["local", "pool", "1SIVW", "GD"],
    final_estimators: list = [
        "Oracle",
        "Meta-SW",
        "GD-MW_logistic",
        "GD-MW_NN",
        "GD-DWparam",
        "1SIVW-MW_logistic",
        "1SIVW-MW_NN",
        "1SIVW-DWparam",
    ],
    propensity_estimation: str = "logistic",
    use_grf: bool = False,
    params_for_federation=None or dict,
    params_for_federation_omega=None or dict,
    generate_by_center: bool = True,
    scale=False,
    print_overlap=True,
    rerun_NN=True,
):
    dict_structure = {
        f"client{i}": [] for i in range(1, len(self.clients_list) + 1)
    } | {"total_data": []}
    local_aipw_estimates = {
        "oracle": copy.deepcopy(dict_structure),
        "Locals": copy.deepcopy(dict_structure),
        "GD-MW_logistic": copy.deepcopy(dict_structure),
        "GD-MW_NN": copy.deepcopy(dict_structure),
        "GD-DWparam": copy.deepcopy(dict_structure),
        "pool": copy.deepcopy(dict_structure),
        "1SIVW-MW_logistic": copy.deepcopy(dict_structure),
        "1SIVW-MW_NN": copy.deepcopy(dict_structure),
        "1SIVW-DWparam": copy.deepcopy(dict_structure),
        "Locals-MW_logistic": copy.deepcopy(dict_structure),
        "Locals-MW_NN": copy.deepcopy(dict_structure),
        "Locals-DWparam": copy.deepcopy(dict_structure),
    }
    estimators_aipw = {final_estimator: [] for final_estimator in final_estimators}

    X_cols = ["X" + str(i) for i in range(1, self.dim_x + 1)]

    if use_grf:
        pandas2ri.activate()
        # Import the grf package
        grf = importr("grf")

    for _ in tqdm(range(self.n_simulations)):
        # Generate the df with membership probabilities
        if generate_by_center == True:
            df_aipw = Simulations_Fed(
                client_params_dict=self.client_params_dict,
                estimator="aipw",
                n_simulations=1,
                fixed_design=False,
                known_sigma2=False,
                rct_binomial_treatment=False,
                estimate_Sigma=False,
                estimate_sigma2=False,
                estime_norm_beta1_minus_beta0=False,
            ).combine_data()
        else:
            df_aipw = Simulations_Fed(
                client_params_dict=self.client_params_dict,
                estimator="aipw",
                n_simulations=1,
                fixed_design=False,
                known_sigma2=False,
                rct_binomial_treatment=False,
                estimate_Sigma=False,
                estimate_sigma2=False,
                estime_norm_beta1_minus_beta0=False,
            ).make_data_by_H_given_X()

        X_cols = ["X" + str(i) for i in range(1, self.dim_x + 1)]
        if generate_by_center:
            mu_k = {}
            Sigma_k = {}
            for k in range(1, len(self.clients_list) + 1):
                mu_k["client" + str(k)] = self.client_params_dict[
                    "client" + str(k)
                ]["mean_covariates"]
                Sigma_k["client" + str(k)] = self.client_params_dict[
                    "client" + str(k)
                ]["cov_covariates"]

        if use_grf:
            # Convert data to R objects
            X_r = pandas2ri.py2rpy(
                pd.DataFrame(df_aipw[self.treatment_cols].values)
            )
            W_r = ro.FactorVector(pandas2ri.py2rpy(pd.Series(df_aipw["W"].values)))
        # Compute propensity scores
        if "pool" in propensity_federations:
            if propensity_estimation == "logistic":
                hat_gamma_pool = compute_gamma(df_aipw, self.treatment_cols)
                df_aipw["hat_e_pool"] = logistic_function_vectorized(
                    df_aipw, self.treatment_cols, hat_gamma_pool
                )
            elif propensity_estimation == "random_forest":
                if not use_grf:
                    rf_model = RandomForestClassifier(
                        n_estimators=2000,
                        min_samples_leaf=int(np.sqrt(len(df_aipw))),
                        max_depth=2,
                    )
                    df_aipw["hat_e_pool"] = rf_model.fit(
                        df_aipw[self.treatment_cols].values, df_aipw["W"].values
                    ).predict_proba(df_aipw[self.treatment_cols])[:, 1]
                else:
                    # Train a probability forest
                    p_forest = grf.probability_forest(X_r, W_r, num_trees=2000)
                    # Predict using the forest
                    p_hat = grf.predict_probability_forest(p_forest, X_r)
                    predictions = np.array(p_hat.rx2("predictions"))
                    df_aipw["hat_e_pool"] = predictions[:, 1]
            # elif propensity_estimation == "1or0":
            #
        if any(
            [
                "MW" in propensity_federations,
                "local" in propensity_federations,
                "dwparam" in propensity_federations,
            ]
        ):
            cols_local_propensities = [
                "hat_e_local_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            if propensity_estimation == "logistic":
                dict_hat_gammas = {}
                for k in range(0, len(df_aipw["client"].unique())):
                    # Compute local propensity scores
                    dict_hat_gammas["hat_gamma_" + str(k + 1)] = compute_gamma(
                        df_aipw[df_aipw["client"] == "client" + str(k + 1)],
                        self.treatment_cols,
                    )
                    df_aipw["hat_e_local_" + str(k + 1)] = (
                        logistic_function_vectorized(
                            df_aipw,
                            self.treatment_cols,
                            dict_hat_gammas["hat_gamma_" + str(k + 1)],
                        )
                    )
            elif propensity_estimation == "random_forest":
                if not use_grf:
                    for k in range(1, len(self.clients_list) + 1):
                        rf_model = RandomForestClassifier(
                            n_estimators=2000,
                            min_samples_leaf=int(
                                np.sqrt(
                                    len(
                                        df_aipw[
                                            df_aipw["client"] == "client" + str(k)
                                        ]
                                    )
                                )
                            ),
                            max_depth=2,
                        )
                        df_aipw["hat_e_local_" + str(k)] = rf_model.fit(
                            df_aipw[df_aipw["client"] == "client" + str(k)][
                                self.treatment_cols
                            ].values,
                            df_aipw[df_aipw["client"] == "client" + str(k)][
                                "W"
                            ].values,
                        ).predict_proba(df_aipw[self.treatment_cols].values)[:, 1]
                else:
                    for k in range(1, len(self.clients_list) + 1):
                        dfk = df_aipw[df_aipw["client"] == "client" + str(k)]
                        X_rk = pandas2ri.py2rpy(
                            pd.DataFrame(dfk[self.treatment_cols].values)
                        )
                        W_rk = ro.FactorVector(
                            pandas2ri.py2rpy(pd.Series(dfk["W"].values))
                        )
                        # Train probability forest
                        p_forestk = grf.probability_forest(
                            X_rk, W_rk, num_trees=2000
                        )
                        # Predict
                        p_hatk = grf.predict_probability_forest(p_forestk, X_r)
                        predictions = np.array(p_hatk.rx2("predictions"))
                        df_aipw["hat_e_local_" + str(k)] = predictions[:, 1]
            elif propensity_estimation == "1or0":
                for k in range(1, len(self.clients_list) + 1):
                    df_aipw["hat_e_local_" + str(k)] = (
                        1
                        if df_aipw[df_aipw["client"] == "client" + str(k)][
                            "W"
                        ].mean()
                        > 0.5
                        else 0
                    )
            elif propensity_estimation == "externalcontrolarm_client2":
                for k in range(1, len(self.clients_list) + 1):
                    if k == 2:
                        df_aipw["e_oracle_" + str(k)] = 0
                    else:
                        df_aipw["e_oracle_" + str(k)] = (
                            logistic_function_vectorized(
                                df_aipw,
                                self.treatment_cols,
                                self.client_params_dict["client" + str(k)]["gamma"],
                            )
                        )

        if "oracle" in propensity_federations:
            cols_e_oracle_k = [
                "e_oracle_" + str(k) for k in range(1, len(self.clients_list) + 1)
            ]
            if propensity_estimation == "logistic":
                for k in range(1, len(self.clients_list) + 1):
                    df_aipw["e_oracle_" + str(k)] = logistic_function_vectorized(
                        df_aipw,
                        self.treatment_cols,
                        self.client_params_dict["client" + str(k)]["gamma"],
                    )
            elif propensity_estimation == "random_forest":
                # df_aipw["e_oracle"] = df_aipw["propensity score*"]
                for k in range(1, len(self.clients_list) + 1):
                    df_aipw["e_oracle_" + str(k)] = generate_W_sequential(
                        df_aipw,
                        self.treatment_cols,
                        scenario=self.client_params_dict["client" + str(k)][
                            "sequential_treatment_scenario"
                        ],
                        return_p=True,
                    )
            elif propensity_estimation == "1or0":
                for k in range(1, len(self.clients_list) + 1):
                    df_aipw["e_oracle_" + str(k)] = (
                        1
                        if df_aipw[df_aipw["client"] == "client" + str(k)][
                            "W"
                        ].mean()
                        > 0.5
                        else 0
                    )
            elif propensity_estimation == "externalcontrolarm_client2":
                for k in range(1, len(self.clients_list) + 1):
                    if k == 2:
                        df_aipw["hat_e_local_" + str(k)] = 0
                    else:
                        df_aipw["hat_e_local_" + str(k)] = (
                            logistic_function_vectorized(
                                df_aipw,
                                self.treatment_cols,
                                self.client_params_dict["client" + str(k)]["gamma"],
                            )
                        )

            cols_true_omegaks = [
                "omega_" + str(k) + "*"
                for k in range(1, len(self.clients_list) + 1)
            ]
            if generate_by_center:
                # get the true omegas with parametric form of f_k(X)/f(X) where f_k is a normal with true parameters
                for k in range(1, len(self.clients_list) + 1):
                    n_k = len(df_aipw[df_aipw["client"] == "client" + str(k)])
                    df_aipw["omega_" + str(k) + "*"] = (
                        normal_density(
                            df_aipw[self.X_cols],
                            mu_k["client" + str(k)],
                            Sigma_k["client" + str(k)],
                        )  # defined above
                        * n_k
                        / len(df_aipw)
                    )
                sum_dr_weights = np.sum(df_aipw[cols_true_omegaks].values, axis=1)
                df_aipw[cols_true_omegaks] = df_aipw[cols_true_omegaks].div(
                    sum_dr_weights, axis=0
                )  # Normalize the density ratios
            else:
                df_aipw[cols_true_omegaks] = multi_logistic(
                    df_aipw[self.sorting_columns].values,
                    self.client_params_dict["global_population"][
                        "membership_Thetas"
                    ],
                    return_prob=True,
                )

            df_aipw["e_oracle"] = membership_weighting_vectorized(
                df_aipw[cols_true_omegaks],
                df_aipw[cols_e_oracle_k],
            )

        cols_local_e_k = [
            "hat_e_local_" + str(k) for k in range(1, len(self.clients_list) + 1)
        ]
        if "MW_logistic" in propensity_federations:
            cols_MW_logistic = [
                "mw_logistic" + str(k + 1) for k in range(len(self.clients_list))
            ]
            MW_columns = (
                self.sorting_columns
            )  # if not generate_by_center else self.treatment_cols
            if params_for_federation_omega is not None:
                df_aipw[cols_MW_logistic] = MW_estimation(
                    df_aipw, MW_columns, **params_for_federation_omega, scale=scale
                )
            else:
                df_aipw[cols_MW_logistic] = MW_estimation_pooled(
                    df_aipw, MW_columns, scale=scale
                )
                # df_aipw[cols_MW_pool] = df_aipw[cols_MW]

            df_aipw["hat_e_MW_logistic"] = membership_weighting_vectorized(
                df_aipw[cols_MW_logistic],
                df_aipw[cols_local_e_k],
            )
        if "MW_RF" in propensity_federations:
            cols_MW_rf = [
                "mw_rf" + str(k + 1) for k in range(len(self.clients_list))
            ]
            X, y = df_aipw[self.sorting_columns].values, df_aipw["client"].values
            rf_model = RandomForestClassifier(
                n_estimators=400,
                min_samples_leaf=int(np.sqrt(len(df_aipw))),
                max_depth=None,
            )
            rf_model.fit(X, y)
            df_aipw[cols_MW_rf] = rf_model.predict_proba(X)
            df_aipw["hat_e_MW_RF"] = membership_weighting_vectorized(
                df_aipw[cols_MW_rf],
                df_aipw[cols_local_e_k],
            )
        if "MW_NN" in propensity_federations:
            cols_MW_nn = [
                "mw_nn" + str(k + 1) for k in range(len(self.clients_list))
            ]
            # Get probabilities
            X = df_aipw[self.sorting_columns].values
            X_scaled = StandardScaler().fit_transform(X)

            df_aipw[cols_MW_nn] = self.train_NN(df_aipw, rerun_NN=rerun_NN)

            df_aipw["hat_e_MW_NN"] = membership_weighting_vectorized(
                df_aipw[cols_MW_nn],
                df_aipw[cols_local_e_k],
            )

        if "oracle_weights" in propensity_federations:
            cols_true_omegaks = [
                "omega_" + str(k) + "*"
                for k in range(1, len(self.clients_list) + 1)
            ]
            df_aipw["hat_e_oracle_weights"] = membership_weighting_vectorized(
                df_aipw[cols_true_omegaks],
                df_aipw[cols_local_e_k],
            )

        if "drkernel" in propensity_federations:
            X_pool = df_aipw[self.sorting_columns].values

            kde_pool = KernelDensity(kernel="gaussian", bandwidth=0.05).fit(X_pool)

            for k in range(1, len(self.clients_list) + 1):
                # Filter data for client k
                dfk = df_aipw[df_aipw["client"] == "client" + str(k)]
                X_k = dfk[self.sorting_columns].values

                # Fit KDEs for client k and pool
                kde_k = KernelDensity(kernel="gaussian", bandwidth=0.05).fit(X_k)

                # Compute the density ratio for client k
                df_aipw["hat_drkernel_" + str(k)] = (
                    density_ratio_kernel(X_pool, kde_k, kde_pool)
                    * len(X_k)
                    / len(X_pool)
                )

            cols_dr_kernel = [
                "hat_drkernel_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            # Update the density ratio column for client k
            df_aipw["hat_e_drkernel"] = dr_weighting_vectorized(
                df_aipw[cols_dr_kernel], df_aipw[cols_local_propensities]
            )

        if "dwparam" in propensity_federations:
            # We estimate their mean and variance
            hat_mu_k = {}
            hat_Sigma_k = {}
            for k in range(1, len(self.clients_list) + 1):
                df_k = df_aipw[df_aipw["client"] == "client" + str(k)]
                n_k = df_k.shape[0]
                hat_mu_k["client" + str(k)] = np.mean(
                    df_k[self.sorting_columns].values,
                    axis=0,
                )
                hat_Sigma_k["client" + str(k)] = (
                    1
                    / len(df_k)
                    * np.dot(
                        (
                            df_k[self.sorting_columns].values
                            - hat_mu_k["client" + str(k)]
                        ).T,
                        df_k[self.sorting_columns].values
                        - hat_mu_k["client" + str(k)],
                    )
                )
                # n_k/n * f_k(X_i)
                df_aipw["hat_dwparam_" + str(k)] = (
                    normal_density(
                        df_aipw[self.sorting_columns],
                        hat_mu_k["client" + str(k)],
                        hat_Sigma_k["client" + str(k)],
                    )
                    * n_k
                    / len(df_aipw[self.sorting_columns])
                )

            cols_dr_param = [
                "hat_dwparam_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            sum_dr_weights = np.sum(df_aipw[cols_dr_param].values, axis=1)

            df_aipw[cols_dr_param] = df_aipw[cols_dr_param].div(
                sum_dr_weights, axis=0
            )  # Normalize the density ratios

            # Update the density ratio column for client k
            df_aipw["hat_e_dwparam"] = dr_weighting_vectorized(
                df_aipw[cols_dr_param], df_aipw[cols_local_propensities]
            )

        if "drpackage" in propensity_federations:
            X_pool = df_aipw[self.sorting_columns].values
            for k in range(1, len(self.clients_list) + 1):
                # Filter data for client k
                client_mask = df_aipw["client"] == "client" + str(k)
                X_k = df_aipw[client_mask][self.sorting_columns].values

                # Compute the density ratio for client k
                dr_np_k = densratio(
                    X_k, X_pool, lambda_range=[0.01], sigma_range=[1], verbose=False
                )
                df_aipw["hat_drpackage_" + str(k)] = (
                    dr_np_k.compute_density_ratio(X_pool) * len(X_k) / len(X_pool)
                )

            cols_dr_package = [
                "hat_drpackage_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            # Update the density ratio column for client k
            df_aipw["hat_e_drpackage"] = dr_weighting_vectorized(
                df_aipw[cols_dr_package], df_aipw[cols_local_propensities]
            )

        if "drexponential" in propensity_federations:
            X_pool = df_aipw[self.sorting_columns].values

            # Iterate over each client
            for k in range(1, len(self.clients_list) + 1):
                # Filter data for client k
                dfk = df_aipw[df_aipw["client"] == "client" + str(k)]
                X_k = dfk[self.sorting_columns].values

                exp_tilt_k = solve_gamma(X_k, np.mean(X_pool, axis=0))
                df_aipw["hat_drexponential_" + str(k)] = (
                    estimate_exponential_dr(X_pool, exp_tilt_k)
                    * len(X_k)
                    / len(X_pool)
                )

            cols_dr_exponential = [
                "hat_drexponential_" + str(k)
                for k in range(1, len(self.clients_list) + 1)
            ]
            df_aipw[cols_dr_exponential] = df_aipw[cols_dr_exponential].div(
                df_aipw[cols_dr_exponential].sum(axis=1), axis=0
            )  # Normalize the exponential density ratios
            df_aipw["hat_e_drexponential"] = dr_weighting_vectorized(
                df_aipw[cols_dr_exponential], df_aipw[cols_local_propensities]
            )

        if print_overlap and _ % 50 == 0:
            print(
                f"Global overlap: {1/len(df_aipw) * np.sum(1/((df_aipw['e_oracle']) * (1 - df_aipw['e_oracle'])))}"
            )
            for k in range(1, len(self.clients_list) + 1):
                print(
                    f"Overlap client {k}: {1/len(df_aipw[df_aipw['client'] == 'client' + str(k)]) * np.sum(1/((df_aipw[df_aipw['client'] == 'client' + str(k)]['e_oracle_'+ str(k)]) * (1 - df_aipw[df_aipw['client'] == 'client' + str(k)]['e_oracle_'+ str(k)])))}"
                )

        # if "oracle" in propensity_federations:
        # if compute_non_param_oracle_weights == True:
        #     cols_true_omegaks = [
        #         "omega" + str(k) + "*"
        #         for k in range(1, len(self.client_params_dict) + 1)
        #     ]
        #     cols_e_oracle_k = [
        #         "e_oracle_" + str(k)
        #         for k in range(1, len(self.clients_list) + 1)
        #     ]
        #     df_aipw[cols_true_omegaks] = oracle_weights.predict_proba(
        #         df_aipw[X_cols]
        #     )  # get the true omegas from KNN
        #     df_aipw["e_oracle"] = membership_weighting_vectorized(
        #         df_aipw[cols_true_omegaks],
        #         df_aipw[cols_e_oracle_k],
        #     )
        # else:
        #     df_aipw["e_oracle"] = membership_weighting_vectorized(
        #         df_aipw[cols_MW],  # estimated omegas
        #         df_aipw[cols_e_oracle_k],
        #     )

        # Compute the outcome estimators
        dict_beta_1 = {
            "local": {
                # "client" + str(k): [] for k in range(1, len(self.clients_list) + 1)
            },
            # "pool": [],
            # "GD": [],
            # '1SIVW': []
        }
        dict_beta_0 = {
            "local": {
                # "client" + str(k): [] for k in range(1, len(self.clients_list) + 1)
            },
            # "pool": [],
            # "GD": [],
            # '1SIVW': []
        }

        if "local" in outcome_estimators:
            for k in range(1, len(self.clients_list) + 1):
                dfk1 = df_aipw[
                    (df_aipw["client"] == "client" + str(k)) & (df_aipw["W"] == 1)
                ]
                print(dfk1.head())
                dict_beta_1["local"]["client" + str(k)] = estimate_beta_MLE(
                    dfk1,
                    self.outcome_cols,
                )
                dfk0 = df_aipw[
                    (df_aipw["client"] == "client" + str(k)) & (df_aipw["W"] == 0)
                ]
                dict_beta_0["local"]["client" + str(k)] = estimate_beta_MLE(
                    dfk0,
                    self.outcome_cols,
                )

        if "pool" in outcome_estimators:
            dict_beta_1["pool"] = estimate_beta_MLE(
                df_aipw[df_aipw["W"] == 1], self.outcome_cols
            )
            dict_beta_0["pool"] = estimate_beta_MLE(
                df_aipw[df_aipw["W"] == 0], self.outcome_cols
            )

        if "GD" in outcome_estimators:
            df_dict = {
                "client" + str(k): df_aipw[df_aipw["client"] == "client" + str(k)]
                for k in range(1, len(self.clients_list) + 1)
            }
            if params_for_federation is not None:
                dict_beta_1["GD"] = fedavg_sgd_server_side(
                    df_dict,
                    self.outcome_cols,
                    beta1_or_beta0_or_gamma="beta1",
                    **params_for_federation,
                )
                dict_beta_0["GD"] = fedavg_sgd_server_side(
                    df_dict,
                    self.outcome_cols,
                    beta1_or_beta0_or_gamma="beta0",
                    **params_for_federation,
                )
            else:
                dict_beta_1["GD"] = (
                    sm.OLS(
                        df_aipw[df_aipw["W"] == 1]["Y"],
                        df_aipw[df_aipw["W"] == 1][self.outcome_cols],
                    )
                    .fit()
                    .params
                )
                dict_beta_0["GD"] = (
                    sm.OLS(
                        df_aipw[df_aipw["W"] == 0]["Y"],
                        df_aipw[df_aipw["W"] == 0][self.outcome_cols],
                    )
                    .fit()
                    .params
                )

        if "1SIVW" in outcome_estimators:
            # One-Shot IVW the beta_1ks
            numerator = np.zeros(self.dim_x + 1)
            denominator = np.zeros((self.dim_x + 1, self.dim_x + 1))

            for k in range(1, len(self.clients_list) + 1):
                beta_k = dict_beta_1["local"]["client" + str(k)]

                # Get the hat Sigma_k_1
                weight_k = (
                    df_aipw[
                        (df_aipw["client"] == "client" + str(k))
                        & (df_aipw["W"] == 1)
                    ][self.outcome_cols].T
                    @ df_aipw[
                        (df_aipw["client"] == "client" + str(k))
                        & (df_aipw["W"] == 1)
                    ][self.outcome_cols]
                )

                numerator += weight_k @ beta_k
                denominator += weight_k
            dict_beta_1["1SIVW"] = np.linalg.inv(denominator) @ numerator

            # One-Shot IVW the beta_1ks
            numerator = np.zeros(self.dim_x + 1)
            denominator = np.zeros((self.dim_x + 1, self.dim_x + 1))
            for k in range(1, len(self.clients_list) + 1):
                beta_k = dict_beta_1["local"]["client" + str(k)]
                # Get the hat Sigma_k_0
                weight_k = (
                    df_aipw[
                        (df_aipw["client"] == "client" + str(k))
                        & (df_aipw["W"] == 0)
                    ][self.outcome_cols].T
                    @ df_aipw[
                        (df_aipw["client"] == "client" + str(k))
                        & (df_aipw["W"] == 0)
                    ][self.outcome_cols]
                )
                numerator += weight_k @ beta_k
                denominator += weight_k
            # Compute the weighted average
            dict_beta_0["1SIVW"] = np.linalg.inv(denominator) @ numerator

        # Compute AIPWs among ["Oracle", "Meta-SW", "GD-MW", "GD-DWparam", "1SIVW-MW", "1SIVW-DWparam"]

        if "Oracle" in final_estimators:
            # df_aipw["Y1*"] = np.where(
            #     df_aipw["W"] == 1, df_aipw["Y"], df_aipw["Y_cf*"]
            # )  # oracle outcomes
            # df_aipw["Y0*"] = np.where(
            #     df_aipw["W"] == 0, df_aipw["Y"], df_aipw["Y_cf*"]
            # )
            for k in range(1, len(self.clients_list) + 1):
                local_aipw_estimates["oracle"]["client" + str(k)].append(
                    aipw_function(
                        df_aipw[df_aipw["client"] == "client" + str(k)],
                        y1_col="Y1*",
                        y0_col="Y0*",
                        e_col="e_oracle",
                    ),
                )
            local_aipw_estimates["oracle"]["total_data"].append(
                aipw_function(
                    df_aipw,
                    y1_col="Y1*",
                    y0_col="Y0*",
                    e_col="e_oracle",
                ),
            )
            estimators_aipw["Oracle"].append(
                local_aipw_estimates["oracle"]["total_data"][-1]
            )

        if any(
            [
                "Meta-SW" in final_estimators,
                "Locals-MW_logistic" in final_estimators,
                "Locals-MW_NN" in final_estimators,
                "Locals-DWparam" in final_estimators,
            ]
        ):
            for k in range(1, len(self.clients_list) + 1):
                df_aipw["hat_mu1_" + str(k)] = np.dot(
                    df_aipw[self.outcome_cols].values,
                    dict_beta_1["local"]["client" + str(k)],
                )
                df_aipw["hat_mu0_" + str(k)] = np.dot(
                    df_aipw[self.outcome_cols].values,
                    dict_beta_0["local"]["client" + str(k)],
                )

            if "Meta-SW" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    # Compute local AIPW estimates
                    local_aipw_estimates["Locals"]["client" + str(k)].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_" + str(k),
                            y0_col="hat_mu0_" + str(k),
                            e_col="hat_e_local_" + str(k),
                        )
                    )
                # print(f"Local AIPW estimates: {local_aipw_estimates['Locals']}")
                # Agregate with nk/n
                estimators_aipw["Meta-SW"].append(
                    np.average(
                        [
                            local_aipw_estimates["Locals"]["client" + str(k)][-1]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )

            if "Locals-MW_logistic" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    local_aipw_estimates["Locals-MW"]["client" + str(k)].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_" + str(k),
                            y0_col="hat_mu0_" + str(k),
                            e_col="hat_e_MW_logistic",
                        )
                    )
                # Agregate with nk/n
                estimators_aipw["Locals-MW_logistic"].append(
                    np.average(
                        [
                            local_aipw_estimates["Locals-MW_logistic"][
                                "client" + str(k)
                            ][-1]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )
            if "Locals-MW_NN" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    local_aipw_estimates["Locals-MW"]["client" + str(k)].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_" + str(k),
                            y0_col="hat_mu0_" + str(k),
                            e_col="hat_e_MW_NN",
                        )
                    )
                # Agregate with nk/n
                estimators_aipw["Locals-MW_NN"].append(
                    np.average(
                        [
                            local_aipw_estimates["Locals-MW_NN"]["client" + str(k)][
                                -1
                            ]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )

            if "Locals-DWparam" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    local_aipw_estimates["Locals-DWparam"][
                        "client" + str(k)
                    ].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_" + str(k),
                            y0_col="hat_mu0_" + str(k),
                            e_col="hat_e_dwparam",
                        )
                    )
                # Agregate with nk/n
                estimators_aipw["Locals-DWparam"].append(
                    np.average(
                        [
                            local_aipw_estimates["Locals-DWparam"][
                                "client" + str(k)
                            ][-1]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )

        if any(
            [
                "GD-MW_logistic" in final_estimators,
                "GD-MW_NN" in final_estimators,
                "GD-DWparam" in final_estimators,
            ]
        ):
            df_aipw["hat_mu1_GD"] = np.dot(
                df_aipw[self.outcome_cols].values,
                dict_beta_1["GD"],
            )
            df_aipw["hat_mu0_GD"] = np.dot(
                df_aipw[self.outcome_cols].values,
                dict_beta_0["GD"],
            )
            if "GD-MW_logistic" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    local_aipw_estimates["GD-MW_logistic"][
                        "client" + str(k)
                    ].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_GD",
                            y0_col="hat_mu0_GD",
                            e_col="hat_e_MW_logistic",
                        )
                    )
                estimators_aipw["GD-MW_logistic"].append(
                    np.average(
                        [
                            local_aipw_estimates["GD-MW_logistic"][
                                "client" + str(k)
                            ][-1]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )
            if "GD-MW_NN" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    local_aipw_estimates["GD-MW_NN"]["client" + str(k)].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_GD",
                            y0_col="hat_mu0_GD",
                            e_col="hat_e_MW_NN",
                        )
                    )
                estimators_aipw["GD-MW_NN"].append(
                    np.average(
                        [
                            local_aipw_estimates["GD-MW_NN"]["client" + str(k)][-1]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )

            if "GD-DWparam" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    local_aipw_estimates["GD-DWparam"]["client" + str(k)].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_GD",
                            y0_col="hat_mu0_GD",
                            e_col="hat_e_dwparam",
                        )
                    )
                estimators_aipw["GD-DWparam"].append(
                    np.average(
                        [
                            local_aipw_estimates["GD-DWparam"]["client" + str(k)][
                                -1
                            ]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )

        if any(
            [
                "1SIVW-MW_logistic" in final_estimators,
                "1SIVW-MW_NN" in final_estimators,
                "1SIVW-DWparam" in final_estimators,
            ]
        ):
            df_aipw["hat_mu1_1SIVW"] = np.dot(
                df_aipw[self.outcome_cols].values,
                dict_beta_1["1SIVW"],
            )
            df_aipw["hat_mu0_1SIVW"] = np.dot(
                df_aipw[self.outcome_cols].values,
                dict_beta_0["1SIVW"],
            )
            if "1SIVW-MW_logistic" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    local_aipw_estimates["1SIVW-MW_logistic"][
                        "client" + str(k)
                    ].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_1SIVW",
                            y0_col="hat_mu0_1SIVW",
                            e_col="hat_e_MW_logistic",
                        )
                    )
                estimators_aipw["1SIVW-MW_logistic"].append(
                    np.average(
                        [
                            local_aipw_estimates["1SIVW-MW_logistic"][
                                "client" + str(k)
                            ][-1]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )
            if "1SIVW-MW_NN" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    local_aipw_estimates["1SIVW-MW_NN"]["client" + str(k)].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_1SIVW",
                            y0_col="hat_mu0_1SIVW",
                            e_col="hat_e_MW_NN",
                        )
                    )
                estimators_aipw["1SIVW-MW_NN"].append(
                    np.average(
                        [
                            local_aipw_estimates["1SIVW-MW_NN"]["client" + str(k)][
                                -1
                            ]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )

            if "1SIVW-DWparam" in final_estimators:
                for k in range(1, len(self.clients_list) + 1):
                    local_aipw_estimates["1SIVW-DWparam"]["client" + str(k)].append(
                        aipw_function(
                            df_aipw[df_aipw["client"] == "client" + str(k)],
                            y1_col="hat_mu1_1SIVW",
                            y0_col="hat_mu0_1SIVW",
                            e_col="hat_e_dwparam",
                        )
                    )
                estimators_aipw["1SIVW-DWparam"].append(
                    np.average(
                        [
                            local_aipw_estimates["1SIVW-DWparam"][
                                "client" + str(k)
                            ][-1]
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                        weights=[
                            len(df_aipw[df_aipw["client"] == "client" + str(k)])
                            / len(df_aipw)
                            for k in range(1, len(self.clients_list) + 1)
                        ],
                    )
                )

        if "pool" in final_estimators:
            for k in range(1, len(self.clients_list) + 1):
                local_aipw_estimates["pool"]["client" + str(k)].append(
                    aipw_function(
                        df_aipw[df_aipw["client"] == "client" + str(k)],
                        y1_col="hat_mu1_GD",
                        y0_col="hat_mu0_GD",
                        e_col="hat_e_pool",
                    )
                )
            estimators_aipw["pool"].append(
                np.average(
                    [
                        local_aipw_estimates["pool"]["client" + str(k)][-1]
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                    weights=[
                        len(df_aipw[df_aipw["client"] == "client" + str(k)])
                        / len(df_aipw)
                        for k in range(1, len(self.clients_list) + 1)
                    ],
                )
            )

    return estimators_aipw, df_aipw