import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from collections import defaultdict
import statsmodels.api as sm
import seaborn as sns
import numpy as np
import pandas as pd
from itertools import product
from scipy.stats import pearsonr, spearmanr
from notation_config import NotationConfig
from operator import itemgetter

sns.set_style("darkgrid")

font = {"family": "sans-serif", "size": 16}
plt.rc("font", **font)
plt.rcParams["figure.dpi"] = 300

ADDITIONAL_PLOT_OPTIONS = ["limited", "counts", None]

"""
This function plots multiple subplots of data. The data for each subplot is defined by an element of x_data and an element of y_data. 
Each element is a list of integers.
The subplots are arranged in a grid defined by the number of arrays in x_data (rows) and in y_data (columns).

The function is useful for plotting the relationship between multiple variables. 

Parameters:
    x_data: a list of arrays where each array contains the x data for each subplot. Each array should have the same length.
    x_labels: a list of strings containing the x labels for each subplot. 
    y_data: a list of arrays where each array contains the y data for each subplot. Each array should have the same length.
    y_labels: a list of strings containing the y labels for each subplot
    c (optional): The array used to color each x,y point for all subplots. Should be equal to len of array in x_data/
    figtitle (optional): a string containing the title for the figure. If not provided, no title will be displayed.
    slopes_to_plot (optional): a list of floats containing the slopes to plot for each column of subplots. If not provided, no slopes will be plotted.
    figsize (optional): a tuple containing the width and height of the figure in inches. Default is (20, 15).
    regress (optional): a boolean indicating whether or not to plot a linear regression line for each subplot. Default is False.
    clusters (optional): a list of arrays containing the cluster labels for each subplot. Clusters are used only when regress = True

Returns:
A dictionary containing the residuals for each subplot, with the keys being a tuple of the form (i, j) where i is the index of the x_data element and j is the index of the y_data element. 

"""


def plot_all_relationships(
    x_data,
    x_labels,
    y_data,
    y_labels,
    c=None,
    figtitle=None,
    slopes_to_plot=None,
    figsize=(20, 15),
    regress=False,
    clusters=None,
):
    assert len(x_data) == len(x_labels)  # guarantee size
    assert len(y_data) == len(y_labels)

    fig, axs = plt.subplots(len(x_data), len(y_data), figsize=figsize)

    all_residuals = defaultdict(int)

    for i, x_i_data in enumerate(x_data):
        for j, y_j_data in enumerate(y_data):
            assert len(x_i_data) == len(y_j_data)
            # axs can be a single Axes, an array of Axes, or an array of array of Axes
            if len(x_data) == 1 and len(y_data) == 1:
                ax = axs  # axs is an Axes
            elif len(x_data) == 1:
                ax = axs[j]  # axs is an array of Axes
            elif len(y_data) == 1:
                ax = axs[i]  # axs is an array of Axes
            else:
                ax = axs[i][j]  # axs is an array of array of Axes

            legend = False
            # Only sent legend = True on the top right subplot
            if c is not None and i == 0 and j == (len(y_data) - 1):
                legend = True

            if slopes_to_plot is not None:
                slope_to_plot = slopes_to_plot[j]
                print(slope_to_plot)
            else:
                slope_to_plot = None

            y_label = "j == {}: {}".format(j, y_labels[j])
            x_label = "i == {}: {}".format(i, x_labels[i])
            residuals = plot_single_plot(
                x_i_data,
                x_label,
                y_j_data,
                y_label,
                ax,
                legend=legend,
                regress=regress,
                c=c,
                clusters=clusters,
                slope_to_plot=slope_to_plot,
            )
            all_residuals[(i, j)] = residuals

    fig.suptitle(figtitle)
    return all_residuals


"""
This function plots a single scatter plot of y_j_data versus x_i_data. It can also optionally plot a linear regression line, a line with a specified slope, and/or color the points according to the values in c.

Parameters:
    x_i_data: an array containing the x data for the plot.
    x_label: a string containing the x label for the plot.
    y_j_data: an array containing the y data for the plot.
    y_label: a string containing the y label for the plot.
    ax: the axis object for the plot.
    legend (optional): a boolean indicating whether or not to display a legend. Default is True.
    regress (optional): a boolean indicating whether or not to plot a linear regression line. Default is False.
    c (optional): an array containing the colors for the points. If not provided, all points will use the same color.
    clusters (optional): an array containing the cluster labels for the points. If provided, the linear regression will be adjusted for clustering.
    slope_to_plot (optional): a float containing the slope to plot. If not provided, no slope will be plotted.
    plot_title (optional): a string containing the title for the plot. If not provided, no title will be displayed.
    palette (optional): A string corresponding to a matplotlib palette. If not provided, default matplotlib color will be used.

Returns: The residuals of the linear regression, if regress is True. Otherwise, returns None. """


def plot_single_plot(
    x_i_data,
    x_label,
    y_j_data,
    y_label,
    ax,
    legend=True,
    regress=False,
    c=None,
    clusters=None,
    slope_to_plot=None,
    plot_title=None,
    palette=None,
):
    scatter = sns.scatterplot(
        x=x_i_data, y=y_j_data, hue=c, ax=ax, legend=legend, palette=palette
    )
    residuals = None
    if legend:
        scatter.legend(bbox_to_anchor=(1.03, 1))
    if regress:
        # Calculate regression and plot line of best fit
        x_i_data_constant = sm.tools.add_constant(x_i_data)
        if clusters is not None:
            regress_results = sm.OLS(y_j_data, x_i_data_constant).fit(
                cov_type="cluster", cov_kwds={"groups": clusters}
            )
        else:
            regress_results = sm.OLS(y_j_data, x_i_data_constant).fit()

        coefs = regress_results.params
        a, b = coefs[0], coefs[1]
        residuals = regress_results.resid
        r = round(np.sqrt(regress_results.rsquared), 2)

        # Plot line of best fit
        xseq = np.linspace(min(x_i_data), max(x_i_data), num=100)
        ax.plot(xseq, a + b * xseq, color="k", lw=1.5, linestyle="--")

        # Calculate Pearson and Spearman correlations
        pearsonC, pearsonP = pearsonr(x_i_data, y_j_data)
        spearmanC, spearmanP = spearmanr(x_i_data, y_j_data)
        s = "Pearson Corr: {pc}, P val: {pp} \nSpearman Corr: {sc}, P val: {sp}\n R: {r}".format(
            pc=round(pearsonC, 2),
            pp=round(pearsonP, 2),
            sc=round(spearmanC, 2),
            sp=round(spearmanP, 2),
            r=r,
        )
        if (
            b < 0
        ):  # If b is negative, want to put text in top right to avoid overlap with top left of line
            locx, locy = 0.40, 0.85
        else:
            locx, locy = 0.05, 0.85
        ax.text(locx, locy, s=s, transform=ax.transAxes)

    if slope_to_plot is not None:
        xseq = np.linspace(0, max(x_i_data), num=100)
        ax.plot(xseq, 0 + slope_to_plot * xseq, color="red")

    if plot_title is not None:
        ax.set_title(plot_title)

    ax.set_ylabel(y_label)
    ax.set_xlabel(x_label)

    return residuals


"""
This function plots multiple versions of a scatter plot of y_data versus x_data, each with a different color scheme specified by c_lists.

Parameters:
    x_data: an array of arrays. each subarray represents the x data for a plot
    x_labels: an array containing strings containing the x label for the plots.
    y_data: an array of arrays. each subarray containing the y data for the plots.
    y_label: an array containing strings containing the y label for the plots.
    figsize (optional): a tuple containing the width and height of the figure in inches. Default is (20, 7.5).
    palette (optional): a string corresponding to a matplotlib palette.
    regress (optional): a boolean indicating whether or not to plot a linear regression line for each plot. Default is False.
    c_lists: a list of arrays, where each array specifies the colors for one plot. Needs to be provided.
    title_list (optional): a list of strings containing the titles for the plots. If not provided, no titles will be displayed. """


def plot_multiple_color_schemes(
    x_data,
    x_labels,
    y_data,
    y_labels,
    figsize=(20, 7.5),
    palette=None,
    regress=False,
    c_lists=None,
    title_list=None,
):
    assert c_lists is not None
    nrows = len(x_data) * len(y_data)
    fig, axs = plt.subplots(nrows=nrows, ncols=len(c_lists), figsize=figsize)
    for i, x_i in enumerate(x_data):
        x_label = x_labels[i]
        for j, y_j in enumerate(y_data):
            y_label = y_labels[j]
            for k, c in enumerate(c_lists):
                if nrows > 1:
                    axs_row = axs[i + j]
                else:
                    axs_row = axs
                if len(c_lists) > 1:
                    ax = axs_row[k]
                else:
                    ax = axs_row

                if title_list == None:
                    plot_title = None
                else:
                    plot_title = title_list[k]
                plot_single_plot(
                    x_i,
                    x_label,
                    y_j,
                    y_label,
                    ax,
                    palette=palette,
                    regress=regress,
                    c=c,
                    plot_title=plot_title,
                )


"""
The sort_col_row function takes a NumPy array R_np as input and returns a tuple of two NumPy arrays: R_rows_sorted and error_rates_sorted.

The R_rows_sorted array is a sorted version of the input array R_np, where the ordering of the rows is determined by the product of I(error_rates) where I is the indicator variable (top->bottom : smallest->biggest)
If col_sorting_order is None, the columns are sorted from least errorful model to most errorful model; else, it's sorded according to col_sorting_order

The error_rates_sorted array is a sorted version of the error rates of the models in the input array R_np, where the error rates are sorted in ascending order.

Parameters: 
- R_np (np.array): A NumPy array representing the predictions made by different models for different users. The rows represent the users and the columns represent the models.
- col_sorting_order (list): a list that specifies the ordering of the columns in R_rows_sorted. IF none, uses the error rates of the model.

Returns:
- R_rows_sorted (np.array): A NumPy array representing the sorted version of the input array R_np, where each row is sorted according to the probability of that combination of outcomes occuring.

- error_rates_sorted (1D NumPy array): A 1D NumPy array representing the sorted error rates of the models in the input array R_np, where the error rates are sorted in ascending order.

- ordering: Returns the col_sorting_order that was used to sort the columns of R_np
"""


def sort_col_row(R_np, col_sorting_order=None):
    error_rates = np.nansum(R_np, axis=0) / np.count_nonzero(~np.isnan(R_np), axis=0)
    if col_sorting_order is None:
        ordering = np.argsort(error_rates)
        error_rates_sorted = error_rates[
            ordering
        ]  # Sort error rates using provided col_sorting_order
        R_cols_sorted = R_np[
            :, ordering
        ]  # sort R_np columns by provided col_sorting_order
    else:  # if col sorting order not specified, sort columns by error rates of model
        ordering = col_sorting_order
        error_rates_sorted = error_rates[
            ordering
        ]  # Sort error rate least error -> most error

        R_cols_sorted = R_np[
            :, ordering
        ]  # sort column by error rate of model (least error -> most error)

    error = R_cols_sorted * error_rates_sorted

    epsilon = 0.000005  # This very small number is used to sort nan (non interactions) values below failures,
    error = np.nan_to_num(error, nan=(1 + epsilon))
    prod_error = np.reshape(np.nanprod(error, axis=1, where=error > 0), (-1, 1))

    R_with_err = np.concatenate([R_cols_sorted, prod_error], axis=1)
    R_rows_sorted = R_with_err[R_with_err[:, -1].argsort()][:, :-1]
    return R_rows_sorted, error_rates_sorted, ordering


"""
The plot_error_matrix function plots two heatmaps of the input array R_rows_sorted, with the error rates of the models in the input array error_rates_sorted displayed along the x-axis. The heatmaps are displayed in the Axes specified by the input array axs_row.
The first heatmap displays all the rows of R_rows_sorted. The second heatmap displays only those rows of R_rows_sorted that have at least one failure.
The rows in both heatmaps are sorted according to the product of the error rates of the models in the corresponding columns. The error rates are displayed in the x-axis tick labels. In the heatmap, black represents a correct model prediction and red represents an incorrect prediction.

Parameters
R_rows_sorted (2D NumPy array): A 2D NumPy array representing the sorted version of the input array R_np, where each row is sorted according to the error rates of the models in the corresponding columns.
error_rates_sorted (1D NumPy array): A 1D NumPy array representing the sorted error rates of the models in the input array R_np, where the error rates are sorted in ascending order.
axs_row (array): An array of Axes objects, where the heatmaps will be plotted.
ncols: Either 1 or 2; 1 if only plotting a single error matrix: 2 if plotting error matrix and the subset of error matrix where there is at least 1 failure.
prefix (string, optional): A string that will be displayed as the title of the heatmaps. Default is an empty string."""


def plot_error_matrix(R_rows_sorted, error_rates_sorted, axs_row, ncols, prefix=""):
    if ncols == 1:
        ax = axs_row
    else:
        ax = axs_row[0]

    prefix = str(prefix)  # ensure prefix is string
    ax.set_title(prefix + " All rows")
    sns.heatmap(
        R_rows_sorted,
        cmap=["black", "indianred"],
        cbar=False,
        ax=ax,
        xticklabels=error_rates_sorted.round(3),
    )
    if ncols > 1:
        ax = axs_row[1]
        ax.set_title(prefix + " Rows with at least 1 failure")
        mask = R_rows_sorted == [0] * R_rows_sorted.shape[1]
        first_row_all_correct = np.argmax(
            np.all(mask, axis=1)
        )  ### First row that isn't all 0s. Assumes that all correct is most probable outcome
        all_correct_df = R_rows_sorted[:first_row_all_correct]
        if len(all_correct_df) > 0:
            sns.heatmap(
                all_correct_df,
                cmap=["black", "indianred"],
                cbar=False,
                ax=ax,
                xticklabels=error_rates_sorted.round(3),
            )


"""
The plot_all_error_matrices function plots multiple heatmaps of the input array R_observed and a dictionary of simulation arrays R_additional. 
The first heatmap displays all the rows of R_observed. The second heatmap displays only those rows of R_observed that have at least one failure (value of 0). The remaining heatmaps display the corresponding rows of the arrays in R_additional.
The rows in all the heatmaps are sorted according to the product of the error rates of the models in the corresponding columns. Black represents values correct predictions and red represents incorrect predictions. 
The function also plots two horizontal bar plots showing the proportion of rows for each combination of model outcomes, for R_observed and R_additional. The first bar plot shows the proportions for all the rows, and the second bar plot shows the proportions for only those rows that have at least one failure. 

Parameters
R_all (dictionary): A dictionary of {string: 2D NumPy array} pairs, where the keys are the names of the arrays and the values are the matrices that we will plot. The title is used to label the matrix plots and the bar plots.
title (string): A string containing the title of the entire figure.
additional_plot (string, optional): A string that specifies which additional plots to include. Options are 'limited' (default) and 'counts'. If 'limited', it will only plot only users with at least one failure. If 'counts', it will plot the unnormlized counts of each combination of model outcomes.
palette (string or list, optional): A string or list of colors that will be used to color the bar plots. Default is 'Set1'.
plot_matrices (bool, optional): If true, plots the red and black error matrices in addition to histograms. If false, just plots histograms
display_title: If True, displays title for the figure
orientation: Orientation of the bar plots
y_max: If not None, the range of the y axix for the bar plots if (0, y_max)

Output
fig (matplotlib figure): The figure containing the heatmaps and bar plots.
axs (array of array of Axes): A 2D NumPy array of subplots, where the heatmaps and bar plots are plotted.
all_left_plot_data: Returns all data included in the left plot
all_right_plot_data: Returns all data included in the right plot
"""


def plot_error_matrices_and_histograms(
    R_all,
    title,
    additional_plot="limited",
    palette="Set1",
    plot_matrices=True,
    display_title=True,
    orientation="vertical",
    y_max=None,
):
    assert additional_plot in ADDITIONAL_PLOT_OPTIONS
    BASE_HEIGHT = 7.5

    nrows = 1
    if plot_matrices:
        nrows += len(R_all)

    height = BASE_HEIGHT * nrows
    height_ratios = [1 for n in range(nrows)]
    height_ratios[-1] = max(
        1, nrows / 2
    )  # heuristic that's designed to enlarge the histogram plot if plotting many matrices

    if type(palette) == str:
        color = get_cmap(palette).colors
    else:
        color = palette  # if passed a list
    plt.rcParams["axes.prop_cycle"] = plt.cycler(
        color=color
    )  # Sets Set1 as default color cycle
    if additional_plot == None:
        ncols = 1
    else:
        ncols = 2

    width = 10 * ncols
    fig, axs = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        gridspec_kw={"height_ratios": height_ratios},
        figsize=(width, height),
    )

    if display_title:
        fig.suptitle(title)

    row = 0
    model_ordering = None
    R_all_sorted = {}  # use copy to avoid modifying passed in dict
    for scenario_name, R in R_all.items():
        R = np.array(R)
        R_rows_sorted, error_rates_sorted, model_ordering = sort_col_row(
            R, col_sorting_order=model_ordering
        )
        R_all_sorted[scenario_name] = R_rows_sorted
        if plot_matrices:
            axs_row = axs[row]
            plot_error_matrix(
                R_rows_sorted,
                error_rates_sorted,
                axs_row,
                ncols=ncols,
                prefix=scenario_name,
            )
            row += 1

    if nrows > 1:
        axs_row = axs[row]
    else:
        axs_row = axs
    i = 0
    width = 0.2
    if len(R_all) > 3:
        width = 0.15
    elif len(R_all) > 5:
        width = 0.1
    error_rates = None
    all_left_plot_data = {}
    all_right_plot_data = {}
    for scenario_name, R in R_all_sorted.items():
        if R.size == 0:
            continue  # have to continue to avoid incrementing i
        nrows = R.shape[0]
        scenario_title = f"{scenario_name} (n={nrows})"
        try:
            error_rates, left_plot_data, right_plot_data = plot_error_histograms(
                R,
                i,
                axs_row,
                scenario_title,
                width=width,
                error_rates=error_rates,
                additional_plot=additional_plot,
                orientation=orientation,
                y_max=y_max,
            )
        except AssertionError as ae:
            print(ae)
            print(f"Error plotting histogram for {title} -- {scenario_name}")
            break
        all_left_plot_data[scenario_name] = left_plot_data
        all_right_plot_data[scenario_name] = right_plot_data
        i += 1
    return fig, axs, all_left_plot_data, all_right_plot_data


"""
Parameters:
R (2D NumPy array): A matrix of binary outcomes, where 0 represents a correct prediction and 1 represents an incorrect prediction.
i (int): The index of the bar plot that is being plotted. Used to offset the bars so that they don't overlap.
axs_row (array of Axes): An array of two subplots, where the bar plots will be plotted.
scenario_name (string): A string containing the name of the scenario that is being plotted.
width (float, optional): The width of the bars in the bar plot. Default is 0.25.
error_rates (array-like, optional): An array-like object containing error rates associated with each model. Default is None. If None, calculate error rates based on passed in R. 
    Because plot_error_histogram can be called multiple times in a loop (hence the use of `i`), error_rates is sometimes needed to ensure that the grouping of the bars
    is consistent when the relative ordering of model accuracies changes over time (e.g. if the 3rd most accurate model becomes the second most accurate model and multiple years are being plotted)
additional_plot (string, optional): Specifies the type of additional plot to be shown. Options are 'limited', 'counts', or None. Default is 'limited'.
semantic (bool, optional): Whether to semantically convert the values in the plot. Default is True.
orientation (string, optional): The orientation of the bar plot. Options are 'vertical' or 'horizontal'. Default is 'vertical'.
y_max (float, optional): The maximum value for the y-axis. Default is None.

Returns:
error_rates (array-like): An array-like object containing the error rates associated with each outcome.
left_plot_data (dict): A dictionary containing the left plot data with unique values as keys and proportions as values.
right_plot_data (dict): A dictionary containing the right plot data with unique values as keys and proportions or counts as values.
"""


def plot_error_histograms(
    R,
    i,
    axs_row,
    scenario_name,
    width=0.25,
    error_rates=None,
    additional_plot="limited",
    semantic=True,
    orientation="vertical",
    y_max=None,
):
    assert R.size > 0  # if empty, avoid doing stuff
    Notation = NotationConfig()

    if additional_plot is not None:
        ax1 = axs_row[0]
        ax2 = axs_row[1]
    else:
        ax1 = axs_row
    left_plot_data, right_plot_data = None, None
    unique_vals, inds, counts, error_rates = count_combinations(
        R, error_rates=error_rates
    )

    plot_ind = np.arange(len(unique_vals))
    proportions = np.round(counts / sum(counts), 3)
    if orientation == "horizontal":
        flip = True
        ax1_plot = ax1.barh
        ax1_ticks = ax1.set_yticks

        if additional_plot is not None:
            ax2_plot = ax2.barh
            ax2_ticks = ax2.set_yticks
    elif orientation == "vertical":
        flip = False
        ax1_plot = ax1.bar
        ax1_ticks = ax1.set_xticks

        i = (
            -i
        )  # we want the observed bar to be on the left, so we flip the sign of i so that (- width * i) is positive. Recall that matplot lib indices from top to bottom and left to right

        if additional_plot is not None:
            ax2_plot = ax2.bar
            ax2_ticks = ax2.set_xticks
    else:
        raise AssertionError("Orientation must be horizontal or vertical")

    if flip:
        proportions = np.flip(proportions)
        unique_vals = np.flip(unique_vals)
    ax1_plot(plot_ind - width * i, proportions, width, label=scenario_name)
    if y_max is not None:
        ax1.set_ylim(top=y_max)
    padding = 5
    ax1.bar_label(ax1.containers[i], label_type="edge", padding=padding)
    if semantic:
        unique_vals = Notation.semantically_convert_list(unique_vals)
    ax1_ticks(plot_ind - (i / 2) * width, unique_vals)
    ax1.legend(fontsize="medium")
    # ax1.set_title(
    #     f"Proportion of {Notation.input_name}s for Each {Notation.outcome_name}"
    # )
    left_plot_data = {
        unique_val: proportion
        for (unique_val, proportion) in zip(unique_vals, proportions)
    }
    ax1.set_xlabel(Notation.outcome_name, fontsize="x-large")
    ax1.set_ylabel(f"Proportion of {Notation.input_name}s ", fontsize="x-large")

    right_plot_data = {}  # will be populated if additional_plot is not None
    if additional_plot is not None:
        if additional_plot == "limited":
            mask = R == [0] * R.shape[1]
            first_row_all_correct = np.argmax(np.all(mask, axis=1))
            inds_limit = np.argmin(np.sort(inds) < first_row_all_correct)
            unique_vals_additional = unique_vals[:inds_limit]
            plot_ind_additional = np.arange(len(unique_vals_additional))
            counts_limited = counts[:inds_limit]
            proportions_limited = np.round(counts_limited / sum(counts_limited), 3)
            vals_to_plot_additional = proportions_limited
        elif additional_plot == "counts":
            unique_vals_additional = unique_vals
            plot_ind_additional = plot_ind
            vals_to_plot_additional = counts
        if flip:
            vals_to_plot_additional = np.flip(vals_to_plot_additional)
            unique_vals_additional = np.flip(unique_vals_additional)
        ax2_plot(
            plot_ind_additional - width * i,
            vals_to_plot_additional,
            width,
            label=scenario_name,
        )
        ax2.set_xlabel(Notation.outcome_name)
        ax2.bar_label(ax2.containers[i], padding=padding)

        # reverse unique_vals so [all wrong] at top of graph and [all right] at bottom
        ax2_ticks(plot_ind_additional - (i / 2) * width, unique_vals_additional)
        ax2.legend(fontsize="large")

        if additional_plot == "limited":
            additional_plot_title = f"Proportion of {Notation.input_name}s for Each {Notation.outcome_name}, Conditional on at Least 1 Failure"
            additional_y_label = (
                "Proportion of {Notation.input_Name} With at Least 1 Failure"
            )
        elif additional_plot == "counts":
            additional_plot_title = (
                f"Number of {Notation.input_name}s for each {Notation.outcome_name}"
            )
            additional_y_label = "Number of {Notation.input_Name}"

        ax2.set_ylabel(additional_y_label)
        # ax2.set_title(additional_plot_title)
        right_plot_data = {
            unique_val: proportion
            for (unique_val, proportion) in zip(
                unique_vals_additional, vals_to_plot_additional
            )
        }

    return error_rates, left_plot_data, right_plot_data


"""Takes a list of error rates (list of bernoullis) and generates pmf of their joint distribution
under assumption of independence.
Return type is a list of tuples where each tuple is of form (outcome, probability)."""


def calculate_independent_pmf(error_rates):
    pmf = []
    possible_outcomes = [list(seq) for seq in product([0, 1], repeat=len(error_rates))]
    for outcome in possible_outcomes:
        probabilities = []
        for i, indicator in enumerate(outcome):
            if indicator == 0:
                p = 1 - error_rates[i]
            if indicator == 1:
                p = error_rates[i]
            probabilities.append(p)
        expected_rate = np.prod(probabilities)
        pmf.append((outcome, expected_rate))
    return pmf


def calculate_prob_error(error_rates):
    error_pmf = []
    possible_outcomes = [list(seq) for seq in product([0, 1], repeat=len(error_rates))]
    for outcome in possible_outcomes:
        error_prob = outcome * error_rates  # outcome is indicator variable
        error_prod = np.prod(error_prob)
        error_pmf.append((outcome, error_prod))
    return error_pmf


"""The count_combinations function returns the unique combinations of values in the input array R_np, their corresponding indices in the original array, and their counts.

Parameters
R_np (2D NumPy array): A 2D NumPy array representing the predictions made by different models for different samples. The rows represent the samples and the columns represent the models.
error rates: If None, will calculate error rates off of R_np. If passed in as an array, will order combinations according to these passed in error rates.

Output
unique_vals_str (1D NumPy array): A 1D NumPy array of strings, where each string represents a unique combination of values in R_np.
inds (1D NumPy array): A 1D NumPy array of indices, where each index corresponds to the position of a unique combination of values in the original array R_np.
counts (1D NumPy array): A 1D NumPy array of counts, where each count represents the number of occurrences of a unique combination of values in the original array R_np.
error rates: Returns the error rates of the models passed in in the same column order as R_np.
"""


def count_combinations(R_np, error_rates=None, Notation=NotationConfig()):
    nan_placeholder = 8  # Numpy can't compare nans so the np.unique() function doesn't work. Use a placeholder as replacement

    unique_vals, inds, counts = np.unique(
        np.nan_to_num(R_np, nan=nan_placeholder),
        return_index=True,
        return_counts=True,
        axis=0,
    )
    try:
        unique_vals[
            unique_vals == nan_placeholder
        ] = np.nan  # Convert placeholder back to nan
    except ValueError:
        print("no nan values")

    if error_rates is None:
        error_rates = np.nansum(R_np, axis=0) / np.count_nonzero(
            ~np.isnan(R_np), axis=0
        )

    # Assumption that outcomes are binary
    possible_outcomes = {tuple(seq) for seq in product([0, 1], repeat=len(error_rates))}
    outcomes_information = []
    for i, outcome in enumerate(unique_vals):
        possible_outcomes.discard(tuple(outcome))
        outcome_error_rates = (
            outcome * error_rates
        )  # I(error_rates) where I is failure indicator
        prod_error = np.nanprod(
            outcome_error_rates, where=outcome_error_rates > 0
        )  # Product of the error rates for each model that failed (this is not the same as the probability of the outcome and is only used to sort the outcomes in the histogram)
        outcome_information = (
            outcome,
            counts[i],
            prod_error,
        )  # each outcome_info has the outcome, the count, and the error_prob used to sort
        outcomes_information.append(outcome_information)
    # for any outcomes remainng, we know that they never occur, so we set the count to 0
    for remaining_outcome in possible_outcomes:
        outcome_error_rates = remaining_outcome * error_rates
        prod_error = np.nanprod(outcome_error_rates, where=outcome_error_rates > 0)
        outcomes_information.append((np.array(remaining_outcome), 0, prod_error))
    outcomes_information.sort(key=itemgetter(2))
    unique_vals, counts, error_probs = zip(*outcomes_information)

    unique_vals_str = np.array(
        [str(unique_val.astype(int)) for unique_val in unique_vals]
    )  # Convert unique vals to string form so they can be printed

    if Notation is not None:
        unique_vals_str = Notation.semantically_convert_list(unique_vals_str)

    return unique_vals_str, np.sort(inds), counts, error_rates


# Utility wrapper function around df.stack(). Often have to stack() data to visualize more easily, but stack()
# doesn't name the new index and value column. This is a wrapper function that provides an index name and
# column name.
# Assumption: The data should already have index names
def stack_helper(data, index_name, column_name):
    stacked_data = data.stack()
    stacked_data.index.rename(index_name, level=-1, inplace=True)

    if type(stacked_data) == pd.core.frame.Series:
        stacked_data.name = column_name
    if type(stacked_data) == pd.core.frame.DataFrame:
        stacked_data.rename(columns={stacked_data.columns[-1]: "test"}, inplace=True)
    return stacked_data
