#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep  2 14:01:37 2025

Load results generated by ScoringAnalysis.py and generate big FDR table.
"""


import numpy as np
from tabulate import tabulate

import sys, os
if os.getcwd() not in sys.path: sys.path.append(os.getcwd())

from Benjamini_Hochberg import Storeys_correction, Benjamini_Hochberg_procedure


if __name__ == "__main__":
    scenarios = ["retina/retina",
                  "retinalOCT_partial/retinalOCT_partial",
                  "WBC_partial/WBC_partial",
                  "FEMNIST_partial/FEMNIST_partial",
                  "MNIST_partial_fn/MNIST_partial_fn",
                  "MNIST_partial_ln/MNIST_partial_ln"]
    # conformal_scores = ["OCSVM",
    #                     "IF",
    #                     "AE",
    #                     "AdaDetect_LR",
    #                     "AdaDetect_SVC",
    #                     "labelOCSVM",
    #                     "labelIF",
    #                     "labelAE",
    #                     "labelAdaDetect_LR",
    #                     "labelAdaDetect_SVC"]
    conformal_scores = ["AdaDetect_SVC",
                        "labelAdaDetect_LR"]
    contamination_sim = "fixed_balanced"
    names = ["Storey", "Quantile", "Fisher", "Sum"]
    pi_th_arr = [0.0, 0.1, 0.2]
    alpha_idx = 4
    gamma = 0.5 # Storey's hyperparameter conformal contamination test
    num_rejection_thresholds = 40
    significance_levels = np.linspace(0, 0.975, num_rejection_thresholds)

    FDR = np.zeros((len(scenarios), len(conformal_scores),
                    len(pi_th_arr), num_rejection_thresholds, 4), dtype=np.float32)
    for sc_idx, sc in enumerate(scenarios):
        for cs_idx, cs in enumerate(conformal_scores):
            print(f"{sc_idx+1:d}/{len(scenarios):d}, {cs_idx+1:d}/{len(conformal_scores):d}")
            savename = f"{sc}_{cs}_{contamination_sim}"
            with open(savename+".txt", "r") as file:
                lines = file.readlines()
                Kminus1 = eval(lines[2].split(" ")[1][:-1])
                K = Kminus1 + 1
            in_data = np.load(f"{savename}.npz")
            pi_true = in_data["pi_true"]
            pcon_total = in_data["pcon_total"]
            sims = pi_true.shape[0]

            FDP_total = np.zeros((sims, len(pi_th_arr), num_rejection_thresholds, 4), dtype=np.float32)
            for sim in range(sims):
                for idx_alpha, alpha in enumerate(significance_levels):
                    for pi_th_idx, pi_th in enumerate(pi_th_arr):
                        for idx2_, idx2 in enumerate([0, 1, 3, 2]):
                            pvals = pcon_total[sim, pi_th_idx, :, idx2]
                            pi_k = pi_true[sim, pi_th_idx, :]
                            _, K0_hat = Storeys_correction(gamma, pvals, K-1, False)
                            rejectBool_ = Benjamini_Hochberg_procedure(pvals, alpha, K0_hat, K-1)
                            # rejectBool_ = Benjamini_Hochberg_procedure(pvals, alpha, int(K-1)/2, K-1)
                            R = np.sum(rejectBool_)
                            RcapH = np.sum(rejectBool_[pi_k<=pi_th])
                            if R == 0:
                                FDP = 0
                            else:
                                FDP = RcapH/R
                            FDP_total[sim, pi_th_idx, idx_alpha, idx2_] = FDP
            FDR[sc_idx, cs_idx] = np.mean(FDP_total, axis=0)

    print("\n FDR table scenarios (S1), (S2), (S3)")
    for cs_idx, cs in enumerate(conformal_scores):
        print(f"\n{cs}\n")
        headers = list(np.repeat([f"(S{idx+1:d})" for idx in range(0, 3)], 3))
        table = list()
        for idx2 in range(4):
            table.append([f"{names[idx2]}"] + [FDR[sc_idx, cs_idx, pi_th_idx, alpha_idx, idx2] for sc_idx in range(0, 3) for pi_th_idx in range(len(pi_th_arr))])
        print(tabulate(table, headers, tablefmt="latex", floatfmt=(".4f")))

    print("\n FDR table scenarios (S4), (S5), (S6)")
    for cs_idx, cs in enumerate(conformal_scores):
        print(f"\n{cs}\n")
        headers = list(np.repeat([f"(S{idx+1:d})" for idx in range(3, 6)], 3))
        table = list()
        for idx2 in range(4):
            table.append([f"{names[idx2]}"] + [FDR[sc_idx, cs_idx, pi_th_idx, alpha_idx, idx2] for sc_idx in range(3, 6) for pi_th_idx in range(len(pi_th_arr))])
        print(tabulate(table, headers, tablefmt="latex", floatfmt=(".4f")))

