import pandas as pd
import json
from IPython.display import HTML, display

def side_by_side(*dfs):
    for i in range(0, len(dfs), 2):
        html = '<div style="display:flex">'
        for df, title in dfs[i:i+2]:
            html += '<div style="margin-right: 2em">'
            html += title
            html += df.to_html()
            html += '</div>'
        html += '</div>'
        display(HTML(html))
    
def read_json(path, verbose=False):
    with open(path, "r") as f:
        rt = json.load(f)
        if verbose:
            return rt
        else:
            return {k: v for k, v in rt.items() if k in ["accuracy", "many_acc", "med_acc", "few_acc"]}

def average_pd(pds):
    avg_pd = 0
    for pd in pds:
        avg_pd += pd
    avg_pd = avg_pd/len(pds)
    return round(avg_pd, 1)

def read_pd(results, keep=None):
    seed_results = {0: {}}

    for r in results:
        name = r.parent.name
        seed = name.split("_")[-1]

        if 'inat2018' in name or 'minicifar100':
            name_ix = 1
        else:
            name_ix = 2

        if seed.isdigit():
            seed = int(seed)
            name = "_".join(name.split("_")[name_ix:-1])
        else:
            seed = 0
            name = "_".join(name.split("_")[name_ix:])

        if keep is not None and not (keep in name):
            continue

        if seed > 0:
            continue

        js = read_json(r)
        seed_results[seed][name] = js
    
    pd_0 = pd.DataFrame(seed_results[0]).T.sort_index()
    return average_pd([pd_0])