#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 23 10:59:48 2025

Load results generated by ScoringAnalysis.py and generate plots/table.
"""

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from scipy.stats import ecdf
from sklearn.metrics import auc


if __name__ == "__main__":
    plt.style.use("seaborn-v0_8-whitegrid")
    fsize = (9.6, 5.76)
    mpl.rcParams['figure.dpi'] = 600
    color_list = ["tab:red", "tab:olive", "tab:purple", "tab:orange"]
    name_list = ["Storey", "Quantile", "Summation", "Fisher"]
    pi_th_idx = 1
    pi = 0.1

    savename = "retina/retina_AdaDetect_SVC_discrete_uniform"

    with open(savename+".txt", "r") as file:
        lines = file.readlines()
        Kminus1 = eval(lines[2].split(" ")[1][:-1])
        K = Kminus1 + 1
        n = eval(lines[4].split(" ")[1][:-1])
        m = eval(lines[5].split(" ")[1][:-1])
        pi_model = lines[10].split(" ")[-1][:-1]
        assert lines[4].split(" ")[0][0] == "n", ""
        assert lines[10].split(" ")[0] == "pi"

    in_data = np.load(f"{savename}.npz")

    pi_true = in_data["pi_true"]
    pvals_all = in_data["pvals_all"]/(n+1)
    indicator_all = in_data["indicator_all"].astype(bool)
    pcon_total = in_data["pcon_total"]
    pi_true_flat = pi_true.flatten()
    pcon_flat = pcon_total[:, :, :, :].reshape(-1, 4)
    pi_th_arr = in_data["pi_th_arr"]
    len_pi_th_arr = len(pi_th_arr)
    sims = pi_true.shape[0]

    # =============================================================================
    # Distribution of conformal p-values of OOD samples
    # =============================================================================
    outlier_cdf_all = list()

    pvals = pvals_all
    ind = indicator_all
    outliers = pvals[ind]

    outlier_mass = np.zeros(n+1)
    for j in range(1, n+2):
        outlier_mass[j-1] = sum(np.where(np.isclose(outliers, np.ones(len(outliers))*j/(n+1)), 1, 0))
    outlier_mass = outlier_mass/np.sum(outlier_mass)
    outlier_cdf_all.append(np.cumsum(outlier_mass))

    _ = plt.figure(figsize=fsize)
    plt.plot(np.arange(1,n+2,1)/(n+1), outlier_cdf_all[0], color="tab:blue")
    plt.plot(np.arange(1,n+2,1)/(n+1), np.arange(1,n+2,1)/(n+1), color="k")
    plt.ylabel("Cumulative distribution function")
    plt.xlabel("Conformal p-value")
    plt.show()

    AUC = np.sum(outlier_cdf_all[0]) / (n+2)
    print(f"{AUC:.4f}")

    # =============================================================================
    # Distribution of conformal data contamination p-values
    # =============================================================================
    if (pi_model == "discrete_uniform"):

        pi_true = in_data["pi_true"][:, pi_th_idx, :]
        indicator_all = in_data["indicator_all"]
        pcon_flat = in_data["pcon_total"][:, pi_th_idx, :, :].reshape(-1, 4)
        pi_true_flat = pi_true.flatten()

        _ = plt.figure(figsize=fsize)
        ind = np.logical_and(pi_true_flat >= pi-0.001, pi_true_flat < pi+0.001)
        CCTest_ecdf_list = list()
        auc_data = np.zeros(4, dtype=np.float32)
        for idx2 in range(4):
            y = pcon_flat[ind, idx2]
            CCTest_ecdf_list.append(ecdf(y))
            ecdf(y).cdf.plot(linestyle="solid", color=color_list[idx2], label=f"{name_list[idx2]}")
            quantiles = np.hstack((0, ecdf(y).cdf.quantiles, 1))
            probabilities = np.hstack((0, ecdf(y).cdf.probabilities, 1))
            auc_data[idx2] = auc(quantiles, probabilities)

        plt.plot(np.linspace(0, 1, 11), np.linspace(0, 1, 11), "-k")
        plt.ylabel("Cumulative distribution function")
        plt.xlabel("Conformal contamination test p-value")
        plt.legend()
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.show()

        print(f"{auc_data[0]:.4f}, {auc_data[1]:.4f}, {auc_data[2]:.4f}, {auc_data[3]:.4f}")

