import numpy as np
import matplotlib.pyplot as plt


############################################################################
# Scoring rules
#########################################################################################
def brier_score(f, y): return -(f-y)**2
def log_score(f, y): return np.log(f) if y == 1 else  np.log(1-f)
def utility_score(FY_A, y): return y
def utility_plus_brier_score(f, y): return y + brier_score(f, y)/2


#########################################################################################
# Functions to calculate the expected score / divergence / IPW score
# given a conditional forecast, a model of F -> A -> Y and a scoring rule S(F,y)
#########################################################################################
def score_expectation(FY_A, PY_A, PA_F, scoring_rule):
    expected_score = 0
    for a in [0, 1]:
        expected_score += ((1-PY_A[a]) * scoring_rule(FY_A[a], 0) +
                           PY_A[a] * scoring_rule(FY_A[a], 1))*PA_F(FY_A)[a]
    return expected_score


def divergence_expectation(FY_A, PY_A, PA_F, scoring_rule):
    expected_divergence = 0
    for a in [0, 1]:
        cond_entropy = (1-PY_A[a]) * scoring_rule(PY_A[a],
                                                  0) + PY_A[a] * scoring_rule(PY_A[a], 1)
        cond_score = (1-PY_A[a]) * scoring_rule(FY_A[a],
                                                0) + PY_A[a] * scoring_rule(FY_A[a], 1)
        expected_divergence -= (cond_entropy - cond_score)*PA_F(FY_A)[a]
    return expected_divergence


def ipw_expectation(FY_A, PY_A, PA_F, scoring_rule):
    expected_IPW_score = 0
    for a in [0, 1]:
        expected_IPW_score += ((1-PY_A[a]) * scoring_rule(FY_A[a], 0) +
                               PY_A[a] * scoring_rule(FY_A[a], 1))*int(PA_F(FY_A)[a] > 0)
    return expected_IPW_score


#########################################################################################
# Function to generate a 3D plot of the expected score for all different
# forecast, given a model for F -> A -> Y, a scoring rule, and a method to
# calculate the expected score.
#########################################################################################
def get_plot(expected_score_fn, PY_A, PA_F, scoring_rule, plot_optima=True):
    ngrid = 400
    # fig = plt.figure(figsize=(10, 8))
    fig = plt.figure(figsize=(5, 3.5))
    ax = fig.add_subplot(111, projection='3d', computed_zorder=False)
    F0, F1 = np.meshgrid(np.linspace(0, 1, ngrid), np.linspace(0, 1, ngrid))
    plot_surface = np.zeros_like(F0)
    for i in range(F0.shape[0]):
        for j in range(F0.shape[1]):
            FY_A = [F0[i, j], F1[i, j]]
            plot_surface[j, i] = expected_score_fn(
                FY_A, PY_A, PA_F, scoring_rule)
    ax.plot_surface(F0, F1, plot_surface,
                    cmap='coolwarm',
                    # cmap='cividis',
                    edgecolor='none', alpha=0.8, zorder=0)

    # Plot horizontal lines at global optima
    if (plot_optima):
        max_score = plot_surface.max()
        tolerance = 1e-5  # Numerical tolerance to identify plateaus
        optima_indices = np.argwhere(
            np.abs(plot_surface - max_score) < tolerance)
        ax.scatter(F0[optima_indices[:, 0], optima_indices[:, 1]],
                   F1[optima_indices[:, 0], optima_indices[:, 1]],
                   np.add([max_score]*len(optima_indices), 0.00),
                   color='blue', s=2, marker='o', label='Global optima', zorder=1)
        ax.scatter(F0[optima_indices[:, 0], optima_indices[:, 1]],
                   F1[optima_indices[:, 0], optima_indices[:, 1]],
                   plot_surface.min(),
                   color='grey', s=2, marker='o', label='Global optima', zorder=0)
        ax.scatter(F0[optima_indices[:, 0], optima_indices[:, 1]],
                   1,
                   np.add([max_score]*len(optima_indices), 0.00),
                   color='grey', s=2, marker='o', label='Global optima', zorder=0)
        ax.scatter(1,
                   F1[optima_indices[:, 0], optima_indices[:, 1]],
                   np.add([max_score]*len(optima_indices), 0.00),
                   color='grey', s=2, marker='o', label='Global optima', zorder=0)

    # Plot the truthful report
    grid = np.linspace(0, 1, ngrid)
    i = np.argmin(np.abs(grid - PY_A[0]))
    j = np.argmin(np.abs(grid - PY_A[1]))
    ax.scatter(PY_A[1], PY_A[0], plot_surface[i, j] + 0.00, color='green',
               s=30, marker='o', alpha=1, label='Correct forecast', zorder=0)
    ax.scatter(PY_A[1], PY_A[0], plot_surface.min(), color='dimgrey',
               s=30, marker='o', alpha=1, label='Correct forecast', zorder=-1)
    ax.scatter(PY_A[1], 1, plot_surface[i, j], color='dimgrey', s=30,
               marker='o', alpha=1, label='Correct forecast', zorder=-1)
    ax.scatter(1, PY_A[0], plot_surface[i, j], color='dimgrey', s=30,
               marker='o', alpha=1, label='Correct forecast', zorder=-1)

    # Add figure labels
    ax.set_xlabel('F(Y=1|A=1)', labelpad=15)
    ax.set_ylabel('F(Y=1|A=0)', labelpad=15)
    ax.set_zlabel('Expected score', labelpad=15)
    ax.set_zlim(plot_surface.min(), plot_surface.max() + 0.05)
    ax.view_init(elev=16, azim=-142)
    # ax.legend(fontsize=8, loc='upper right')
    plt.tight_layout(pad=0.1)
    return plt


#########################################################################################
# Various mechanisms for P(A | F)
#########################################################################################
PY_A = [0.5, 0.25]
# Othman and Sandholm (2010), EXAMPLE OF NOT COUNTERFACTUALLY PROPER
def PA_F_argmax(F): return [0, 1] if (np.argmax(F) == 1) else [1, 0]
# OUR EXAMPLE OF NOT OBSERVATIONALLY PROPER NOR COUNTERFACTUALLY PROPER
def PA_F_self_defeating(F): return [0, 1] if (F[1] >= 0.4 and F[0] <= 0.4) else [1, 0]
eps = 1/3
def PA_F_unif(F): return [1/2, 1/2]
def PA_F_mixture(F): return np.add(np.multiply(eps, PA_F_unif(F)), np.multiply((1-eps), PA_F_self_defeating(F)))

#########################################################################################
# Example of not proper and not observationally proper scoring rules
#########################################################################################
save_figs = False
plt = get_plot(score_expectation, PY_A, PA_F_argmax, brier_score)
if save_figs:
    plt.savefig("figures/1a_expected_brier_score_plot_not_proper.pdf", format='pdf', dpi=300, bbox_inches='tight')

plt = get_plot(score_expectation, PY_A, PA_F_self_defeating, brier_score)
if save_figs:
    plt.savefig("figures/1b_expected_brier_score_plot_self_defeating.pdf", format='pdf', dpi=300, bbox_inches='tight')

plt = get_plot(score_expectation, PY_A, PA_F_mixture, brier_score)
if save_figs:
    plt.savefig("figures/1c_expected_brier_score_plot_self_defeating_positivity.pdf", format='pdf', dpi=300, bbox_inches='tight')

#########################################################################################
# Utility score is proper, and observationally striclty proper when adding delta*Brier
#########################################################################################
plt = get_plot(score_expectation, PY_A, PA_F_argmax, utility_score, plot_optima=False)
if save_figs:
    plt.savefig("figures/2_expected_utility_score_plot.pdf", format='pdf', dpi=300, bbox_inches='tight')
plt = get_plot(score_expectation, PY_A, PA_F_argmax, utility_plus_brier_score, plot_optima=True)
if save_figs:
    plt.savefig("figures/2b_expected_utility_plus_brier_score_plot.pdf", format='pdf', dpi=300, bbox_inches='tight')

#########################################################################################
# Plot of expected score, divergence and IPW score WITHOUT positivity
#########################################################################################
plt = get_plot(score_expectation, PY_A, PA_F_self_defeating, brier_score)
if save_figs:
    plt.savefig("figures/3a_expected_brier_score_plot_no_positivity.pdf", format='pdf', dpi=300, bbox_inches='tight')

# Divergence score is empirically proper
plt = get_plot(divergence_expectation, PY_A, PA_F_self_defeating, brier_score)
if save_figs:
    plt.savefig("figures/3b_expected_brier_divergence_plot_no_positivity.pdf", format='pdf', dpi=300, bbox_inches='tight')

# IPW score is proper
plt = get_plot(ipw_expectation, PY_A, PA_F_self_defeating, brier_score)
if save_figs:
    plt.savefig("figures/3c_expected_brier_ipw_score_plot_no_positivity.pdf", format='pdf', dpi=300, bbox_inches='tight')

#########################################################################################
# Plot of expected score, divergence and IPW score WITH positivity
#########################################################################################
plt = get_plot(score_expectation, PY_A, PA_F_mixture, brier_score)
if save_figs:
    plt.savefig("figures/4a_expected_brier_score_plot_positivity.pdf", format='pdf', dpi=300, bbox_inches='tight')

# Divergence score is empirically proper
plt = get_plot(divergence_expectation, PY_A, PA_F_mixture, brier_score)
if save_figs:
    plt.savefig("figures/4b_expected_brier_divergence_plot_positivity.pdf", format='pdf', dpi=300, bbox_inches='tight')

# IPW score is proper
plt = get_plot(ipw_expectation, PY_A, PA_F_mixture, brier_score)
if save_figs:
    plt.savefig("figures/4c_expected_brier_ipw_score_plot_positivity.pdf", format='pdf', dpi=300, bbox_inches='tight')

plt.show()
