import os
import sys
import logging
import time

from os.path import dirname

sys.path.append(dirname(dirname(__file__)))

import torch
from scipy import stats
import numpy as np
import matplotlib.pyplot as plt

from statsmodels.stats.multitest import fdrcorrection

from pt_fmvhg import pmf_noncentral_fmvhg
from pt_fmvhg import pmf_noncentral_fmvhg_rsample_w
from pt_fmvhg import get_logits_fisher


def test_fisher_2c(rvs, m1, m2, n, w):
    M = int(m1) + int(m2)
    ks_res = stats.kstest(rvs, "nchypergeom_fisher", M=M, n=m1, N=n, odds=w)
    return ks_res


def test_dist_mvnchypg(m_all, n, w_all, temperature, dist_type, n_rep, device):
    if dist_type == "fisher":
        logits_func = get_logits_fisher
        scipy_func = stats.nchypergeom_fisher.rvs
    else:
        print("distribution type unavailable...exit")
        sys.exit()
    start_time = time.time()
    num_samples = n_rep
    n_classes = m_all.shape[0]
    n = n.unsqueeze(0).repeat(num_samples, 1)
    w_all = w_all.unsqueeze(0).repeat(num_samples, 1)
    rvs_all_classes = np.zeros((n_rep, n_classes))
    rvs_all_classes_ref = np.zeros((n_rep, n_classes))
    # test implementation
    # y, x, _ = pmf_noncentral_fmvhg(
    #     m_all, n, torch.log(w_all), temperature, logits_func, device=device
    # )
    y, x, _ = pmf_noncentral_fmvhg_rsample_w(
        m_all, n, torch.log(w_all), temperature, logits_func, device=device
    )
    end_time = time.time()
    time_ref = end_time - start_time
    print("our time: ", time_ref)
    for c in range(n_classes):
        rvs_all_classes[:, c] = x[c].squeeze().cpu().numpy()

    # scipy: reference implementation

    start_time = time.time()
    n_out = np.zeros((n_rep, 1))
    for c in range(n_classes - 1):
        m_i = m_all[c].squeeze(0).repeat(n_rep, 1)
        m_i = m_i.cpu().numpy()
        m_rest = m_all[c + 1 :].sum().squeeze(0).repeat(n_rep, 1)
        m_rest = m_rest.cpu().numpy()
        n_i = n.cpu().numpy() - n_out
        w_i = w_all[:, c].cpu().numpy()
        w_rest_enum = (m_all[c + 1 :] * w_all[:, c + 1 :]).sum(dim=1, keepdims=True)
        w_rest_enum = w_rest_enum.cpu().numpy()
        w_rest_denom = m_rest
        w_rest = (w_rest_enum / w_rest_denom).flatten()
        w = w_i / w_rest
        M = m_i + m_rest
        x_i = scipy_func(M.flatten(), m_i.flatten(), n_i.flatten(), w, size=n_rep)
        n_out += np.expand_dims(x_i, axis=1)
        rvs_all_classes_ref[:, c] = x_i
    rvs_all_classes_ref[:, -1] = (n.cpu().numpy() - n_out).flatten()
    end_time = time.time()
    time_ref = end_time - start_time
    print("ref time: ", time_ref)

    # n_out = 0
    str_ws = [str(w_j) for w_j in list(w_all[0].cpu().numpy().flatten())]
    str_weights = "_".join(str_ws)
    pre_fn_hist = "./results/hist_ks_test_mvnchypg_" + dist_type + "_" + str_weights
    pre_fn_cdf = "./results/cdf_ks_test_mvnchypg_" + dist_type + "_" + str_weights
    ref_hists = []
    test_hists = []
    ref_cdfs = []
    test_cdfs = []
    ref_pmfs = []
    test_pmfs = []
    ks_values = []
    p_values = []

    start_time = time.time()
    for c in range(n_classes):
        m_i = m_all[c]
        rvs_test_i = rvs_all_classes[:, c]
        rvs_ref_i = rvs_all_classes_ref[:, c]
        m_i = int(m_all[c].cpu().numpy())
        hist_test_i, _ = np.histogram(
            rvs_test_i, bins=(m_i + 1), range=(0.0, m_i), density=True
        )
        hist_ref_i, _ = np.histogram(
            rvs_ref_i, bins=(m_i + 1), range=(0.0, m_i), density=True
        )
        ref_hists.append(hist_ref_i)
        test_hists.append(hist_test_i)

        # finding the PDF of the histogram using count values
        pmf_ref_i = hist_ref_i / sum(hist_ref_i)
        cdf_ref_i = np.cumsum(pmf_ref_i)
        pmf_test_i = hist_ref_i / sum(hist_test_i)
        cdf_test_i = np.cumsum(pmf_test_i)
        ref_cdfs.append(cdf_ref_i)
        test_cdfs.append(cdf_test_i)
        ref_pmfs.append(pmf_ref_i)
        test_pmfs.append(pmf_test_i)

        res_i = stats.kstest(rvs_test_i, rvs_ref_i)
        ks_val = res_i[0]
        p_val = res_i[1]
        ks_values.append(ks_val)
        p_values.append(p_val)
        title_str = "class " + str(c).zfill(2) + ": k-s-test: "
        title_str = title_str + str(round(ks_val, 3)) + ", p-value: "
        title_str = title_str + str(round(p_val, 3))
        fn_plot_cdf = pre_fn_cdf + "_c" + str(c) + ".png"

        fig_cdf = plt.figure(figsize=(8, 4))
        ax_cdf = fig_cdf.add_subplot(2, 1, 1)
        ax_cdf.plot(pmf_test_i, alpha=0.4, label="test pmf")
        ax_cdf.plot(cdf_test_i, alpha=0.4, label="test cdf")
        ax_cdf.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))
        ax_cdf = fig_cdf.add_subplot(2, 1, 2)
        ax_cdf.plot(pmf_ref_i, alpha=0.4, label="ref pmf")
        ax_cdf.plot(cdf_ref_i, alpha=0.4, label="ref cdf")
        ax_cdf.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))
        plt.suptitle(str_weights + ": " + title_str)
        plt.tight_layout(rect=[0, 0, 0.95, 1])
        plt.draw()
        plt.savefig(fn_plot_cdf, format="png")
        plt.close()

        fn_plot_hist = pre_fn_hist + "_c" + str(c) + ".png"
        fig_hist = plt.figure(figsize=(8, 4))
        ind = np.arange(0, m_i + 1)
        ax_hist = fig_hist.add_subplot(2, 1, 1)
        ax_hist.bar(ind, hist_test_i, width=1.0, alpha=0.3, label="test")
        ax_hist.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))
        ax_hist = fig_hist.add_subplot(2, 1, 2)
        ax_hist.bar(ind, hist_ref_i, width=1.0, alpha=0.3, label="ref")
        ax_hist.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))
        plt.suptitle(str_weights + ": " + title_str)
        plt.tight_layout(rect=[0, 0, 0.95, 1])
        plt.draw()
        plt.savefig(fn_plot_hist, format="png")
        plt.close()
    p_vals = np.array(p_values)
    corr = fdrcorrection(p_vals, alpha=0.05, method="indep", is_sorted=False)
    return [
        [ref_hists, test_hists],
        [ref_cdfs, test_cdfs],
        [ref_pmfs, test_pmfs],
        [ks_values, corr],
    ]


if __name__ == "__main__":
    dir_res = "./results_rsample_omega"
    device = "cpu"
    weights_3 = np.arange(1, 11)
    dist_type = "fisher"
    temperature = 1.0
    hists_weights = []
    cdfs_weights = []
    pmfs_weights = []
    ks_weights = []
    # for w_c in weights_3:
    #     print("w_c: " + str(w_c))
    #     m = torch.tensor([200, 200, 200]).to(device)
    #     n = torch.tensor(180).to(device)
    #     w = torch.tensor([1.0, float(w_c), 1.0]).to(device)
    #     n_rvs = 50000
    #     hists, cdfs, pmfs, ks = test_dist_mvnchypg(
    #         m, n, w, temperature, dist_type, n_rvs, device
    #     )
    #     hists_weights.append(hists)
    #     cdfs_weights.append(cdfs)
    #     pmfs_weights.append(pmfs)
    #     ks_weights.append(ks)

    # plt.rc("legend", fontsize=14)
    # plt.rc("xtick", labelsize=14)
    # plt.rc("ytick", labelsize=14)
    # plt.rc("axes", titlesize=14)
    # plt.rc("axes", labelsize=14)
    # num_classes = m.shape[0]
    # pre_fn_hist = "./results/hist_ks_test_mvnchypg_" + dist_type + "_all_weights"
    # for c in range(num_classes):
    #     fig_hist = plt.figure(figsize=(8, 4))
    #     ax_hist = fig_hist.add_subplot(2, 1, 1)
    #     ax_hist = fig_hist.add_subplot(2, 1, 2)
    #     fn_plot_hist = pre_fn_hist + "_c" + str(c) + ".png"
    #     m_i = m[c]
    #     ind = np.arange(0, m_i + 1)
    #     for i_w, w_c in enumerate(weights_3):
    #         hist_test_i = hists_weights[i_w][1][c]
    #         hist_ref_i = hists_weights[i_w][0][c]
    #         axes_hist = fig_hist.get_axes()
    #         ax_hist = axes_hist[0]
    #         ax_hist.bar(ind, hist_test_i, width=1.0, alpha=0.3, label="w2=" + str(w_c))
    #         ax_hist.set_title("proposed")
    #         ax_hist = axes_hist[1]
    #         ax_hist.bar(ind, hist_ref_i, width=1.0, alpha=0.3, label="w2=" + str(w_c))
    #         ax_hist.set_title("ref")
    #         ax_hist.set_xlabel("Number of elements")
    #     handles, labels = ax_hist.get_legend_handles_labels()
    #     fig_hist.legend(handles, labels, bbox_to_anchor=(1, 0.54), loc="center right")
    #     # plt.suptitle(str_weights + ': ' + title_str)
    #     plt.tight_layout(rect=[0, 0, 0.825, 1])
    #     plt.draw()
    #     plt.savefig(fn_plot_hist, format="png")
    #     plt.close()

    # ks_values_all = []
    # p_values_all = []
    # for c in range(num_classes):
    #     m_i = m[c]
    #     ind = np.arange(0, m_i + 1)
    #     ks_values_c = []
    #     p_values_c = []
    #     for i_w, w_c in enumerate(weights_3):
    #         ks_i = ks_weights[i_w][0][c]
    #         p_i = ks_weights[i_w][1][1][c]
    #         ks_values_c.append(ks_i)
    #         p_values_c.append(p_i)
    #     ks_values_all.append(ks_values_c)
    #     p_values_all.append(p_values_c)

    # linestyles = ["-^", "-o", "-*", "--"]
    # pre_fn_ks = "./results/ks_test_mvnchypg_" + dist_type + "_all_weights"
    # fig_ks = plt.figure(figsize=(8, 4))
    # ax_ks = fig_ks.add_subplot(2, 1, 1)
    # ax_p = fig_ks.add_subplot(2, 1, 2)
    # fn_plot_ks = pre_fn_ks + ".png"
    # for c in range(num_classes):
    #     ax_ks.plot(
    #         weights_3, ks_values_all[c], linestyles[c], label="class " + str(c + 1)
    #     )
    #     ax_p.plot(
    #         weights_3, p_values_all[c], linestyles[c], label="class " + str(c + 1)
    #     )
    # ax_p.hlines(0.05, weights_3[0], weights_3[-1], color="r", label="0.05")
    # ax_ks.set_ylim(0, 0.025)
    # ax_p.set_ylim(0.035, 1.5)
    # ax_p.set_xticks(weights_3)
    # ax_ks.set_xticks(weights_3)
    # ax_p.set_xticklabels(weights_3)
    # ax_ks.set_xticklabels(weights_3)
    # ax_p.set_yscale("log")
    # ax_ks.set_title("KS test")
    # ax_p.set_title("p values")
    # ax_p.set_xlabel("w2")
    # ax_p.set_ylabel("Value")
    # ax_ks.set_ylabel("Distance")
    # handles, labels = ax_p.get_legend_handles_labels()
    # fig_ks.legend(handles, labels, bbox_to_anchor=(1, 0.54), loc="center right")
    # plt.tight_layout(rect=[0, 0, 0.825, 1])
    # plt.draw()
    # plt.savefig(fn_plot_ks, format="png")
    # plt.close()

    # n_vec = np.array([1, 10, 100, 200, 400, 600])
    # for n in n_vec:
    #     print("n: " + str(n))
    #     m = torch.tensor([200, 200, 200]).to(device)
    #     n_run = torch.tensor(n).to(device)
    #     w = torch.tensor([1.0, 5.0, 1.0]).to(device)
    #     n_rvs = 50000
    #     hists, cdfs, pmfs, ks = test_dist_mvnchypg( #         m, n_run, w, temperature, dist_type, n_rvs, device
    #     )
    #     hists_weights.append(hists)
    #     cdfs_weights.append(cdfs)
    #     pmfs_weights.append(pmfs)
    #     ks_weights.append(ks)

    # plt.rc("legend", fontsize=14)
    # plt.rc("xtick", labelsize=14)
    # plt.rc("ytick", labelsize=14)
    # plt.rc("axes", titlesize=14)
    # plt.rc("axes", labelsize=14)
    # num_classes = m.shape[0]
    # pre_fn_hist = "./results/hist_ks_test_mvnchypg_" + dist_type + "_all_n"
    # for c in range(num_classes):
    #     fig_hist = plt.figure(figsize=(8, 4))
    #     ax_hist = fig_hist.add_subplot(2, 1, 1)
    #     ax_hist = fig_hist.add_subplot(2, 1, 2)
    #     fn_plot_hist = pre_fn_hist + "_c" + str(c) + ".png"
    #     m_i = m[c]
    #     ind = np.arange(0, m_i + 1)
    #     for i_n, n in enumerate(n_vec):
    #         hist_test_i = hists_weights[i_n][1][c]
    #         hist_ref_i = hists_weights[i_n][0][c]
    #         axes_hist = fig_hist.get_axes()
    #         ax_hist = axes_hist[0]
    #         ax_hist.bar(ind, hist_test_i, width=1.0, alpha=0.3, label="n=" + str(n))
    #         ax_hist.set_title("proposed")
    #         ax_hist = axes_hist[1]
    #         ax_hist.bar(ind, hist_ref_i, width=1.0, alpha=0.3, label="n=" + str(n))
    #         ax_hist.set_title("ref")
    #         ax_hist.set_xlabel("Number of elements")
    #     handles, labels = ax_hist.get_legend_handles_labels()
    #     fig_hist.legend(handles, labels, bbox_to_anchor=(1, 0.54), loc="center right")
    #     # plt.suptitle(str_weights + ': ' + title_str)
    #     plt.tight_layout(rect=[0, 0, 0.825, 1])
    #     plt.draw()
    #     plt.savefig(fn_plot_hist, format="png")
    #     plt.close()

    # ks_values_all = []
    # p_values_all = []
    # for c in range(num_classes):
    #     m_i = m[c]
    #     ind = np.arange(0, m_i + 1)
    #     ks_values_c = []
    #     p_values_c = []
    #     for i_n, n in enumerate(n_vec):
    #         ks_i = ks_weights[i_n][0][c]
    #         p_i = ks_weights[i_n][1][1][c]
    #         ks_values_c.append(ks_i)
    #         p_values_c.append(p_i)
    #     ks_values_all.append(ks_values_c)
    #     p_values_all.append(p_values_c)

    # linestyles = ["-^", "-o", "-*", "--"]
    # pre_fn_ks = "./results/ks_test_mvnchypg_" + dist_type + "_all_n"
    # fig_ks = plt.figure(figsize=(8, 4))
    # ax_ks = fig_ks.add_subplot(2, 1, 1)
    # ax_p = fig_ks.add_subplot(2, 1, 2)
    # fn_plot_ks = pre_fn_ks + ".png"
    # for c in range(num_classes):
    #     ax_ks.plot(n_vec, ks_values_all[c], linestyles[c], label="class " + str(c + 1))
    #     ax_p.plot(n_vec, p_values_all[c], linestyles[c], label="class " + str(c + 1))
    # ax_p.hlines(0.05, n_vec[0], n_vec[-1], color="r", label="0.05")
    # ax_ks.set_ylim(0, 0.025)
    # ax_p.set_ylim(0.035, 1.5)
    # ax_p.set_xticks(n_vec)
    # ax_ks.set_xticks(n_vec)
    # ax_p.set_xticklabels(n_vec)
    # ax_ks.set_xticklabels(n_vec)
    # ax_p.set_yscale("log")
    # ax_ks.set_title("KS test")
    # ax_p.set_title("p values")
    # ax_p.set_xlabel("n")
    # ax_p.set_ylabel("Value")
    # ax_ks.set_ylabel("Distance")
    # handles, labels = ax_p.get_legend_handles_labels()
    # fig_ks.legend(handles, labels, bbox_to_anchor=(1, 0.54), loc="center right")
    # plt.tight_layout(rect=[0, 0, 0.825, 1])
    # plt.draw()
    # plt.savefig(fn_plot_ks, format="png")
    # plt.close()

    m_vec_2 = np.array([1, 10, 100, 200, 300])
    for m2 in m_vec_2:
        print("m2: " + str(m2))
        m = torch.tensor([200, m2, 200]).to(device)
        n_run = torch.tensor(200).to(device)
        w = torch.tensor([1.0, 5.0, 1.0]).to(device)
        n_rvs = 50000
        hists, cdfs, pmfs, ks = test_dist_mvnchypg(
            m, n_run, w, temperature, dist_type, n_rvs, device
        )
        hists_weights.append(hists)
        cdfs_weights.append(cdfs)
        pmfs_weights.append(pmfs)
        ks_weights.append(ks)

    plt.rc("legend", fontsize=14)
    plt.rc("xtick", labelsize=14)
    plt.rc("ytick", labelsize=14)
    plt.rc("axes", titlesize=14)
    plt.rc("axes", labelsize=14)
    num_classes = m.shape[0]
    pre_fn_hist = os.path.join(
        dir_res, "hist_ks_test_mvnchypg_" + dist_type + "_all_m2"
    )
    for c in range(num_classes):
        fig_hist = plt.figure(figsize=(8, 4))
        ax_hist = fig_hist.add_subplot(2, 1, 1)
        ax_hist = fig_hist.add_subplot(2, 1, 2)
        fn_plot_hist = pre_fn_hist + "_c" + str(c) + ".png"
        m_i = m[c]
        for i_m, m2 in enumerate(m_vec_2):
            if c == 1:
                ind = np.arange(0, m2 + 1)
            else:
                ind = np.arange(0, m[c] + 1)
            hist_test_i = hists_weights[i_m][1][c]
            hist_ref_i = hists_weights[i_m][0][c]
            axes_hist = fig_hist.get_axes()
            ax_hist = axes_hist[0]
            ax_hist.bar(ind, hist_test_i, width=1.0, alpha=0.3, label="m2=" + str(m2))
            ax_hist.set_title("proposed")
            ax_hist = axes_hist[1]
            ax_hist.bar(ind, hist_ref_i, width=1.0, alpha=0.3, label="m2=" + str(m2))
            ax_hist.set_title("ref")
            ax_hist.set_xlabel("Number of elements")
        handles, labels = ax_hist.get_legend_handles_labels()
        fig_hist.legend(handles, labels, bbox_to_anchor=(1, 0.54), loc="center right")
        plt.tight_layout(rect=[0, 0, 0.825, 1])
        plt.draw()
        plt.savefig(fn_plot_hist, format="png")
        plt.close()

    ks_values_all = []
    p_values_all = []
    for c in range(num_classes):
        ks_values_c = []
        p_values_c = []
        for i_m, m2 in enumerate(m_vec_2):
            ks_i = ks_weights[i_m][0][c]
            p_i = ks_weights[i_m][1][1][c]
            ks_values_c.append(ks_i)
            p_values_c.append(p_i)
        ks_values_all.append(ks_values_c)
        p_values_all.append(p_values_c)

    linestyles = ["-^", "-o", "-*", "--"]
    pre_fn_ks = os.path.join(dir_res, "ks_test_mvnchypg_" + dist_type + "_all_m2")
    fig_ks = plt.figure(figsize=(8, 4))
    ax_ks = fig_ks.add_subplot(2, 1, 1)
    ax_p = fig_ks.add_subplot(2, 1, 2)
    fn_plot_ks = pre_fn_ks + ".png"
    for c in range(num_classes):
        ax_ks.plot(
            m_vec_2, ks_values_all[c], linestyles[c], label="class " + str(c + 1)
        )
        ax_p.plot(m_vec_2, p_values_all[c], linestyles[c], label="class " + str(c + 1))
    ax_p.hlines(0.05, m_vec_2[0], m_vec_2[-1], color="r", label="0.05")
    ax_ks.set_ylim(0, 0.025)
    ax_p.set_ylim(0.035, 1.5)
    ax_p.set_xticks(m_vec_2)
    ax_ks.set_xticks(m_vec_2)
    ax_p.set_xticklabels(m_vec_2)
    ax_ks.set_xticklabels(m_vec_2)
    ax_p.set_yscale("log")
    ax_ks.set_title("KS test")
    ax_p.set_title("p values")
    ax_p.set_xlabel("m2")
    ax_p.set_ylabel("Value")
    ax_ks.set_ylabel("Distance")
    handles, labels = ax_p.get_legend_handles_labels()
    fig_ks.legend(handles, labels, bbox_to_anchor=(1, 0.54), loc="center right")
    plt.tight_layout(rect=[0, 0, 0.825, 1])
    plt.draw()
    plt.savefig(fn_plot_ks, format="png")
    plt.close()
