import matplotlib.pyplot as plt
import json
import pandas as pd
import numpy as np

runs_data = pd.read_csv("/home/jakob/rddleval/runs.csv")
runs_data = runs_data.set_index("run_id")


def get_param(run_id, param):
    r = runs_data.loc[run_id]
    return r[f"params.{param}"]


mlp_data_file = "mlps.json"
prost_data_file = "prost.json"
min_data_file = "mins.json"
gnn_data_file = "evaluation_results.json"

with open(mlp_data_file) as f:
    mlp_data = json.load(f)

best_mlp_data = mlp_data["mlp400k"]

with open(prost_data_file) as f:
    prost_data = json.load(f)["prost"]

with open(min_data_file) as f:
    min_data = json.load(f)

with open(gnn_data_file) as f:
    gnn_data = json.load(f)

for run in gnn_data:
    gnn_data[run]["batch_id"] = get_param(run, "batch_id")


best_gnn_data = [
    gnn_data[run]
    for run in gnn_data
    if gnn_data[run]["batch_id"] == "09e8bcad-221f-4cf0-8c78-6b2743783394"
]

best_gnn_data = {k["domain"]: k for k in best_gnn_data}

results = []

for domain in best_gnn_data:
    gnn = np.array(best_gnn_data[domain]["instance_returns"]).mean(-1)
    mins = np.array(min_data[domain])
    mlp = np.array(best_mlp_data[domain])
    prost = np.array(prost_data[domain])
    gnn -= mins
    mlp -= mins
    prost -= mins
    results.append(
        (
            {
                "domain": domain,
                "gnn": max(gnn[5:].mean(), 0),
                "prost": max(prost[5:].mean(), 0),
                "mlp": max(mlp[5:].mean(), 0),
            }
        )
    )

    pass

df = pd.DataFrame.from_records(results)

df = df.set_index("domain")


df.plot(kind="barh", figsize=(20, 6))
plt.tight_layout()
plt.savefig("test.png")


pass
