import os
import pandas as pd
import wandb

from utils_plots import generate_plot


NUM_VIEWS = 3.0
NCOLS_L = 4
FID_X_MIN = 25
FID_X_MAX = 275
REC_X_MIN = 2900
REC_X_MAX = 5000
COH_Y_MIN = 0.0
COH_Y_MAX = 0.7
DS_Y_MIN = 0.0
DS_Y_MAX = 1.0

use_local_wandb = True
if use_local_wandb:
    os.environ["HTTP_PROXY"] = "socks5h://localhost:10080"
    wandb.login(host="PUT YOUR WANDB ENTITY HERE")
# wandb.login(host=os.getenv("WANDB_LOCAL_URL"))
api = wandb.Api()
# Project is specified by <entity/project-name>
runs = api.runs("PUT YOUR WANDB PROJECT HERE")


summary_list, config_list, name_list = [], [], []
for run in runs:
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})

    # .name is the human-readable name of the run.
    name_list.append(run.name)


runs_df = pd.concat(
    [
        pd.DataFrame(summary_list),
        pd.DataFrame(
            [
                {
                    k: v
                    for k, v in x.items()
                    if k != "log" and k != "model" and k != "dataset"
                }
                for x in config_list
            ]
        ),
        pd.DataFrame(
            [{f"dataset.{k}": v for k, v in x["dataset"].items()} for x in config_list]
        ),
        pd.DataFrame(
            [{f"model.{k}": v for k, v in x["model"].items()} for x in config_list]
        ),
        pd.DataFrame(
            [{f"log.{k}": v for k, v in x["log"].items()} for x in config_list]
        ),
        pd.DataFrame(name_list),
    ],
    axis=1,
)


DIR_OUT = "PUT PLOT SAVE DIR HERE"
if not os.path.exists(DIR_OUT):
    os.makedirs(DIR_OUT)

n_epochs_threshold = 10
beta_threshold = 0.0
latent_dim = 512
dataset = "PM_translated75"
num_views = 3
modalities = ["m" + str(m) for m in range(0, num_views)]
lr = 5e-4
batch_size = 256
alpha_annealing = True
model_names = ["unimodal", "joint", "split", "mixedprior"]
model_names = ["unimodal", "joint", "mixedprior"]
split_type = "plus"
model_names_vis = {
    "unimodal": "independent",
    "joint": {
        "avg": "AVG",
        "moe": "MoE",
        "poe": "PoE",
        "mopoe": "MoPoE",
    },
    # "split": "split",
    "mixedprior": "MMVM",
}
wandb_group = "20240513"
dataset_names = ["PM_translated75"]

agg_df = runs_df.copy()
# agg_df = agg_df.loc[agg_df["model.beta"] == beta]
agg_df = agg_df.loc[agg_df["model.lr"] == lr]
agg_df = agg_df.loc[agg_df["model.latent_dim"] == latent_dim]
agg_df = agg_df.loc[agg_df["model.alpha_annealing"] == alpha_annealing]
agg_df = agg_df.loc[agg_df["log.wandb_group"] == wandb_group]
agg_df = agg_df.loc[agg_df["model.beta"] >= beta_threshold]
agg_df = agg_df.loc[agg_df["epoch"] >= n_epochs_threshold]

col_names = [
    "model.name",
    "model.aggregation",
    "dataset.name",
    "model.beta",
    "model.seed",
]
list_str_coh = []
list_str_fid = []
list_str_ds_agg = []
list_str_ds_uni = []
for m_key in modalities:
    for m_tilde_key in modalities:
        str_coh = "val/coherence/" + m_key + "_to_" + m_tilde_key
        col_names.append(str_coh)
        list_str_coh.append(str_coh)
        str_fid = "val/fid/" + m_key + "_to_" + m_tilde_key
        col_names.append(str_fid)
        list_str_fid.append(str_fid)
    # str_lhood = "val/likelihood/" + m_key
    # col_names.append(str_lhood)
    str_ds_agg = "val/downstream/aggregated/" + m_key
    col_names.append(str_ds_agg)
    list_str_ds_agg.append(str_ds_agg)
    str_ds_uni = "val/downstream/unimodal/" + m_key
    col_names.append(str_ds_uni)
    list_str_ds_uni.append(str_ds_uni)
str_cond_rec = "val/condition_generation/avg_rec_loss"
str_rec = "val/loss/avg_rec_loss_epoch"
col_names.append(str_rec)
col_names.append(str_cond_rec)

dataset_df = agg_df.loc[agg_df["dataset.name"] == dataset]

single_plots_datasets = True
if single_plots_datasets:
    coh_df = dataset_df[col_names]
    mean_coh = coh_df[list_str_coh].mean(axis=1)
    coh_df["coherence"] = mean_coh
    coh_mean = coh_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    coh_std = coh_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    coh_mean = coh_mean.reset_index()
    coh_std = coh_std.reset_index()
    coh_mean_sorted = coh_mean.sort_values(by=["model.beta"])
    coh_std_sorted = coh_std.sort_values(by=["model.beta"])

    str_fn_legend = os.path.join(DIR_OUT, "legend_coherence_recloss.png")
    str_filename_coh = os.path.join(
        DIR_OUT, "plot_coherence_recloss_" + str(dataset) + ".png"
    )
    generate_plot(
        coh_mean_sorted,
        coh_std_sorted,
        str_rec,
        "coherence",
        str_filename_coh,
        "Reconstruction Error",
        "Accuracy",
        "model.aggregation",
        model_names_vis,
        xlim=[REC_X_MAX, REC_X_MIN],
        ylim=[COH_Y_MIN, COH_Y_MAX],
        fn_legend=str_fn_legend,
    )
    str_filename_coh_condrec = os.path.join(
        DIR_OUT, "plot_coherence_condrecloss_" + str(dataset) + ".png"
    )
    generate_plot(
        coh_mean_sorted,
        coh_std_sorted,
        str_cond_rec,
        "coherence",
        str_filename_coh_condrec,
        "Conditional Reconstruction Error",
        "Accuracy",
        "model.aggregation",
        model_names_vis,
        xlim=[8000, REC_X_MIN],
        ylim=[COH_Y_MIN, COH_Y_MAX],
    )

    # generate plots for latent representations (aggregated)
    ds_df = dataset_df[col_names]
    mean_lr = ds_df[list_str_ds_agg].mean(axis=1)
    ds_df["downstream_lr_agg"] = mean_lr

    ds_mean = ds_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    ds_std = ds_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    ds_mean = ds_mean.reset_index()
    ds_std = ds_std.reset_index()
    ds_mean_sorted = ds_mean.sort_values(by=["model.beta"])
    ds_std_sorted = ds_std.sort_values(by=["model.beta"])

    str_filename_lr = os.path.join(
        DIR_OUT, "plot_downstream_agg_recloss_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_mean_sorted,
        ds_std_sorted,
        str_rec,
        "downstream_lr_agg",
        str_filename_lr,
        "Reconstruction Error",
        "Accuracy",
        "model.aggregation",
        model_names_vis,
        xlim=[REC_X_MAX, REC_X_MIN],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )
    str_filename_lr = os.path.join(
        DIR_OUT, "plot_downstream_agg_recloss_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_mean_sorted,
        ds_std_sorted,
        str_rec,
        "downstream_lr_agg",
        str_filename_lr,
        "Reconstruction Error",
        "Accuracy",
        "model.name",
        model_names_vis,
        xlim=[REC_X_MAX, REC_X_MIN],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )

    # generate plot conditional generation and unimodal representations/coherence
    ds_uni_df = dataset_df[col_names]
    mean_uni_lr = ds_uni_df[list_str_ds_uni].mean(axis=1)
    ds_uni_df["downstream_lr_uni"] = mean_uni_lr
    ds_uni_mean = ds_uni_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    ds_uni_std = ds_uni_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    ds_uni_mean = ds_uni_mean.reset_index()
    ds_uni_std = ds_uni_std.reset_index()
    ds_uni_mean_sorted = ds_uni_mean.sort_values(by=["model.name", "model.beta"])
    ds_uni_std_sorted = ds_uni_std.sort_values(by=["model.name", "model.beta"])
    ds_uni_mean_sorted.to_csv(
        os.path.join(DIR_OUT, "results_downstream_uni_" + str(dataset) + ".csv")
    )
    str_filename_lr_uni_condrec = os.path.join(
        DIR_OUT, "plot_downstream_uni_condrecloss_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_uni_mean_sorted,
        ds_uni_std_sorted,
        str_cond_rec,
        "downstream_lr_uni",
        str_filename_lr_uni_condrec,
        "Conditional Reconstruction Error",
        "Accuracy",
        "model.aggregation",
        model_names_vis,
        xlim=[8000, 2900],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )
    str_filename_lr_uni_rec = os.path.join(
        DIR_OUT, "plot_downstream_uni_recloss_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_uni_mean_sorted,
        ds_uni_std_sorted,
        str_rec,
        "downstream_lr_uni",
        str_filename_lr_uni_rec,
        "Reconstruction Error",
        "Accuracy",
        "model.aggregation",
        model_names_vis,
        xlim=[REC_X_MAX, REC_X_MIN],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )

    # plot downstream task performance vs FID
    ds_uni_df = dataset_df[col_names]
    mean_uni_lr = ds_uni_df[list_str_ds_uni].mean(axis=1)
    mean_fid = ds_uni_df[list_str_fid].mean(axis=1)
    ds_uni_df["downstream_lr_uni"] = mean_uni_lr
    ds_uni_df["fid"] = mean_fid
    ds_uni_mean = ds_uni_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    ds_uni_std = ds_uni_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    ds_uni_mean = ds_uni_mean.reset_index()
    ds_uni_std = ds_uni_std.reset_index()
    ds_uni_mean_sorted = ds_uni_mean.sort_values(by=["model.name", "model.beta"])
    ds_uni_std_sorted = ds_uni_std.sort_values(by=["model.name", "model.beta"])

    str_filename_lr_uni_fid = os.path.join(
        DIR_OUT, "plot_downstream_uni_fid_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_uni_mean_sorted,
        ds_uni_std_sorted,
        "fid",
        "downstream_lr_uni",
        str_filename_lr_uni_fid,
        "FID",
        "Accuracy",
        "model.aggregation",
        model_names_vis,
        xlim=[FID_X_MAX, FID_X_MIN],
        ylim=[DS_Y_MIN, DS_Y_MAX],
    )
    # plot coherence performance vs FID
    ds_coh_df = dataset_df[col_names]
    mean_coh_lr = ds_coh_df[list_str_coh].mean(axis=1)
    mean_fid = ds_coh_df[list_str_fid].mean(axis=1)
    ds_coh_df["coherence"] = mean_coh_lr
    ds_coh_df["fid"] = mean_fid
    ds_coh_mean = ds_coh_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).mean()
    ds_coh_std = ds_coh_df.groupby(
        ["model.name", "model.aggregation", "dataset.name", "model.beta"]
    ).std()
    ds_coh_mean = ds_coh_mean.reset_index()
    ds_coh_std = ds_coh_std.reset_index()
    ds_coh_mean_sorted = ds_coh_mean.sort_values(by=["model.name", "model.beta"])
    ds_coh_std_sorted = ds_coh_std.sort_values(by=["model.name", "model.beta"])

    str_filename_coh_fid = os.path.join(
        DIR_OUT, "plot_coherence_fid_" + str(dataset) + ".png"
    )
    generate_plot(
        ds_coh_mean_sorted,
        ds_coh_std_sorted,
        "fid",
        "coherence",
        str_filename_coh_fid,
        "FID",
        "Accuracy",
        "model.aggregation",
        model_names_vis,
        xlim=[FID_X_MAX, FID_X_MIN],
        ylim=[COH_Y_MIN, COH_Y_MAX],
    )

ghif use_local_wandb:
    os.unsetenv("HTTP_PROXY")
    wandb.login(host="https://api.wandb.ai")
