# %%
import pandas as pd
import wandb
import numpy as np

api = wandb.Api()
entity, project = "???", "weighted-dp"  # set to your entity and project
runs = api.runs(entity + "/" + project)

summary_list = []
for run in runs:
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary_dict = {
        k: v
        for k, v in run.summary._json_dict.items()
        if (not k.startswith("_")) or (k == "_timestamp")
    }

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_dict = {k: v for k, v in run.config.items() if not k.startswith("_")}

    # .name is the human-readable name of the run.
    name_dict = {"name": run.name}

    summary_list.append(dict(**summary_dict, **config_dict, **name_dict))

# first observations are most recent
runs_df = pd.DataFrame(summary_list)

# %%
configs = [
    "epsilon",
    "epsilon_iw_perc",
    "delta_iw_perc",
    "model_class",
    "fn_csv",
    "seed",
    "reg",
]

iw_metrics = [
    "logreg knn classification AUC",
    "none logreg classification AUC",
    "generator knn classification AUC",
    "generator logreg classification AUC",
    "none svm classification AUC",
    "mlp rf classification AUC",
    "beta nn classification AUC",
    "none rf classification AUC",
    "none knn classification AUC",
    "mlp logreg classification AUC",
    "mlp nn classification AUC",
    "beta logreg classification AUC",
    "logreg beta_mse classification AUC",
    "beta_unbiased knn classification AUC",
    "beta_unbiased logreg classification AUC",
    "mlp svm classification AUC",
    "beta knn classification AUC",
    "mlp beta_mse classification AUC",
    "beta_unbiased nn classification AUC",
    "beta rf classification AUC",
    "beta_unbiased rf classification AUC",
    "beta svm classification AUC",
    "generator nn classification AUC",
    "generator svm classification AUC",
    "logreg logreg classification AUC",
    "beta_unbiased beta_mse classification AUC",
    "mlp knn classification AUC",
    "logreg nn classification AUC",
    "logreg rf classification AUC",
    "logreg svm classification AUC",
    "generator rf classification AUC",
    "beta beta_mse classification AUC",
    "beta_unbiased svm classification AUC",
    "none nn classification AUC",
    "none beta_mse classification AUC",
    "generator beta_mse classification AUC",
    "beta time",
    "none time",
    "beta_unbiased time",
    "mlp time",
    "logreg time",
    "generator time",
    "tf_mlp logreg classification AUC",
    "tf_mlp nn classification AUC",
    "generator test wst",
    "none test wst",
    "logreg test wst",
    "generator test mmd",
    "none test mmd",
    "mlp test mmd",
    "beta test mmd",
    "logreg test mmd",
    "mlp test wst",
    "beta test wst",
    "logreg test wst",
    "generator test mmd",
    "none test mmd",
    "mlp test mmd",
    "beta test mmd",
    "logreg test mmd",
    "mlp test wst",
    "beta test wst",
]

metrics = iw_metrics + [
    # "beta_mse",
    # "debiased_beta_mse",
]


# %%
# specific preprocessing
this_runs_df = runs_df.iloc[
    : runs_df[runs_df["name"] == "smitten-infatuation-874"].index[0]
]

this_runs_df = this_runs_df[this_runs_df[metrics].isnull().sum(axis=1) < len(metrics)]

this_runs_df = this_runs_df.loc[
    this_runs_df.isnull().sum(1).sort_values(ascending=1).index
]

this_runs_df.drop_duplicates(
    subset=configs, keep="first", inplace=True, ignore_index=True
)

this_runs_df = this_runs_df.dropna(axis=1, how="all")

this_runs_df.drop(["iw"], axis=1, inplace=True)
# %%


def rename_cols_for_pd_wide_to_long(col_names):
    new_column_names = [c for c in col_names]
    new_column_names = [
        c
        if " classification AUC" not in c
        else c.replace(" classification AUC", "_classification_AUC")
        for c in new_column_names
    ]
    new_column_names = [
        c if " test " not in c else c.replace(" test ", " test_")
        for c in new_column_names
    ]
    new_column_names = [
        c if "test" not in c else ";".join(c.split(" ")[::-1]) for c in new_column_names
    ]
    new_column_names = [
        c if "_classification_AUC" not in c else ";".join(c.split(" ")[::-1])
        for c in new_column_names
    ]
    new_column_names = [
        c if "time" not in c else ";".join(c.split(" ")[::-1]) for c in new_column_names
    ]
    return new_column_names


def print_df_duplicates(df, columns):
    return df[df.duplicated(subset=columns, keep=False)]


def slice_df(df, value_dict):
    for k, v in value_dict.items():
        df = df[df[k].isin(v)]
    return df


# %%
column_names = this_runs_df.columns.to_list()
new_columns = rename_cols_for_pd_wide_to_long(column_names)

metrics_set = list(
    set(
        m.split(";")[0]
        for m in new_columns
        if ("_classification_AUC" in m) or ("test_" in m) or ("time;" in m)
    )
)

this_runs_df.columns = new_columns
long_runs_df = pd.wide_to_long(
    this_runs_df, stubnames=metrics_set, i="name", j="iw", sep=";", suffix="\w+"
).reset_index()

# %%

for fn_csv in long_runs_df.fn_csv.unique():
    for epsilon in long_runs_df.epsilon.unique():
        for epsilon_iw_perc in long_runs_df.epsilon_iw_perc.unique():
            for delta_iw_perc in long_runs_df.delta_iw_perc.unique():
                this_df = slice_df(
                    long_runs_df,
                    {
                        "epsilon": [epsilon],
                        "epsilon_iw_perc": [epsilon_iw_perc],
                        "delta_iw_perc": [delta_iw_perc],
                        "fn_csv": [fn_csv],
                    },
                ).reset_index()
                for metric in metrics_set:
                    ave_df = this_df.groupby(["iw", "model_class"]).mean().reset_index()
                    values = pd.pivot_table(
                        ave_df,
                        values=metric,
                        index=["iw"],
                        columns=["model_class"],
                        aggfunc=np.sum,
                    )
                    if len(values):
                        print(
                            "\n",
                            fn_csv,
                            epsilon,
                            epsilon_iw_perc,
                            delta_iw_perc,
                            metric,
                        )
                        print(values)

# %%
