#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 27 12:27:56 2025

Load results generated by ScoringAnalysis.py and generate big tables

"""


# import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import ecdf
from tabulate import tabulate
from sklearn.metrics import auc


if __name__ == "__main__":
    scenarios = ["retina/retina",
                 "retinalOCT_partial/retinalOCT_partial",
                 "WBC_partial/WBC_partial",
                 "FEMNIST_partial/FEMNIST_partial",
                 "MNIST_partial_fn/MNIST_partial_fn",
                 "MNIST_partial_ln/MNIST_partial_ln"]
    # conformal_scores = ["OCSVM",
    #                     "IF",
    #                     "AE",
    #                     "AdaDetect_LR",
    #                     "AdaDetect_SVC",
    #                     "labelOCSVM",
    #                     "labelIF",
    #                     "labelAE",
    #                     "labelAdaDetect_LR",
    #                     "labelAdaDetect_SVC"]
    conformal_scores = ["AdaDetect_SVC",
                        "labelAdaDetect_LR"]
    names = ["Storey", "Quantile", "Fisher", "Sum"]
    pi_th_indices = [0, 1, 2]
    pis = [0.0, 0.1, 0.2, 0.3]
    pi_idx = 3

    AUC_cod = np.zeros((len(scenarios), len(conformal_scores)), dtype=np.float32)
    AUC_cdct = np.zeros((len(scenarios), len(conformal_scores),
                         len(pi_th_indices), len(pis), 4), dtype=np.float32)
    for sc_idx, sc in enumerate(scenarios):
        for cs_idx, cs in enumerate(conformal_scores):
            savename = f"{sc}_{cs}_discrete_uniform"
            with open(savename+".txt", "r") as file:
                lines = file.readlines()
                Kminus1 = eval(lines[2].split(" ")[1][:-1])
                K = Kminus1 + 1
                n = eval(lines[4].split(" ")[1][:-1])
                m = eval(lines[5].split(" ")[1][:-1])
                pi_model = lines[10].split(" ")[-1][:-1]
                assert lines[4].split(" ")[0][0] == "n", ""
                assert lines[10].split(" ")[0] == "pi"
            assert pi_model == "discrete_uniform", ""
            in_data = np.load(f"{savename}.npz")
            pi_true = in_data["pi_true"]
            pvals_all = in_data["pvals_all"]
            indicator_all = in_data["indicator_all"]
            pcon_total = in_data["pcon_total"]
            pi_true_flat = pi_true.flatten()
            pcon_flat = pcon_total[:, :, :, :].reshape(-1, 4)
            pi_th_arr = in_data["pi_th_arr"]
            len_pi_th_arr = len(pi_th_arr)
            sims = pi_true.shape[0]

            # =============================================================================
            # Distribution of conformal p-values of OOD samples
            # =============================================================================
            outliers = pvals_all[indicator_all]
            outlier_mass = np.zeros(n+1)
            for j in range(1, n+2):
                outlier_mass[j-1] = sum(np.where(np.isclose(outliers, np.ones(len(outliers))*j/(n+1)), 1, 0))
            outlier_mass = outlier_mass/np.sum(outlier_mass)
            outlier_cdf = np.cumsum(outlier_mass)
            AUC_cod[sc_idx, cs_idx] = np.sum(outlier_cdf) / (n+2)

            for pi_th_idx in pi_th_indices:
                for pi_idx, pi in enumerate(pis):
                    # =============================================================================
                    # Distribution of conformal data contamination p-values
                    # =============================================================================
                    pi_true = in_data["pi_true"][:, pi_th_idx, :]
                    indicator_all = in_data["indicator_all"]
                    pcon_flat = in_data["pcon_total"][:, pi_th_idx, :, :].reshape(-1, 4)
                    pi_true_flat = pi_true.flatten()
                    ind = np.logical_and(pi_true_flat >= pi-0.01, pi_true_flat < pi+0.01)
                    CCTest_ecdf_list = list()
                    auc_data = np.zeros(4, dtype=np.float32)
                    for idx2_, idx2 in enumerate([0, 1, 3, 2]):
                        y = pcon_flat[ind, idx2]
                        CCTest_ecdf_list.append(ecdf(y))
                        quantiles = np.hstack((0, ecdf(y).cdf.quantiles, 1))
                        probabilities = np.hstack((0, ecdf(y).cdf.probabilities, 1))
                        AUC_cdct[sc_idx, cs_idx, pi_th_idx, pi_idx, idx2_] = auc(quantiles, probabilities)

    print("\n AUC COD table\n")
    headers = [f"(S{idx+1:d})" for idx in range(len(scenarios))]
    table = list()
    for cs_idx, cs in enumerate(conformal_scores):
        table.append([f"(C{cs_idx+1:d})"] + [AUC_cod[sc_idx, cs_idx] for sc_idx in range(len(scenarios))])
    print(tabulate(table, headers, tablefmt="latex"))

    print("\n AUC CDCT Power table")
    for cs_idx, cs in enumerate(conformal_scores):
        print(f"\n(C{cs_idx+1:d})\n")
        headers = list(np.repeat([f"(S{idx+1:d})" for idx in range(0, 3)], 3))
        table = list()
        for idx2 in range(4):
            table.append([f"{names[idx2]}"] + [AUC_cdct[sc_idx, cs_idx, pi_th_idx, pi_idx, idx2] for sc_idx in range(0, 3) for pi_th_idx in range(len(pi_th_indices))])
        print(tabulate(table, headers, tablefmt="latex", floatfmt=(".4f")))

    print("\n AUC CDCT Power table")
    for cs_idx, cs in enumerate(conformal_scores):
        print(f"\n(C{cs_idx+1:d})\n")
        headers = list(np.repeat([f"(S{idx+1:d})" for idx in range(3, 6)], 3))
        table = list()
        for idx2 in range(4):
            table.append([f"{names[idx2]}"] + [AUC_cdct[sc_idx, cs_idx, pi_th_idx, pi_idx, idx2] for sc_idx in range(3, 6) for pi_th_idx in range(len(pi_th_indices))])
        print(tabulate(table, headers, tablefmt="latex", floatfmt=(".4f")))

    print("\n AUC CDCT Error table")
    for cs_idx, cs in enumerate(conformal_scores):
        print(f"\n(C{cs_idx+1:d})\n")
        headers = list(np.repeat([f"(S{idx+1:d})" for idx in range(0, 3)], 3))
        table = list()
        for idx2 in range(4):
            table.append([f"{names[idx2]}"] + [AUC_cdct[sc_idx, cs_idx, pi_th_idx, pi_th_idx, idx2] for sc_idx in range(0, 3) for pi_th_idx in range(len(pi_th_indices))])
        print(tabulate(table, headers, tablefmt="latex", floatfmt=(".4f")))

    print("\n AUC CDCT Error table")
    for cs_idx, cs in enumerate(conformal_scores):
        print(f"\n(C{cs_idx+1:d})\n")
        headers = list(np.repeat([f"(S{idx+1:d})" for idx in range(3, 6)], 3))
        table = list()
        for idx2 in range(4):
            table.append([f"{names[idx2]}"] + [AUC_cdct[sc_idx, cs_idx, pi_th_idx, pi_th_idx, idx2] for sc_idx in range(3, 6) for pi_th_idx in range(len(pi_th_indices))])
        print(tabulate(table, headers, tablefmt="latex", floatfmt=(".4f")))
