import pickle
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.optimize import brentq


def compute_threshold(alpha, base, power):
    def target_func(th):
        log_bonf_alpha = np.log(alpha) - np.log(base) * power
        return np.log(2.0) + norm.logcdf(-th) - log_bonf_alpha
    th = brentq(target_func, 0.0, 1000.0)
    return th

def plot_fpr(alpha=0.05):
    image_sizes = [8, 16, 32, 64]
    result_dict = {}
    for category in ["iid", "corr"]:
        naive_fpr_list, selective_fpr_list, oc_fpr_list, bonferroni_fpr_list, permutation_fpr_list, calc_t_list = [], [], [], [], [], []
        for image_size in image_sizes:
            naive_p_values, selective_p_values, oc_p_values, z_list, permutation_p_values, calc_t = [], [], [], [], [], []
            for seed in range(10):
                pkl_path = f"../results/{category}/fpr/{category}_size{image_size}_signal0_seed{seed}.pickle"
                if not os.path.exists(pkl_path):
                    print("pickle not found")
                    continue

                with open(pkl_path, "rb") as f:
                    result_pickle = pickle.load(f)
                naive_p_values.extend(result_pickle["naive_p_values"][:100])
                selective_p_values.extend(result_pickle["selective_p_values"][:100])
                oc_p_values.extend(result_pickle["oc_p_values"][:100])
                z_list.extend(result_pickle["z"][:100])
                permutation_p_values.extend(result_pickle["permutation_p_values"][:100])
                calc_t.extend(result_pickle["time"][:100])

            iter = len(naive_p_values)
            naive_p_values = np.array(naive_p_values[:1000])
            selective_p_values = np.array(selective_p_values[:1000])
            oc_p_values = np.array(oc_p_values[:1000])
            z_list = np.array(z_list[:1000])
            permutation_p_values = np.array(permutation_p_values[:1000])
            fpr = lambda p_values: np.sum(np.array(p_values) < alpha) / len(p_values)
            naive_fpr = fpr(naive_p_values)
            selective_fpr = fpr(selective_p_values)
            oc_fpr = fpr(oc_p_values)
            bonferroni_fpr = np.mean(z_list > compute_threshold(0.05, 2, image_size**2))
            permutation_fpr = fpr(permutation_p_values)

            naive_fpr_list.append(naive_fpr)
            selective_fpr_list.append(selective_fpr)
            oc_fpr_list.append(oc_fpr)
            bonferroni_fpr_list.append(bonferroni_fpr)
            permutation_fpr_list.append(permutation_fpr)
            calc_t_list.append(np.mean(calc_t))
            if selective_fpr_list == []:
                continue
        result_dict[category] = calc_t_list

        plt.rcParams["font.size"] = 14
        fig = plt.figure(figsize=(4, 4), dpi=100)
        ax = fig.add_subplot(111)
        ax.plot([1, 2, 3, 4], [alpha]*4, linestyle='dashed', color='black')
        ax.plot([1, 2, 3, 4], selective_fpr_list, "x-", label="DMAD-test", zorder=10)
        ax.plot([1, 2, 3, 4], oc_fpr_list, "x-", label="DMAD-test-oc", zorder=9)
        ax.plot([1, 2, 3, 4], bonferroni_fpr_list, "x-", label="bonferroni", zorder=8)
        ax.plot([1, 2, 3, 4], permutation_fpr_list, "x-", label="permutation", zorder=7)
        ax.plot([1, 2, 3, 4], naive_fpr_list, "x-", label="naive", zorder=6)
        ax.grid()
        ax.set_yticks(np.arange(0.0, 0.51, 0.1))

        ax.set_xticks([1, 2, 3, 4])
        ax.set_xticklabels(["64", "256", "1024", "4096"])
        ax.set_xlabel('Image size')
        ax.set_ylabel('Type I Error Rate')
        plt.legend(loc='upper left')
        plt.savefig(f"figures/syn/{category}_fpr_alpha{alpha}.pdf", bbox_inches="tight", pad_inches=0.0)




def plot_power(image_size=64, alpha=0.05):
    exp = "power"
    for category in ["iid", "corr"]:
        naive_reject_rate_list, selective_reject_rate_list, oc_reject_rate_list, bonferroni_reject_rate_list, permutation_reject_rate_list = [], [], [], [], []
        for signal in range(1, 5):
            naive_p_values, selective_p_values, oc_p_values, z_list, permutation_p_values = [], [], [], [], []
            for seed in range(10):
                pkl_path = f"../results/{category}/{exp}/{category}_size{image_size}_signal{signal}_seed{seed}.pickle"
                if not os.path.exists(pkl_path):
                    print("pickle not found")
                    continue

                with open(pkl_path, "rb") as f:
                    result_pickle = pickle.load(f)
                naive_p_values.extend(result_pickle["naive_p_values"][:100])
                selective_p_values.extend(result_pickle["selective_p_values"][:100])
                oc_p_values.extend(result_pickle["oc_p_values"][:100])
                z_list.extend(result_pickle["z"][:100])
                permutation_p_values.extend(result_pickle["permutation_p_values"][:100])

            naive_p_values = np.array(naive_p_values[:1000])
            selective_p_values = np.array(selective_p_values[:1000])
            oc_p_values = np.array(oc_p_values[:1000])
            z_list = np.array(z_list[:1000])
            permutation_p_values = np.array(permutation_p_values[:1000])
            fpr = lambda p_values: np.sum(np.array(p_values) < alpha) / len(p_values)
            naive_fpr = fpr(naive_p_values)
            selective_fpr = fpr(selective_p_values)
            oc_fpr = fpr(oc_p_values)
            bonferroni_fpr = np.mean(z_list > compute_threshold(0.05, 2, image_size**2))
            permutation_fpr = fpr(permutation_p_values)

            naive_reject_rate_list.append(naive_fpr)
            selective_reject_rate_list.append(selective_fpr)
            oc_reject_rate_list.append(oc_fpr)
            bonferroni_reject_rate_list.append(bonferroni_fpr)
            permutation_reject_rate_list.append(permutation_fpr)
            if selective_reject_rate_list == []:
                continue

        plt.rcParams["font.size"] = 14
        fig = plt.figure(figsize=(4, 4), dpi=100)
        ax = fig.add_subplot(111)

        ax.plot([1, 2, 3, 4], selective_reject_rate_list, "x-", label="DMAD-test")
        ax.plot([1, 2, 3, 4], oc_reject_rate_list, "x-", label="DMAD-test-oc")
        ax.plot([1, 2, 3, 4], bonferroni_reject_rate_list, "x-", label="bonferroni")
        ax.grid()
        ax.set_yticks(np.arange(0, 1.1, 0.1))

        ax.set_xticks([1, 2, 3, 4])
        ax.set_xticklabels(["1", "2", "3", "4"])
        ax.set_xlabel(r"$\Delta$")
        ax.set_ylabel('Power')
        plt.legend(loc='upper left')
        plt.savefig(f"figures/syn/{category}_{exp}_alpha{alpha}.pdf", bbox_inches="tight", pad_inches=0.0)


def plot_robust():
    image_size = 16
    categories = ["SND", "EMG", "GND", "t-distribution"]
    alphas = [0.05, 0.1]

    for category in categories:
        result_dict = {alpha: [] for alpha in alphas}
        
        for distance in [0.01, 0.02, 0.03, 0.04]:
            selective_p_values = []
            for seed in range(10):
                pkl_path = f"../results/robust/fpr/{category}_size{image_size}_distance{distance}_seed{seed}.pickle"
                if not os.path.exists(pkl_path):
                    print(f"Pickle not found for seed {seed}")
                    continue

                with open(pkl_path, "rb") as f:
                    result_pickle = pickle.load(f)
                selective_p_values.extend(result_pickle["selective_p_values"][:100])

            selective_p_values = np.array(selective_p_values[:1000])
            
            for alpha in alphas:
                fpr = lambda p_values: np.sum(np.array(p_values) < alpha) / len(p_values)
                selective_reject_rate = fpr(selective_p_values)
                result_dict[alpha].append(selective_reject_rate)

        plt.rcParams["font.size"] = 14
        fig = plt.figure(figsize=(4, 4), dpi=100)
        ax = fig.add_subplot(111)

        for alpha, values in result_dict.items():
            if not values:
                continue
            ax.plot([0.01, 0.02, 0.03, 0.04], [alpha]*4, linestyle='dashed', color='black')
            ax.plot([0.01, 0.02, 0.03, 0.04], values, '-x', label=r'$\alpha$ = ' + f"{alpha}")
        ax.set_yticks(np.arange(0.0, 0.51, 0.1))
        ax.set_xticks([0.01, 0.02, 0.03, 0.04])
        ax.set_xticklabels(["0.01", "0.02", "0.03", "0.04"])
        ax.grid()
        ax.set_xlabel('Wasserstein Distance')
        ax.set_ylabel('Type I Error Rate')
        plt.legend(loc='upper left')
        plt.savefig(f"figures/syn/robust_{category}.pdf", bbox_inches="tight", pad_inches=0.0)
        plt.close()


if __name__ == "__main__":
    plot_fpr()
    plot_power()
    plot_robust()
