import json
import pandas as pd
import numpy as np
from scipy.stats import ttest_ind, tukey_hsd

runs_data = pd.read_json("data/runs.json", orient="index")


def get_param(run_id, param):
    r = runs_data.loc[run_id]
    return r[f"params.{param}"]


with open("data/eval_results.jsonl") as f:
    data = [json.loads(l) for l in f.readlines()]

with open("data/mins.json") as f:
    min_data = json.load(f)

with open("data/prost.json") as f:
    prost_data = json.load(f)["prost"]

with open("data/mlps.json") as f:
    mlp_data = json.load(f)["mlp400k"]


rows = []
for d in data:
    domain = d["domain"]
    r = d["run_id"]
    batch_id = d.get("batch_id", None) or get_param(r, "batch_id")
    train_instances = d.get("train_instances") or json.loads(get_param(r, "instance"))
    train_instances = np.array(train_instances) - 1
    prost = np.array(prost_data[domain])
    gnn = np.array(d["instance_returns"]).mean(-1)
    mins = np.array(min_data[domain])
    mlp = np.array(mlp_data[domain])

    highest_scores = np.maximum(gnn, prost)
    highest_scores = np.maximum(highest_scores, mlp)
    scale = highest_scores - mins

    scaled_gnn = np.maximum(gnn - mins, 0) / scale
    scaled_prost = np.maximum(prost - mins, 0) / scale
    scaled_mlp = np.maximum(mlp - mins, 0) / scale

    scaled_gnn = np.nan_to_num(scaled_gnn, nan=0.0, posinf=0.0, neginf=0.0)
    scaled_prost = np.nan_to_num(scaled_prost, nan=0.0, posinf=0.0, neginf=0.0)
    scaled_mlp = np.nan_to_num(scaled_mlp, nan=0.0, posinf=0.0, neginf=0.0)

    assert (scaled_gnn <= 1.0).all() and (scaled_gnn >= 0.0).all()

    for i, (gnn_score, prost_score, mlp_score) in enumerate(
        zip(scaled_gnn, scaled_prost, scaled_mlp)
    ):
        row = pd.Series(
            {
                "run_id": r,
                "batch_id": batch_id,
                "domain": domain,
                "instance": str(i),
                "score": gnn_score,
                "prost": prost_score,
                "is_train": i in train_instances,
                # "mins": mins[i],
                "mlp": mlp_score,
            }
        )
        rows.append(row)

df = pd.DataFrame(rows)
df = df.set_index("run_id")

df = df.sort_values(by=["batch_id", "run_id", "domain", "is_train"], ascending=False)


grouped = df[df["batch_id"] == "37290af5-1d7d-4d35-88fe-6fcf8ffc5868"]

tukey = tukey_hsd(
    grouped[~grouped["is_train"]]["score"],
    grouped[~grouped["is_train"]]["mlp"],
    grouped[~grouped["is_train"]]["prost"],
)


print(tukey)


with open("all_results.md", "w") as f:
    f.write(df.to_markdown())

with open("all_results.csv", "w") as f:
    f.write(df.to_csv())


df_mean = df.groupby(["batch_id", "domain", "is_train"]).mean(numeric_only=True)
df_std = (
    df.groupby(["batch_id", "domain", "is_train"])
    .std(numeric_only=True)
    .rename(
        mapper=lambda x: f"{x}_std",
        axis="columns",
    )
)

df = pd.concat((df_mean, df_std), axis=1)

df = df.sort_values(by=["is_train", "domain", "score"], ascending=False)
df = df.dropna(axis=0)
with open("train_eval_all_results.csv", "w") as f:
    f.write(df.to_csv())

df = df.sort_values(by=["domain", "is_train", "score"], ascending=False)
print(df.loc["37290af5-1d7d-4d35-88fe-6fcf8ffc5868", :, :])

df = (
    df.groupby(["batch_id", "is_train"])
    .mean(numeric_only=True)
    .sort_values(by=["is_train", "score"], ascending=False)
)
# print(df)

# eval_df = df[df["instance"] > 5]
# train_df = df[df["instance"] <= 5]
# eval_df = eval_df.rename(columns={"score": "eval_score", "prost": "eval_prost"}).drop(
#     columns="instance"
# )
# train_df = train_df.rename(
#     columns={"score": "train_score", "prost": "train_prost"}
# ).drop(columns="instance")

# eval_df = eval_df.groupby(["domain", "batch_id"]).mean(numeric_only=True)
# train_df = train_df.groupby(["domain", "batch_id"]).mean(numeric_only=True)

# combined_df = train_df.join(eval_df, how="inner").sort_values(
#     by=["domain", "eval_score"], ascending=False
# )

# with open("train_eval.md", "w") as f:
#     f.write(combined_df.to_markdown())

# # summary = df.groupby(["domain", "run_id", "batch_id"]).sum()

# # summary2 = summary.groupby(["domain"]).idxmax()

# summary3 = combined_df.groupby(["batch_id"]).sum().sort_values(by="eval_score")

# # with open("means_per_model.md", "w") as f:
# # 	f.write(summary.to_markdown())


# # with open("max_per_domain.md", "w") as f:
# # 	f.write(summary2.to_markdown())

# with open("score_per_model.md", "w") as f:
#     f.write(summary3.to_markdown())
