import pandas as pd
import json
import numpy as np

runs_data = pd.read_json("data/runs.json", orient="index")


def get_param(run_id, param):
    r = runs_data.loc[run_id]
    return r[f"params.{param}"]


with open("data/prost100.json") as f:
    prost_data = json.load(f)["prost"]

with open("data/mins100.json") as f:
    mins_data = json.load(f)

with open("data/eval_400k_0.jsonl") as f:
    mlp_data = [json.loads(l) for l in f.readlines()]

mlp_data = {d["domain"]: d for d in mlp_data}


with open("data/eval_37290af5-1d7d-4d35-88fe-6fcf8ffc5868.jsonl") as f:
    data = [json.loads(l) for l in f.readlines()]

rows = []
for d in data:
    domain = d["domain"]
    if domain == "TriangleTireworld_MDP_ippc2014":
        continue
    train_instances = json.loads(get_param(d["run_id"], "instance"))
    mlp_returns = mlp_data[domain]["instance_returns"]

    mins_returns = np.array(mins_data[domain])

    prost_returns = np.array(prost_data[domain])

    for i in range(10):
        score = d["instance_returns"][i]
        mlp_score = mlp_returns[i]
        min_score = mins_returns[i]
        prost_score = prost_returns[i]

        avg_mlp_score = np.mean(mlp_score)
        avg_gnn_score = np.mean(score)

        max_score = max([avg_gnn_score, avg_mlp_score, prost_score])

        # for e, s in enumerate(score):
        row = pd.Series(
            {
                # "episode": e,
                "run_id": d["run_id"],
                "domain": domain,
                "instance": str(i),
                "score": max(avg_gnn_score - min_score, 0) / (max_score - min_score)
                if (max_score - min_score) > 0
                else 0,
                "prost": max(prost_score - min_score, 0) / (max_score - min_score)
                if (max_score - min_score) > 0
                else 0,
                "is_train": i in train_instances,
                "mlp": max(avg_mlp_score - min_score, 0) / (max_score - min_score)
                if (max_score - min_score) > 0
                else 0,
            }
        )
        rows.append(row)


full_df = pd.DataFrame(rows)

# df = df.set_index(["domain", "instance"])

# plot histogram over scores


def statistic(x, y, axis):
    return np.mean(x, axis=axis) - np.mean(y, axis=axis)


from scipy.stats import permutation_test

import matplotlib.pyplot as plt
import seaborn as sns

# for domain in full_df["domain"].unique():
#     df = full_df[full_df["domain"] == domain]
#     df = df[~df["is_train"]]

#     res = permutation_test(
#         (df["score"], df["mlp"]),
#         statistic,
#         vectorized=True,
#         n_resamples=99999,
#         alternative="two-sided",
#     )

#     print(
#         f"Domain: {domain}, p-value: {res.pvalue:.4f}, statistic: {res.statistic:.4f}"
#     )

#     plt.figure(figsize=(10, 6))
#     sns.histplot(
#         df["score"],
#         kde=True,
#         color="blue",
#         stat="density",
#         bins=30,
#     )
#     plt.xlabel("Score")
#     plt.ylabel("Density")
#     plt.title("Score Distribution for GNN")
#     plt.grid(True)
#     plt.savefig(
#         f"gnn_{domain}_score_distribution_histogram_37290af5-1d7d-4d35-88fe-6fcf8ffc5868.png"
#     )


df = full_df[~full_df["is_train"]]


comparisons = [
    (("Vejde", "MLP"), ("score", "mlp")),
    (("Vejde", "Prost"), ("score", "prost")),
    (("MLP", "Prost"), ("mlp", "prost")),
]


results = {
    labels: permutation_test(
        (
            df[comp[0]].values,
            df[comp[1]].values,
        ),
        statistic,
        vectorized=True,
        permutation_type="independent",
        n_resamples=int(1e6),
        #batch=int(1e5),
        alternative="two-sided",
    )
    for (labels, comp) in comparisons
}

for comp, res in results.items():
    plt.figure(figsize=(10, 6))
    sns.histplot(
        res.null_distribution,
        kde=False,
        color="blue",
        stat="percent",
        bins=100,
    )
    # place a vertical line at the test statistic
    plt.axvline(
        res.statistic,
        color="red",
        linestyle="--",
        label=f"Test Statistic: {res.statistic:.4f}, $p$-value: {res.pvalue:.4f}",
    )
    plt.legend()
    plt.xlabel("Value of statistic")
    plt.ylabel("Density (%)")
    plt.grid(False)
    plt.ylim(0, 4)
    
    
    plt.savefig(
        f"{comp[0]}-{comp[1]}_null_distribution_37290af5-1d7d-4d35-88fe-6fcf8ffc5868.pdf"
    )
    plt.close()

    print(f"{comp[0]} - {comp[1]} & {res.statistic:.2f} & {res.pvalue:.2f} \\\\")

    sorted_null = np.sort(res.null_distribution)
    lower_bound = sorted_null[int(0.025 * len(sorted_null))]
    upper_bound = sorted_null[int(0.975 * len(sorted_null))]
    print(f"95% CI: [{lower_bound:.2f}, {upper_bound:.2f}]")

pass
