import torch
import os
import pandas as pd
from models import test_domains


def parse_results(ckpt_folder, job_ids, ckpt_name):
    domains = ["id", "ood_b", "ood_r", "ood_br"]
    jobs = [str(j) for j in job_ids]

    all_results = []
    for j in jobs:
        job_folder = os.path.join(ckpt_folder, j)
        model_file = os.path.join(job_folder, ckpt_name)
        model_info = {}
        model_info = torch.load(model_file)["hyper_parameters"]
        model_info["model_file"] = str(model_file)
        model_info["val_acc"] = {}
        model_info["val_acc"] = test_domains.load_val_accuracy(job_folder, ckpt_name)
        test_acc = test_domains.load_values("domain_accuracies", job_folder)
        for d in domains:
            model_info["test_acc_{0}".format(d)] = test_acc[d]
        all_results.append(model_info)
    df = pd.DataFrame(all_results)
    return df


def print_perf(
    df,
    sort_by_metric,
    output_dir="tmp",
    save_name="results.csv",
    sorting_hparams=["learning_rate"],
    to_print=["val_acc"],
):

    for h in sorting_hparams:
        unique_values = df[h].unique()
        for u in unique_values:
            print("\n")
            # get rows corresponding to value u of sorting_hparams
            df_hvalue = df[df[h] == u]
            # get index of row with best performance according to sort_by_metric
            best_row = df_hvalue["{}".format(sort_by_metric)].argmax()
            hparams = {}
            # save hyperparameters associated with best row in dict hparams
            for h_2 in df.iloc[best_row].to_frame().index:
                if h_2 not in ["model_file"]:
                    if "val_acc" not in h_2 and "test_acc" not in h_2:
                        hparams[h_2] = df_hvalue.iloc[best_row][h_2]
            # get all runs with identical hyperparameters
            new_df = df
            for h_3 in hparams:
                new_df = new_df.loc[new_df[h_3] == hparams[h_3]]
            new_df.index.names = [u]
