#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep  2 08:16:29 2025

Load results generated by ScoringAnalysis.py and generate plot of distributions
of conformal p-values.
"""

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np


if __name__ == "__main__":
    plt.style.use("seaborn-v0_8-whitegrid")
    fsize = (9.6, 5.76)
    mpl.rcParams['figure.dpi'] = 600

    scenarios = ["retina/retina",
                 "retinalOCT_partial/retinalOCT_partial",
                 "WBC_partial/WBC_partial",
                 "FEMNIST_partial/FEMNIST_partial",
                 "MNIST_partial_fn/MNIST_partial_fn",
                 "MNIST_partial_ln/MNIST_partial_ln"]
    conformal_scores = ["OCSVM",
                        "IF",
                        "AE",
                        "AdaDetect_LR",
                        "AdaDetect_SVC",
                        "labelOCSVM",
                        "labelIF",
                        "labelAE",
                        "labelAdaDetect_LR",
                        "labelAdaDetect_SVC"]

    colors = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple",
              "tab:brown", "tab:pink", "tab:gray", "tab:olive", "tab:cyan"]

    for sc_idx, sc in enumerate(scenarios):
        type_name = sc.split("/")[-1]
        _ = plt.figure(figsize=fsize)
        plt.title(f"{sc}")
        for cs_idx, cs in enumerate(conformal_scores):
            savename = f"{sc}_{cs}_discrete_uniform"
            with open(savename+".txt", "r") as file:
                lines = file.readlines()
                n = eval(lines[4].split(" ")[1][:-1])
                pi_model = lines[10].split(" ")[-1][:-1]
                assert lines[4].split(" ")[0][0] == "n", ""
                assert lines[10].split(" ")[0] == "pi"
            assert pi_model == "discrete_uniform", ""
            in_data = np.load(f"{savename}.npz")
            pvals_all = in_data["pvals_all"]/(n+1)
            indicator_all = in_data["indicator_all"].astype(bool)

            outliers = pvals_all[indicator_all]
            outlier_mass = np.zeros(n+1)
            for j in range(1, n+2):
                outlier_mass[j-1] = sum(np.where(np.isclose(outliers, np.ones(len(outliers))*j/(n+1)), 1, 0))
            outlier_mass = outlier_mass/np.sum(outlier_mass)
            outlier_cdf = np.cumsum(outlier_mass)

            plt.plot(np.arange(1,n+2,1)/(n+1), outlier_cdf, color=colors[cs_idx], label=f"{cs}")

        plt.plot(np.arange(1,n+2,1)/(n+1), np.arange(1,n+2,1)/(n+1), color="k")
        plt.legend()
        plt.ylabel("Cumulative distribution function")
        plt.xlabel("Conformal p-value")
        plt.show()

            