#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 29 14:58:55 2025

Load results generated by ScoringAnalysis.py and generate big TDR table.´
"""


import numpy as np
from tabulate import tabulate

import sys, os
if os.getcwd() not in sys.path: sys.path.append(os.getcwd())

from Benjamini_Hochberg import Storeys_correction, Benjamini_Hochberg_procedure


if __name__ == "__main__":
    scenarios = ["retina/retina",
                 "retinalOCT_partial/retinalOCT_partial",
                 "WBC_partial/WBC_partial",
                 "FEMNIST_partial/FEMNIST_partial",
                 "MNIST_partial_fn/MNIST_partial_fn",
                 "MNIST_partial_ln/MNIST_partial_ln"]
    # conformal_scores = ["OCSVM",
    #                     "IF",
    #                     "AE",
    #                     "AdaDetect_LR",
    #                     "AdaDetect_SVC",
    #                     "labelOCSVM",
    #                     "labelIF",
    #                     "labelAE",
    #                     "labelAdaDetect_LR",
    #                     "labelAdaDetect_SVC"]
    conformal_scores = ["AdaDetect_SVC",
                        "labelAdaDetect_LR"]
    names = ["Storey", "Quantile", "Fisher", "Sum"]
    pi_th_arr = [0.0, 0.1, 0.2]
    alpha_idx = 4
    gamma = 0.5 # Storey's hyperparameter conformal contamination test
    num_rejection_thresholds = 40
    significance_levels = np.linspace(0, 0.975, num_rejection_thresholds)

    TDR = np.zeros((len(scenarios), len(conformal_scores),
                    len(pi_th_arr), num_rejection_thresholds, 4), dtype=np.float32)
    for sc_idx, sc in enumerate(scenarios):
        for cs_idx, cs in enumerate(conformal_scores):
            print(f"{sc_idx+1:d}/{len(scenarios):d}, {cs_idx+1:d}/{len(conformal_scores):d}")
            savename = f"{sc}_{cs}_discrete_uniform"
            with open(savename+".txt", "r") as file:
                lines = file.readlines()
                Kminus1 = eval(lines[2].split(" ")[1][:-1])
                K = Kminus1 + 1
                pi_model = lines[10].split(" ")[-1][:-1]
                assert lines[10].split(" ")[0] == "pi"
            assert pi_model == "discrete_uniform", ""
            in_data = np.load(f"{savename}.npz")
            pi_true = in_data["pi_true"]
            pcon_total = in_data["pcon_total"]
            sims = pi_true.shape[0]

            TDP_total = np.zeros((sims, len(pi_th_arr), num_rejection_thresholds, 4), dtype=np.float32)
            for sim in range(sims):
                for idx_alpha, alpha in enumerate(significance_levels):
                    for pi_th_idx, pi_th in enumerate(pi_th_arr):
                        for idx2_, idx2 in enumerate([0, 1, 3, 2]):
                            pvals = pcon_total[sim, pi_th_idx, :, idx2]
                            pi_k = pi_true[sim, pi_th_idx, :]
                            _, K0_hat = Storeys_correction(gamma, pvals, K-1, False)
                            rejectBool_ = Benjamini_Hochberg_procedure(pvals, alpha, K0_hat, K-1)
                            # TDP
                            K1P = np.sum(pi_k>pi_th)
                            RcapHc = np.sum(rejectBool_[pi_k>pi_th])
                            if K1P == 0:
                                TDP = 0
                            else:
                                TDP = RcapHc/K1P
                            TDP_total[sim, pi_th_idx, idx_alpha, idx2_] = TDP
            TDR[sc_idx, cs_idx] = np.mean(TDP_total, axis=0)

    print("\n TDR table scenarios (S1), (S2), (S3)")
    for cs_idx, cs in enumerate(conformal_scores):
        print(f"\n(C{cs_idx+1:d})\n")
        headers = list(np.repeat([f"(S{idx+1:d})" for idx in range(0, 3)], 3))
        table = list()
        for idx2 in range(4):
            table.append([f"{names[idx2]}"] + [TDR[sc_idx, cs_idx, pi_th_idx, alpha_idx, idx2] for sc_idx in range(0, 3) for pi_th_idx in range(len(pi_th_arr))])
        print(tabulate(table, headers, tablefmt="latex", floatfmt=(".4f")))

    print("\n TDR table scenarios (S4), (S5), (S6)")
    for cs_idx, cs in enumerate(conformal_scores):
        print(f"\n(C{cs_idx+1:d})\n")
        headers = list(np.repeat([f"(S{idx+1:d})" for idx in range(3, 6)], 3))
        table = list()
        for idx2 in range(4):
            table.append([f"{names[idx2]}"] + [TDR[sc_idx, cs_idx, pi_th_idx, alpha_idx, idx2] for sc_idx in range(3, 6) for pi_th_idx in range(len(pi_th_arr))])
        print(tabulate(table, headers, tablefmt="latex", floatfmt=(".4f")))

