import os
import pandas as pd
import wandb

from utils_plots import generate_plot, generate_plot_all_attributes

attribute_names = [
    "5_o_Clock_Shadow",
    "Arched_Eyebrows",
    "Attractive",
    "Bags_Under_Eyes",
    "Bald",
    "Bangs",
    "Big_Lips",
    "Big_Nose",
    "Black_Hair",
    "Blond_Hair",
    "Blurry",
    "Brown_Hair",
    "Bushy_Eyebrows",
    "Chubby",
    "Double_Chin",
    "Eyeglasses",
    "Goatee",
    "Gray_Hair",
    "Heavy_Makeup",
    "High_Cheekbones",
    "Male",
    "Mouth_Slightly_Open",
    "Mustache",
    "Narrow_Eyes",
    "No_Beard",
    "Oval_Face",
    "Pale_Skin",
    "Pointy_Nose",
    "Receding_Hairline",
    "Rosy_Cheeks",
    "Sideburns",
    "Smiling",
    "Straight_Hair",
    "Wavy_Hair",
    "Wearing_Earrings",
    "Wearing_Hat",
    "Wearing_Lipstick",
    "Wearing_Necklace",
    "Wearing_Necktie",
    "Young",
]
NUM_VIEWS = 2.0
FID_X_MIN = 25
FID_X_MAX = 275
REC_X_MIN = 5500
REC_X_MAX = 12500
COH_Y_MIN = 0.0
COH_Y_MAX = 0.6
DS_Y_MIN = 0.0
DS_Y_MAX = 0.6
xlim_min_cond = 5000
xlim_max_cond = 10000

use_local_wandb = True
if use_local_wandb:
    os.environ["HTTP_PROXY"] = "socks5h://localhost:10080"
    wandb.login(host="PUT YOUR WANDB ENTITY HERE")
# wandb.login(host=os.getenv("WANDB_LOCAL_URL"))
api = wandb.Api()
# Project is specified by <entity/project-name>
runs = api.runs("PUT YOUR WANDB PROJECT HERE")


summary_list, config_list, name_list = [], [], []
for run in runs:
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})

    # .name is the human-readable name of the run.
    name_list.append(run.name)


runs_df = pd.concat(
    [
        pd.DataFrame(summary_list),
        pd.DataFrame(
            [
                {
                    k: v
                    for k, v in x.items()
                    if k != "log" and k != "model" and k != "dataset"
                }
                for x in config_list
            ]
        ),
        pd.DataFrame(
            [{f"dataset.{k}": v for k, v in x["dataset"].items()} for x in config_list]
        ),
        pd.DataFrame(
            [{f"model.{k}": v for k, v in x["model"].items()} for x in config_list]
        ),
        pd.DataFrame(
            [{f"log.{k}": v for k, v in x["log"].items()} for x in config_list]
        ),
        pd.DataFrame(name_list),
    ],
    axis=1,
)


DIR_OUT = "PUT PLOT SAVE DIR HERE"
if not os.path.exists(DIR_OUT):
    os.makedirs(DIR_OUT)

n_epochs_threshold = 10
latent_dim = 128
dataset = "celeba"
num_views = 2
modalities = ["img", "text"]
lr = 2e-4
alpha_annealing = True
model_names = ["unimodal", "joint", "split", "mixedprior"]
model_names = ["unimodal", "joint", "mixedprior"]
model_names_vis = {
    "unimodal": "independent",
    "joint": {
        "avg": "AVG",
        "moe": "MoE",
        "poe": "PoE",
        "mopoe": "MoPoE",
    },
    # "split": "split",
    "mixedprior": "MMVM",
}
wandb_group = "20240513"
dataset_names = ["celeba"]


agg_df = runs_df
# agg_df = agg_df.loc[agg_df["model.beta"] == beta]
agg_df = agg_df.loc[agg_df["model.lr"] == lr]
agg_df = agg_df.loc[agg_df["model.latent_dim"] == latent_dim]
agg_df = agg_df.loc[agg_df["model.alpha_annealing"] == alpha_annealing]
agg_df = agg_df.loc[agg_df["log.wandb_group"] == wandb_group]
agg_df = agg_df.loc[agg_df["epoch"] >= n_epochs_threshold]

col_names = [
    "model.name",
    "model.aggregation",
    "dataset.name",
    "model.beta",
    "model.seed",
]
list_str_coh = []
list_str_ds_agg = []
list_str_ds_uni = []
list_str_fid = []
for m_key in modalities:
    for m_tilde_key in modalities:
        # if m_tilde_key == "text":
        #     continue
        str_coh = "val/coherence/" + m_key + "_to_" + m_tilde_key
        col_names.append(str_coh)
        list_str_coh.append(str_coh)
        if not m_tilde_key == "text":
            str_fid = "val/fid/" + m_key + "_to_" + m_tilde_key
            col_names.append(str_fid)
        list_str_fid.append(str_fid)
    # str_lhood = "val/likelihood/" + m_key
    # col_names.append(str_lhood)
    str_ds_agg = "val/downstream/aggregated/" + m_key
    col_names.append(str_ds_agg)
    list_str_ds_agg.append(str_ds_agg)
    str_ds_uni = "val/downstream/unimodal/" + m_key
    col_names.append(str_ds_uni)
    list_str_ds_uni.append(str_ds_uni)
str_cond_rec = "val/condition_generation/avg_rec_loss"
str_rec = "val/loss/avg_rec_loss_epoch"
col_names.append(str_rec)
col_names.append(str_cond_rec)

dataset_df = agg_df.loc[agg_df["dataset.name"] == dataset]
dataset_df["model.aggregation"] = dataset_df["model.aggregation"].fillna("avg")
print(dataset_df.shape)

single_plots_datasets = True
if single_plots_datasets:
    coh_df = dataset_df[col_names]
    mean_coh = coh_df[list_str_coh].mean(axis=1)
    coh_df["coherence"] = mean_coh
    coh_mean = coh_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    coh_std = coh_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    coh_mean = coh_mean.reset_index()
    coh_std = coh_std.reset_index()
    coh_mean_sorted = coh_mean.sort_values(by=["model.beta"])
    coh_std_sorted = coh_std.sort_values(by=["model.beta"])
    print(coh_mean.shape)
    print(coh_mean_sorted["model.name"].unique())

    str_fn_legend = os.path.join(DIR_OUT, "legend_coherence_recloss.png")
    str_filename_coh = os.path.join(
        DIR_OUT, "plot_coherence_recloss_" + str(dataset) + ".png"
    )
    generate_plot(
        coh_mean_sorted,
        coh_std_sorted,
        str_rec,
        "coherence",
        str_filename_coh,
        "Reconstruction Error",
        "Average Precision",
        "model.aggregation",
        model_names_vis,
        xlim=[REC_X_MAX, REC_X_MIN],
        ylim=[COH_Y_MIN, COH_Y_MAX],
        fn_legend=str_fn_legend,
    )
    str_filename_coh_condrec = os.path.join(
        DIR_OUT, "plot_coherence_condrecloss_" + str(dataset) + ".png"
    )
    coh_mean_sorted.to_csv(os.path.join(DIR_OUT, "celeba_test_results.csv"))
    generate_plot(
        coh_mean_sorted,
        coh_std_sorted,
        str_cond_rec,
        "coherence",
        str_filename_coh_condrec,
        "Conditional Reconstruction Error",
        "Average Precision",
        "model.aggregation",
        model_names_vis,
        xlim=[xlim_max_cond, xlim_min_cond],
        ylim=[COH_Y_MIN, COH_Y_MAX],
    )

    # generate plots for latent representations (aggregated)
    ds_df = dataset_df[col_names]
    mean_lr = ds_df[list_str_ds_agg].mean(axis=1)
    ds_df["downstream_lr_agg"] = mean_lr

    ds_mean = ds_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    ds_std = ds_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    ds_mean = ds_mean.reset_index()
    ds_std = ds_std.reset_index()
    ds_mean_sorted = ds_mean.sort_values(by=["model.beta"])
    ds_std_sorted = ds_std.sort_values(by=["model.beta"])

    str_filename_lr = os.path.join(
        DIR_OUT, "plot_downstream_agg_recloss_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_mean_sorted,
        ds_std_sorted,
        str_rec,
        "downstream_lr_agg",
        str_filename_lr,
        "Reconstruction Error",
        "Average Precision",
        "model.aggregation",
        model_names_vis,
        xlim=[REC_X_MAX, REC_X_MIN],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )
    str_filename_lr = os.path.join(
        DIR_OUT, "plot_downstream_agg_recloss_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_mean_sorted,
        ds_std_sorted,
        str_rec,
        "downstream_lr_agg",
        str_filename_lr,
        "Reconstruction Error",
        "Average Precision",
        "model.aggregation",
        model_names_vis,
        xlim=[REC_X_MAX, REC_X_MIN],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )

    # generate plot conditional generation and unimodal representations/coherence
    ds_uni_df = dataset_df[col_names]
    mean_uni_lr = ds_df[list_str_ds_uni].mean(axis=1)
    ds_uni_df["downstream_lr_uni"] = mean_uni_lr
    ds_uni_mean = ds_uni_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    ds_uni_std = ds_uni_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    ds_uni_mean = ds_uni_mean.reset_index()
    ds_uni_std = ds_uni_std.reset_index()
    ds_uni_mean_sorted = ds_uni_mean.sort_values(by=["model.name", "model.beta"])
    ds_uni_std_sorted = ds_uni_std.sort_values(by=["model.name", "model.beta"])
    ds_uni_mean_sorted.to_csv(
        os.path.join(DIR_OUT, "results_downstream_uni_" + str(dataset) + ".csv")
    )
    str_filename_lr_uni_condrec = os.path.join(
        DIR_OUT, "plot_downstream_uni_condrecloss_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_uni_mean_sorted,
        ds_uni_std_sorted,
        str_cond_rec,
        "downstream_lr_uni",
        str_filename_lr_uni_condrec,
        "Conditional Reconstruction Error",
        "Average Precision",
        "model.aggregation",
        model_names_vis,
        xlim=[xlim_max_cond, xlim_min_cond],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )
    str_filename_lr_uni_rec = os.path.join(
        DIR_OUT, "plot_downstream_uni_recloss_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_uni_mean_sorted,
        ds_uni_std_sorted,
        str_rec,
        "downstream_lr_uni",
        str_filename_lr_uni_rec,
        "Reconstruction Error",
        "Average Precision",
        "model.aggregation",
        model_names_vis,
        xlim=[REC_X_MAX, REC_X_MIN],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )
    #
    # plot downstream task performance vs FID
    ds_uni_df = dataset_df[col_names]
    mean_uni_lr = ds_uni_df[list_str_ds_uni].mean(axis=1)
    mean_fid = ds_uni_df[list_str_fid].mean(axis=1)
    ds_uni_df["downstream_lr_uni"] = mean_uni_lr
    ds_uni_df["fid"] = mean_fid
    ds_uni_mean = ds_uni_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    ds_uni_std = ds_uni_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    ds_uni_mean = ds_uni_mean.reset_index()
    ds_uni_std = ds_uni_std.reset_index()
    ds_uni_mean_sorted = ds_uni_mean.sort_values(by=["model.name", "model.beta"])
    ds_uni_std_sorted = ds_uni_std.sort_values(by=["model.name", "model.beta"])

    str_filename_lr_uni_fid = os.path.join(
        DIR_OUT, "plot_downstream_uni_fid_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_uni_mean_sorted,
        ds_uni_std_sorted,
        "fid",
        "downstream_lr_uni",
        str_filename_lr_uni_fid,
        "FID",
        "Accuracy",
        "model.aggregation",
        model_names_vis,
        xlim=[FID_X_MAX, FID_X_MIN],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )
    # plot coherence performance vs FID
    ds_coh_df = dataset_df[col_names]
    mean_coh_lr = ds_coh_df[list_str_coh].mean(axis=1)
    mean_fid = ds_coh_df[list_str_fid].mean(axis=1)
    ds_coh_df["coherence"] = mean_coh_lr
    ds_coh_df["fid"] = mean_fid
    ds_coh_mean = ds_coh_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    ds_coh_std = ds_coh_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    ds_coh_mean = ds_coh_mean.reset_index()
    ds_coh_std = ds_coh_std.reset_index()
    ds_coh_mean_sorted = ds_coh_mean.sort_values(by=["model.name", "model.beta"])
    ds_coh_std_sorted = ds_coh_std.sort_values(by=["model.name", "model.beta"])

    str_filename_coh_fid = os.path.join(
        DIR_OUT, "plot_coherence_fid_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_coh_mean_sorted,
        ds_coh_std_sorted,
        "fid",
        "coherence",
        str_filename_coh_fid,
        "FID",
        "Accuracy",
        "model.aggregation",
        model_names_vis,
        xlim=[FID_X_MAX, FID_X_MIN],
        ylim=[COH_Y_MIN, COH_Y_MAX],
    )


# create bar plot for performance with respect to all labels
dataset_df = dataset_df.loc[dataset_df["model.beta"] == 1]

for key in modalities:
    col_names_attrs_lr_uni_mod = []
    for k, attr in enumerate(attribute_names):
        str_attr = f"val/downstream/unimodal/{key}/{attr}"
        col_names_attrs_lr_uni_mod.append(str_attr)
    col_names_mod = col_names + col_names_attrs_lr_uni_mod
    ds_df = dataset_df[col_names_mod]
    ds_df = ds_df.rename(
        columns={
            col_names_attrs_lr_uni_mod[k]: attribute_names[k]
            for k in range(len(attribute_names))
        }
    )

    ds_mean = ds_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    ds_std = ds_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    ds_mean = ds_mean.reset_index()
    ds_std = ds_std.reset_index()
    str_fn_plot_attrs = os.path.join(
        DIR_OUT, f"plot_downstream_uni_{key}_attrs_celeba.png"
    )
    generate_plot_all_attributes(
        ds_mean,
        ds_std,
        str_fn_plot_attrs,
        "Attributes",
        "Average Precision",
        attribute_names,
        model_names_vis,
    )


if use_local_wandb:
    os.unsetenv("HTTP_PROXY")
    wandb.login(host="https://api.wandb.ai")
