import json
import pandas as pd
import numpy as np
from scipy.stats import ttest_ind, tukey_hsd, ttest_rel, pearsonr, PermutationMethod


runs_data = pd.read_json("data/runs.json", orient="index")


def calculate_stats(dist1, dist2, batch_id, question, domain):
    t, p = ttest_ind(dist1, dist2, equal_var=False, method=PermutationMethod())
    if np.isnan(t) or np.isnan(p):
        return None

    row = pd.Series(
        {
            "question": question,
            "domain": domain,
            "batch_id": batch_id,
            "t": t,
            "p": p,
        }
    )
    return row


def get_param(run_id, param):
    r = runs_data.loc[run_id]
    return r[f"params.{param}"]


def main():
    with open("data/eval_results.jsonl") as f:
        data = [json.loads(x) for x in f.readlines()]

    with open("data/mins.json") as f:
        min_data = json.load(f)

    with open("data/prost.json") as f:
        prost_data = json.load(f)["prost"]

    with open("data/mlps.json") as f:
        mlp_data = json.load(f)["mlp400k"]

    rows = []
    for d in data:
        domain = d["domain"]
        r = d["run_id"]
        batch_id = d.get("batch_id", None) or get_param(r, "batch_id")
        train_instances = d.get("train_instances") or json.loads(
            get_param(r, "instance")
        )
        train_instances = np.array(train_instances) - 1
        prost = np.array(prost_data[domain])
        gnn = np.array(d["instance_returns"]).mean(-1)
        mins = np.array(min_data[domain])

        for i, (gnn_score, prost_score) in enumerate(zip(gnn, prost)):
            row = pd.Series(
                {
                    "run_id": r,
                    "batch_id": batch_id,
                    "domain": domain,
                    "instance": i,
                    "score": gnn_score,
                    "prost": prost_score,
                    "is_train": i in train_instances,
                    "mins": mins[i],
                    "mlp": mlp_data[domain][i],
                }
            )
            rows.append(row)

    df = pd.DataFrame(rows)
    df = df.set_index("run_id")

    grouped = df.groupby(["domain", "batch_id"])
    question_data = []

    # question 1. Are the models better than the threshold policies?
    # subquestion: is prost better than the threshold policies?

    q1_data = [
        calculate_stats(
            group[~group["is_train"]]["score"],
            group[~group["is_train"]]["mins"],
            batch_id[1],
            "q1",
            batch_id[0],
        )
        for batch_id, group in grouped
    ]
    q1_data = [x for x in q1_data if x is not None]
    question_data.extend(q1_data)

    # question 2. Is the train score better than the eval score?
    q2_data = [
        calculate_stats(
            group[~group["is_train"]]["score"],
            group[group["is_train"]]["score"],
            batch_id[1],
            "q2",
            batch_id[0],
        )
        for batch_id, group in grouped
    ]
    q2_data = [x for x in q2_data if x is not None]
    question_data.extend(q2_data)

    # question 3. Are the eval scores the same for the GNN and PROST policies?
    q3_data = [
        calculate_stats(
            group[~group["is_train"]]["score"],
            group[~group["is_train"]]["prost"],
            batch_id[1],
            "q3",
            batch_id[0],
        )
        for batch_id, group in grouped
    ]
    q3_data = [x for x in q3_data if x is not None]
    question_data.extend(q3_data)

    # question 4. Are the eval scores the same as the MLP scores?
    q4_data = [
        calculate_stats(
            group[~group["is_train"]]["score"],
            group[~group["is_train"]]["mlp"],
            batch_id[1],
            "q4",
            batch_id[0],
        )
        for batch_id, group in grouped
    ]
    q4_data = [x for x in q4_data if x is not None]
    question_data.extend(q4_data)

    with open("data/question_data.csv", "w") as f:
        pd.DataFrame(question_data).to_csv(f, index=False)


if __name__ == "__main__":
    main()
