import argparse
from email.policy import default
import utils
import pandas as pd
import os
import pickle as pkl
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import wandb
import hydra
import os
import omegaconf
import time

os.environ["WANDB_SILENT"] = "true"

from models.uniform import sample_uniform
from models.dp_cgan import DP_CGAN
from models.dpgan import fit_DPGAN
from models.cgan import CGAN
from models.gan import GAN

# from models.mwem import fit_MWEM
# from models.pategan import fit_PATEGAN
from models.privbayes import PrivBayes_runner
import dp_weights
import evaluation
from pprint import pprint
import plots
from evaluation import (
    compute_mmd,
    wasserstein_distance,
)

model_map = {
    "UNIFORM": sample_uniform,
    "DPCGAN": DP_CGAN,
    "DPGAN": fit_DPGAN,
    "GAN": GAN,
    "PRIVBAYES": PrivBayes_runner,
    # "PATEGAN": fit_PATEGAN,
    "CGAN": CGAN,
    # "MWEM": fit_MWEM,
}

data_map = {
    "gmm_grid",
    "gmm",
    "triangle",
}


def parse_args():
    parser = argparse.ArgumentParser()

    # region setup args
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--reload", type=bool, default=False)
    parser.add_argument("--log-file", type=str, default="test")
    parser.add_argument("--plot", type=bool, default=False)
    parser.add_argument("--wandb-entity", type=str, default="???")
    # endregion

    # region private data args
    parser.add_argument("--fn-csv", default="breast", help="input file")
    parser.add_argument(
        "--fn-target-csv", default=None, help="target input file or column name"
    )
    parser.add_argument("--train-ratio", type=float, default=0.8)
    # endregion

    # region privatisation args
    parser.add_argument("--epsilon", type=float, default=1.0)
    parser.add_argument("--epsilon-iw-perc", type=float, default=0.5)
    parser.add_argument("--delta-iw-perc", type=float, default=0.0)
    parser.add_argument(
        "--model-class",
        nargs="?",
        default="DPCGAN",
        choices=model_map.keys(),
        help="model class, choices: " + " ".join(model_map.keys()),
    )
    # endregion

    # region privbayes
    parser.add_argument("--category_threshold", default=10)
    parser.add_argument("--num_parents", default=3)
    parser.add_argument("--histogram_bins", default="auto")
    parser.add_argument("--other_categorical", default=[])
    parser.add_argument("--keys", default=[])
    # endregion

    # region synthetic data args
    parser.add_argument(
        "--rel-size-synth",
        type=float,
        default=1.0,
        help="choose this relatively large for MH chains",
    )
    # endregion

    # region gan args
    parser.add_argument(
        "--gan-hidden-dim",
        type=int,
        default=None,
        help="hidden dimension of differentially private GAN",
    )
    # endregion

    # region iw args
    parser.add_argument("--iw", type=str, default="none")
    parser.add_argument("--main-dir", type=str, default=".")
    parser.add_argument(
        "--reg",
        type=float,
        default=1,
        help="regularization parameter for logistic regression",
    )
    parser.add_argument(
        "--alpha", type=float, default=1.0, help="regularization parameter for weights"
    )
    # endregion

    return parser.parse_args()


@hydra.main(config_path="conf", config_name="config")
def main(args):

    os.chdir("/home/ma/s/sg20/projects/weighted_dp")
    CURR_DIR = os.getcwd()
    print(CURR_DIR)
    iw_list = args.iw.split(",")
    epsilon_iw = args.epsilon_iw_perc * args.epsilon

    if args.model_class in ["MWEM", "PRIVBAYES"] or (
        args.model_class == "UNIFORM"
        and ((args.fn_csv != "gmm") and (args.fn_csv != "triangle"))
    ):
        iw_list = [iw for iw in iw_list if iw != "generator"]
    if (args.model_class == "MWEM") or (args.model_class == "PRIVBAYES"):
        args.delta_iw_perc = 1.0

    if False:

        args = utils.ObjectView(
            {
                "seed": 0,
                "reload": True,
                "plot": True,
                "run_name": "debug",
                "log_file": "test",
                "fn_csv": "gmm",
                "fn_target_csv": "data/breast_target.csv",
                "train_ratio": 0.8,
                "epsilon_data": 1.0,
                "epsilon_iw": 1.0,
                "delta_data": 1.0,
                "delta_iw": 1.0,
                "model_class": "UNIFORM",
                "rel_size_synth": 0.5,
                "gan_hidden_dim": None,
                "iw": "none",
                "reg": 1,
                "alpha": 1.0,
            }
        )

    np.random.seed(args.seed)

    # region set up wandb
    # os.environ["WANDB_MODE"] = "offline"
    try:
        wandb.init(
            project="weighted-dp",
            entity=args.wandb_entity,
            reinit=True,
            settings=wandb.Settings(start_method="thread"),
            config=omegaconf.OmegaConf.to_container(
                args, resolve=True, throw_on_missing=True
            ),
        )
    # try:
    #     wandb.config = omegaconf.OmegaConf.to_container(
    #         args, resolve=True, throw_on_missing=True
    #     )
    #     # wandb.config.update(args.__dict__["_content"])
    except ValueError:
        wandb.init(
            project="weighted-dp",
            entity=args.wandb_entity,
            reinit=True,
            settings=wandb.Settings(start_method="thread"),
        )
        wandb.config.update(args)
    # endregion

    # region load data
    print("Load private data...")
    if not os.path.exists(f"{args.main_dir}/data"):
        os.makedirs(f"{args.main_dir}/data")

    if args.fn_csv in ["gmm_grid", "gmm"]:
        data_file_name = utils.turn_dict_to_file_name(args.__dict__, ["fn_csv", "seed"])
        data_file_name = f"{args.main_dir}/data/{data_file_name}"
        if os.path.exists(data_file_name) and not args.reload:
            with open(data_file_name, "rb") as handle:
                features, mu_gmm, sigma_gmm, target, beta = pkl.load(handle)
        else:
            from load_data import generate_grid_gmm_samples

            features, mu_gmm, sigma_gmm, target, beta = generate_grid_gmm_samples(
                args.seed,
                sqrt_mixtures=1 + (args.fn_csv == "gmm_grid") * 4,
                target=args.fn_target_csv,
            )
            with open(data_file_name, "wb") as handle:
                pkl.dump(
                    (features, mu_gmm, sigma_gmm, target, beta),
                    handle,
                    protocol=pkl.HIGHEST_PROTOCOL,
                )
    elif args.fn_csv in ["triangle"]:
        data_file_name = utils.turn_dict_to_file_name(args.__dict__, ["fn_csv", "seed"])
        data_file_name = f"{args.main_dir}/data/{data_file_name}.pkl"
        if os.path.exists(data_file_name) and not args.reload:
            with open(data_file_name, "rb") as handle:
                features, target, beta = pkl.load(handle)
        else:
            from load_data import generate_triangle

            features, target, beta = generate_triangle(
                args.seed,
                target=args.fn_target_csv,
            )
            with open(data_file_name, "wb") as handle:
                pkl.dump(
                    (features, target, beta),
                    handle,
                    protocol=pkl.HIGHEST_PROTOCOL,
                )
        mu_gmm, sigma_gmm = None, None

    elif args.fn_csv in ["mnist"]:
        import tensorflow as tf

        (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
        train_images = (train_images.astype("float32") - 127.5).reshape(
            (-1, 784)
        ) / 127.5  # Normalize the images to [-1, 1]
        train_labels = train_labels.reshape((60000, 1))
        # train_labels_vec = np.zeros(
        #     (len(train_labels), COND_num_classes), dtype="float32"
        # )
        # for i, label in enumerate(train_labels):
        #     train_labels_vec[i, int(train_labels[i])] = 1.0
        features = pd.DataFrame(train_images)
        target = pd.DataFrame(train_labels)
        mu_gmm, sigma_gmm, beta = None, None, None
    else:
        data_file_name = f"{args.main_dir}/data/{args.fn_csv}.csv"
        features = pd.read_csv(data_file_name, header=None)
        # if args.fn_csv in ["breast", "credit"]:
        #     args.fn_target_csv = f"{args.main_dir}/data/{args.fn_csv}_target.csv"
        if args.fn_target_csv is not None and os.path.exists(args.fn_target_csv):
            target = pd.read_csv(args.fn_target_csv, header=None)
        else:
            target = features.iloc[:, -1]
            features = features.iloc[:, :-1]
        mu_gmm, sigma_gmm, beta = None, None, None
        # pums age,sex,educ,race,income,married

    # endregion

    # region process data
    if target is not None:
        df = pd.concat([features, target], axis=1)
        df.columns = list(df.columns[:-1]) + ["target"]
        target_type = "classification" * (
            np.unique(target).size <= 10
        ) + "regression" * (np.unique(target).size > 10)
    else:
        df = features

    # split data
    df_train_old, df_test_old = train_test_split(
        df, train_size=args.train_ratio, shuffle=True, random_state=args.seed
    )
    delta_all = 1 / len(df_train_old)
    delta_iw = delta_all * args.delta_iw_perc
    delta_data = delta_all - delta_iw

    # normalize data
    scaler = MinMaxScaler(
        feature_range=(-1, 1),
    )
    if target_type == "classification":
        df_train_features = pd.DataFrame(
            scaler.fit_transform(df_train_old.iloc[:, :-1].copy()),
            columns=df_train_old.columns[:-1],
        )
        df_test_features = pd.DataFrame(
            scaler.transform(df_test_old.iloc[:, :-1]),
            columns=df_test_old.iloc[:, :-1].columns,
        )
        df_train = pd.concat(
            [df_train_features, df_train_old.iloc[:, -1:].reset_index(drop=True)],
            axis=1,
        )
        df_test = pd.concat(
            [df_test_features, df_test_old.iloc[:, -1:].reset_index(drop=True)], axis=1
        )
    else:
        df_train = pd.DataFrame(
            scaler.fit_transform(df_train_old),
            columns=df_train_old.columns,
        )
        df_test = pd.DataFrame(
            scaler.fit_transform(df_test_old),
            columns=df_test_old.columns,
        )
    # endregion

    # region model parameters
    d = df_train.shape[1]
    n_d = len(df_train)
    n_s = min(int(n_d * args.rel_size_synth), 5000)
    data_name = args.fn_csv.split("/")[-1].split(".")[0].split("_")[0]

    model_params = utils.load_params(
        args=args,
        wandb=wandb,
        n_s=n_s,
        df_train=df_train,
        scaler=scaler,
        mu_gmm=mu_gmm,
        sigma_gmm=sigma_gmm,
        beta=beta,
        delta_data=delta_data,
    )
    model_name_params = {
        **model_params,
        **{"data": data_name, "model": args.model_class},
    }
    model_name_params = {
        k: v
        for k, v in model_name_params.items()
        if k
        not in {
            "n_s",
            "x_train",
            "df",
            "df_test",
            "df_",
            "scaler",
            "mu_gmm",
            "sigma_gmm",
            "gmm",
            "test",
            "out_dir",
            "category_threshold",
            "histogram_bins",
            "other_categorical",
            "keys",
            "in_path",
            "out_dir",
            "wandb",
            "config",
            "wandb_entity",
            "reload",
        }
    }
    model_id = "_".join([f"{k[:3]}_{str(v)[:5]}" for k, v in model_name_params.items()])
    model_params.update({"model_id": model_id})

    if not os.path.exists(f"{args.main_dir}/dp_data"):
        os.makedirs(f"{args.main_dir}/dp_data")
    sample_path = f"{args.main_dir}/dp_data/" + data_name + "_" + model_id + ".pkl"

    print(model_name_params)
    # endregion

    # region simulate dp data
    print("Privatise data...")
    if (
        os.path.exists(sample_path)
        and os.path.getsize(sample_path) > 0
        and not args.reload
    ):
        with open(sample_path, "rb") as file:
            try:
                save_var = pkl.load(file)
            except:
                print(sample_path)
                raise
            synth_data, log_model_iw = (
                save_var["x"],
                save_var["w"],
            )
            if len(synth_data) < n_s:
                args.reload = True
    else:
        args.reload = True

    if args.reload:
        synth_data, log_model_iw = model_map[args.model_class](**model_params)
        synth_data = pd.DataFrame(synth_data)
        if target is not None and target_type == "classification":
            synth_data.iloc[:, -1] = synth_data.iloc[:, -1].astype("int")
        print(target_type)
        print("########## TRUE DATA ##########")
        print(df_train.describe())
        print("########## SYNTH DATA ##########")
        print(synth_data.describe())
        synth_data.iloc[:, :-1] = np.clip(synth_data.iloc[:, :-1].values, -1, 1)
        synth_data = pd.DataFrame(synth_data)

        with open(sample_path, "wb") as handle:
            pkl.dump(
                {"x": synth_data, "w": log_model_iw},
                handle,
                protocol=pkl.HIGHEST_PROTOCOL,
            )

    # if args.plot:
    #     plots.kde2d_plot(
    #         synth_data.iloc[:, :2],
    #         synth_data.iloc[:, synth_data.shape[1] - 1],
    #         name=f"{wandb.run.id}_synth",
    #     )
    # endregion

    # region importance weighting
    print("Computing dp weights...")
    times_dict = {}

    if len(synth_data.iloc[:, -1].unique()) > 2:
        dummies = pd.get_dummies(df_train.iloc[:, -1], prefix="t_")
        df_train_dummy = pd.concat([df_train.iloc[:, :-1], dummies], axis=1)

        dummies = pd.get_dummies(df_test.iloc[:, -1], prefix="t_")
        df_test_dummy = pd.concat([df_test.iloc[:, :-1], dummies], axis=1)

        dummies = pd.get_dummies(synth_data.iloc[:, -1], prefix="t_")
        synth_data_dummy = pd.concat([synth_data.iloc[:, :-1], dummies], axis=1)
    else:
        df_train_dummy = df_train
        df_test_dummy = df_test
        synth_data_dummy = synth_data

    log_weights_dict = {}
    for iw in iw_list:
        iw_model = f"{model_id}_{n_s}_{iw}"
        weights_path = f"{args.main_dir}/weights/{iw_model}.pkl"

        if not os.path.exists(f"{args.main_dir}/weights"):
            os.makedirs(f"{args.main_dir}/weights")

        if (
            os.path.exists(weights_path)
            and os.path.getsize(weights_path) > 0
            and not args.reload
        ):
            with open(weights_path, "rb") as file:
                log_weights_dict[iw] = pkl.load(file)
        else:

            if "generator" == iw:
                start = time.time()
                log_weights_dict["generator"] = (
                    log_model_iw[:, 0] if len(log_model_iw.shape) == 2 else log_model_iw
                )

            elif "none" == iw:
                start = time.time()
                log_weights_dict["none"] = np.zeros(len(synth_data))

            elif "logreg" == iw:
                start = time.time()
                log_weights_dict["logreg"], _, = dp_weights.dp_class_logreg(
                    df_train_dummy,
                    df_test_dummy,
                    synth_data_dummy,
                    epsilon_iw,
                    args.reg,
                    n_d,
                )

            elif "beta" == iw:
                start = time.time()
                _, log_weights_dict["beta"] = dp_weights.dp_class_logreg(
                    df_train_dummy,
                    df_test_dummy,
                    synth_data_dummy,
                    epsilon_iw,
                    args.reg,
                    n_d,
                    unbiased=False,
                )

            elif "beta_unbiased" == iw:
                start = time.time()
                _, log_weights_dict["beta_unbiased"] = dp_weights.dp_class_logreg(
                    df_train_dummy,
                    df_test_dummy,
                    synth_data_dummy,
                    epsilon_iw,
                    args.reg,
                    n_d,
                    unbiased=True,
                    args=args,
                    wandb=wandb,
                )

            # elif "priv_mlp" == iw:
            #     # use weights of noisy classification network
            #     raise NotImplementedError
            #     # log_weights_dict["priv_mlp"] = dp_weights.opacus_class_mlp(
            #     #     df_train,
            #     #     df_test,
            #     #     synth_data,
            #     #     epsilon_iw,
            #     #     delta_iw,
            #     #     model_id,
            #     #     args.reload,
            #     #     seed=args.seed,
            #     #     private=True,
            #     # )

            elif "tf_mlp" == iw:
                # use weights of noisy classification network
                start = time.time()
                log_weights_dict["tf_mlp"] = dp_weights.dp_class_adam_mlp(
                    df_train_dummy,
                    df_test_dummy,
                    synth_data_dummy,
                    epsilon_iw,
                    delta_iw,
                    model_id,
                    args.reload,
                    seed=args.seed,
                    private=True,
                )

            elif "mlp" == iw:
                start = time.time()
                log_weights_dict["mlp"], logits_test_df = dp_weights.class_mlp(
                    df_train_dummy,
                    df_test_dummy,
                    synth_data_dummy,
                    10000,
                    10000,
                    model_id,
                    args.reload,
                    private=False,
                    seed=args.seed,
                )

            end = time.time()
            wandb.log({f"{iw} time": end - start})
            times_dict[iw] = end - start

            # use weights of noisy classification network
            #     # log_weights["mlp"], logits_test_df["mlp"] = dp_class_mlp(df_train, df_test, synth_data, epsilon_iw, args.delta, model_id, args.reload)

            # # noise output of log reg
            # elif "lapl" == iw:
            #     log_weights, _ = dp_weights.dp_class_logreg(
            #         df_train, df_test, synth_data, epsilon_iw, args.reg, n_d, d
            #     )

            #     n_s_ = int((0.2 * n_d * args.reg * epsilon_iw) / (2 * np.sqrt(d)))
            #     choices_ = np.argsort(log_weights)[-n_s_:]
            #     sens = 2 * np.sqrt(d) / (n_d / n_s_ * args.reg * epsilon_iw)

            #     log_noise_lapl = np.random.laplace(
            #         loc=np.log(1 - sens**2), scale=sens, size=(n_s_ + len(df_test))
            #     )
            #     log_weights_dict["lapl"] = np.zeros_like(log_weights) + min(log_weights)
            #     log_weights_dict["lapl"][choices_] = (
            #         log_weights[choices_] + log_noise_lapl[:n_s_]
            #     )
            #     # logits_test_df = logits_test_df + log_noise_lapl[n_s_:]
            #     if False:
            #         log_weights_dict["lapl"][choices_] = denoiser.denoise_log_weights(
            #             log_weights_dict["lapl"][choices_],
            #             synth_data.iloc[choices_],
            #             error="lapl",
            #         )
            #         print("norm error: ", (log_noise_lapl[:n_s_] ** 2).sum())
            #         print(
            #             "denoised norm error: ",
            #             (
            #                 (log_weights_dict["lapl"][choices_] - log_weights[choices_])
            #                 ** 2
            #             ).sum(),
            #         )

            # elif "norm_dwork" == iw:
            #     log_weights, _ = dp_weights.dp_class_logreg(
            #         df_train, df_test, synth_data, epsilon_iw, args.reg, n_d, d
            #     )

            #     n_s_ = int((0.2 * n_d * args.reg * epsilon_iw) / (2 * np.sqrt(d)))
            #     choices_ = np.argsort(log_weights)[:n_s_]
            #     sens = 2 * np.sqrt(d) / (n_d / n_s_ * args.reg * epsilon_iw)

            #     log_weights, _ = dp_weights.dp_class_logreg(
            #         df_train, df_test, synth_data, epsilon_iw, args.reg, n_d, d
            #     )

            #     gamma2 = 2 * sens**2 * np.log(2 / delta_iw)
            #     log_noise_normal = np.random.normal(
            #         loc=-gamma2, scale=np.sqrt(gamma2), size=(n_s_ + len(df_test))
            #     )
            #     log_weights_dict["norm_dwork"] = np.zeros_like(log_weights) + min(
            #         log_weights
            #     )
            #     log_weights_dict["norm_dwork"][choices_] = (
            #         log_weights[choices_] + log_noise_normal[:n_s_]
            #     )
            #     if False:
            #         log_weights_dict["norm"][choices_] = denoiser.denoise_log_weights(
            #             log_weights_dict["norm"][choices_],
            #             synth_data.iloc[choices_],
            #             error="lapl",
            #         )
            #         print("norm error: ", (log_noise_normal[:n_s_] ** 2).sum())
            #         print(
            #             "denoised norm error: ",
            #             (
            #                 (log_weights_dict["norm"][choices_] - log_weights[choices_])
            #                 ** 2
            #             ).sum(),
            #         )
            #     # logits_test_df = logits_test_df + log_noise_normal[n_s_:]

            # elif "norm_borja" == iw:
            #     log_weights, _ = dp_weights.dp_class_logreg(
            #         df_train, df_test, synth_data, epsilon_iw, args.reg, n_d, d
            #     )

            #     n_s_ = int((0.2 * n_d * args.reg * epsilon_iw) / (2 * np.sqrt(d)))
            #     choices_ = np.argsort(log_weights)[:n_s_]

            #     GS = (2 * np.sqrt(d)) / args.reg
            #     sd_gaussian = dp_weights.calibrateAnalyticGaussianMechanism(
            #         epsilon_iw, delta_iw, GS, tol=1.0e-12
            #     )
            #     gamma2 = 2 * sens**2 * np.log(2 / delta_iw)
            #     log_noise_normal = np.random.normal(
            #         loc=-gamma2, scale=sd_gaussian, size=(n_s_ + len(df_test))
            #     )
            #     log_weights_dict["norm_borja"] = np.zeros_like(log_weights) + min(
            #         log_weights
            #     )
            #     log_weights_dict["norm_borja"][choices_] = (
            #         log_weights[choices_] + log_noise_normal[:n_s_]
            #     )

            # else:
            #     raise NotImplementedError(f"iw input {iw} is not implemented")

            # log_weights_dict = {
            #     k: np.clip(v, -30, 30) for k, v in log_weights_dict.items()
            # }

            # with open(weights_path, "wb") as handle:
            #     pkl.dump(log_weights_dict[iw], handle, protocol=pkl.HIGHEST_PROTOCOL)

    if args.plot:
        plots.kde2d_plot_all_weights(
            synth_data.iloc[:, :2],
            synth_data.iloc[:, synth_data.shape[1] - 1] if target else None,
            log_weights_dict=log_weights_dict,
            name=f"{wandb.run.id}_synth_all_weighted",
            # features_true=df_train.iloc[:, :2],
            # target_true=df_train.iloc[:, -1] if target else None,
        )

    # endregion

    # region evaluation metrics
    print("Evaluate...")
    eval_metrics = {}

    # args.data, args.n_s = model_name_params["data"], len(
    #     list(log_weights_dict.values())[0]
    # )

    if args.fn_csv == "mnist":
        methods = ["logreg", "nn"]
    elif target_type == "classification":
        methods = ["logreg", "rf", "nn", "svm", "beta_mse", "knn"]
    else:
        methods = ["linear", "nn", "beta_mse"]

    for iw, log_weights in log_weights_dict.items():
        scaled_weights = np.exp(log_weights) ** args.alpha
        scaled_weights = scaled_weights / sum(scaled_weights) * len(scaled_weights)
        if target is not None and target_type == "classification":
            for method in methods:
                print(f"{iw} {method}")
                eval_metrics[
                    f"{iw} {method} classification AUC"
                ] = evaluation.downstream_classification(
                    synth_data.copy(),
                    df_test.copy(),
                    method=method,
                    iw_weight=scaled_weights,
                )
                wandb.log(
                    {
                        f"{method} classification AUC": eval_metrics[
                            f"{iw} {method} classification AUC"
                        ]
                    }
                )

                # utils.log(
                #     args,
                #     f"{iw}",
                #     f"{method} AUC",
                #     eval_metrics[f"{iw} {method} classification AUC"],
                # )
        elif target is not None and target_type == "regression":
                print(f"{iw} {method}")
                eval_metrics[
                    f"{iw} {method} classification AUC"
                ] = evaluation.downstream_regression(
                    synth_data.copy(),
                    df_test.copy(),
                    method=method,
                    iw_weight=scaled_weights,
                )
                wandb.log(
                    {
                        f"{method} classification AUC": eval_metrics[
                            f"{iw} {method} classification AUC"
                        ]
                    }
                )

        eval_metrics[f"{iw} test wst"] = wasserstein_distance(
            df_test, synth_data, scaled_weights
        )
        eval_metrics[f"{iw} test mmd"] = compute_mmd(
            df_test, synth_data, scaled_weights
        )
            # wandb.log(
            #     {
            #         f"{method} classification AUC": eval_metrics[
            #             f"{iw} {method} classification AUC"
            #         ]
            #     }
            # )
            # wandb.log(
            #     {
            #         f"{method} classification AUC": eval_metrics[
            #             f"{iw} {method} classification AUC"
            #         ]
            #     }
            # )
            # utils.log(
            #     args,
            #     "test_wst",
            #     eval_metrics[f"{iw} test wst"],
            # )
            # utils.log(
            #     args,
            #     "test_mmd",
            #     eval_metrics[f"{iw} test mmd"],
            # )

    # logging
    wandb.log(eval_metrics)
    pprint(eval_metrics)
    wandb.finish()

    # for iw in log_weights_dict.keys():
    #     try:
    #         wandb.init(
    #             project="weighted-dp-final",
    #             entity=args.wandb_entity,
    #             reinit=True,
    #             settings=wandb.Settings(start_method="thread"),
    #             config=omegaconf.OmegaConf.to_container(
    #                 args, resolve=True, throw_on_missing=True
    #             ),
    #         )
    #     except ValueError:
    #         wandb.init(
    #             project="weighted-dp-final",
    #             entity=args.wandb_entity,
    #             reinit=True,
    #             settings=wandb.Settings(start_method="thread"),
    #         )
    #         wandb.config.update(args)
    #         wandb.config.update({"this_iw": iw})

    #     wandb.log({"time": times_dict[iw]})

    #     for method in [m + " classification AUC" for m in methods] + [
    #         "test wst",
    #         "test mmd",
    #     ]:
    #         wandb.log({f"{method}": eval_metrics[f"{iw} {method}"]})
    #     wandb.finish()
    # endregion


if __name__ == "__main__":

    start = time.time()
    try:
        main()
    except SystemExit:
        end = time.time()
        if (end - start) < 0.5:
            args = parse_args()
            main(args)
