from all_data import get_all_data
from scipy.stats import tukey_hsd


def tukey_to_string(tukey):
    labels = [
        "GNN",
        "MLP",
        "Prost",
    ]
    ci = tukey.confidence_interval(confidence_level=0.95)
    pvalue = tukey.pvalue
    statistic = tukey.statistic
    s = ""
    # s += " & ".join(["Comparison", "Statistic", "p-value", "Lower CI", "Upper CI"])
    # s += "\n"
    for i in range(pvalue.shape[0]):
        for j in range(pvalue.shape[0]):
            if i < j:
                ci_high_str = (
                    f"{ci.high[i, j]:>10.2f}"
                    if ci.high[i, j] < 0
                    else f"\ \ {ci.high[i, j]:>10.2f}"
                )
                s += " & ".join(
                    [
                        f"{labels[i]} - {labels[j]}",
                        f"{statistic[i, j]:>10.2f}",
                        f"{pvalue[i, j]:>10.2f}",
                        f"[{ci.low[i, j]:>10.2f}" "," + ci_high_str + "]",
                    ]
                )
                s += "\\\\"
                s += "\n"
    return s


df = get_all_data()


df = df[df["batch_id"] == "37290af5-1d7d-4d35-88fe-6fcf8ffc5868"]

grouped = df.groupby(["domain"])

s = ""
s += "\\toprule\n"
s += " & ".join(["Comparison", "Statistic", "p-value", "CI"])
s += "\\\\ \n"
s += "\\midrule\n"

for name, group in grouped:
    s += "\\multicolumn{4}{l}" f"{{{name[0].replace('_MDP_ippc', ' ')}}} \\\\\n"
    tukey = tukey_hsd(
        group[~group["is_train"]]["score"],
        group[~group["is_train"]]["mlp"],
        group[~group["is_train"]]["prost"],
    )
    confidence = tukey.confidence_interval()
    # print as latex table
    s += tukey_to_string(tukey)
    s += "\\midrule\n"

s += "\\bottomrule\n"

print(s)
