from all_data import get_all_data
from scipy.stats import tukey_hsd


def tukey_to_string(tukey):
    labels = [
        "Mimic Train",
        "Mimic Test",
        "PROST",
    ]
    ci = tukey.confidence_interval(confidence_level=0.95)
    pvalue = tukey.pvalue
    statistic = tukey.statistic
    s = ""
    # s += " & ".join(["Comparison", "Statistic", "p-value", "Lower CI", "Upper CI"])
    # s += "\n"
    for i in range(pvalue.shape[0]):
        for j in range(pvalue.shape[0]):
            if i < j:
                s += " & ".join(
                    [
                        f"{labels[i]} - {labels[j]}",
                        f"{statistic[i, j]:>10.2f}",
                        f"{pvalue[i, j]:>10.2f}",
                        f"{ci.low[i, j]:>10.2f}",
                        f"{ci.high[i, j]:>10.2f}",
                    ]
                )
                s += "\\\\"
                s += "\n"
    return s


df = get_all_data()


# df = df[df["batch_id"] == "a4e5a47f-d10a-4d5b-959c-1a64d986350f"]
# df = df[df["batch_id"] == "1ce31e2d-45e2-432b-baec-db8e35fadfdd"]
df = df[df["batch_id"] == "0ddf02f4-2f5a-4d67-9679-358502133a5b"]


grouped = df.groupby(["domain"])

s = ""
s += "\\toprule\n"
s += " & ".join(["Comparison", "Statistic", "p-value", "Lower CI", "Upper CI"])
s += "\\\\ \n"
s += "\\midrule\n"

for name, group in grouped:
    s += "\multicolumn{5}{l}" f"{{{name[0].replace('_MDP_ippc', ' ')}}} \\\\\n"
    tukey = tukey_hsd(
        group[~group["is_train"]]["score"],
        group[group["is_train"]]["score"],
        group[~group["is_train"]]["prost"],
    )
    confidence = tukey.confidence_interval()
    # print as latex table
    s += tukey_to_string(tukey)
    s += "\\midrule\n"

s += "\\bottomrule\n"

print(s)
