import numpy as np
import matplotlib.pyplot as plt


############################################################################
# Scoring rule
#########################################################################################
def brier_score(f, y): return -(f-y)**2


#########################################################################################
# IPW
#########################################################################################
def ipw_expectation(FY_A, PY_A, PA_F, scoring_rule):
    expected_IPW_score = 0
    for a in [0, 1]:
        expected_IPW_score -= ((1-PY_A(a)) * scoring_rule(FY_A(a), 0) +
                               PY_A(a) * scoring_rule(FY_A(a), 1))*int(PA_F(FY_A)[a] > 0)
    return expected_IPW_score


def ipw_estimator(A, Y, F, scoring_rule):
    pa = A.mean()
    def PhatA_F(a): return np.where(a == 0, 1-pa, pa) 
    return np.mean(-scoring_rule(F(A), Y) / PhatA_F(A))

#########################################################################################
# Divergence
#########################################################################################
def divergence_expectation(FY_A, PY_A, PA_F, scoring_rule):
    expected_divergence = 0
    for a in [0, 1]:
        cond_entropy = (1-PY_A(a)) * scoring_rule(PY_A(a), 0) + PY_A(a) * scoring_rule(PY_A(a), 1)
        cond_score = (1-PY_A(a)) * scoring_rule(FY_A(a), 0) + PY_A(a) * scoring_rule(FY_A(a), 1)
        expected_divergence += (cond_entropy - cond_score)*PA_F(FY_A)[a]
    return expected_divergence

def divergence_estimator(A, Y, F, scoring_rule):
    P_Y1_given_A0 = Y[A == 0].mean() if np.any(A == 0) else np.nan
    P_Y1_given_A1 = Y[A == 1].mean() if np.any(A == 1) else np.nan
    def PhatY_A(a): return np.where(a == 0, P_Y1_given_A0, P_Y1_given_A1)
    return np.mean(scoring_rule(PhatY_A(A), Y) - scoring_rule(F(A), Y))

def divergence_estimator_unbiased(A, Y, F, scoring_rule):
    P_Y1_given_A0 = Y[A == 0].mean() if np.any(A == 0) else np.nan
    P_Y1_given_A1 = Y[A == 1].mean() if np.any(A == 1) else np.nan
    def PhatY_A(A): return np.where(A == 1, P_Y1_given_A1, P_Y1_given_A0)
    if scoring_rule == brier_score:
        n = len(Y)
        NA1 = sum(A==1)
        def N_A(A): return np.where(A == 1, NA1, n-NA1)
        return np.mean(np.where(N_A(A) >= 2, (PhatY_A(A) - F(A))**2 - np.divide(PhatY_A(A)*(1-PhatY_A(A)), (N_A(A)-1)), 0))
    return np.mean(scoring_rule(PhatY_A(A), Y) - scoring_rule(F(A), Y))

#########################################################################################
# Function to simulate data, estimate score and return quantiles over m runs
#########################################################################################
def get_percentiles(F, PY_A, PA_F, estimator, scoring_rule, sample_sizes, m):
    pp5 = []
    pp50 = []
    pp95 = []
    for n in sample_sizes:
        A = np.random.binomial(n=1, p=PA_F(F)[1], size=(m, n))
        Y = np.random.binomial(n=1, p=PY_A(A))
        estimates = []
        for i in range(m):
            estimates.append(estimator(A[i], Y[i], F, scoring_rule))
        percentiles = np.percentile(estimates, [5, 50, 95])
        pp5.append(percentiles[0])
        # pp50.append(percentiles[1])
        pp50.append(np.mean(estimates))
        pp95.append(percentiles[2])
    return (pp5, pp50, pp95)

#########################################################################################
# Function to generate plot with estimated and expected scores of a correct and incorrect forecast
#########################################################################################
def get_plot(estimate_percentiles_correct, expected_correct, estimate_percentiles_incorrect, expected_incorrect, ylabel):
    plt.figure(figsize=(4, 4))
    plt.plot(sample_sizes, estimate_percentiles_correct[1], label="Estimated score", color='green')
    plt.fill_between(sample_sizes, estimate_percentiles_correct[0], estimate_percentiles_correct[2], color='green', alpha=0.2)
    plt.axhline(expected_correct, color='darkgreen', linestyle='--', label="True score")

    plt.plot(sample_sizes, estimate_percentiles_incorrect[1], label="Estimated score", color='red')
    plt.fill_between(sample_sizes, estimate_percentiles_incorrect[0], estimate_percentiles_incorrect[2], color='red', alpha=0.2)
    plt.axhline(expected_incorrect, color='darkred', linestyle='--', label="True score")

    plt.ylim(-0.2, 0.9)
    plt.xlabel("Sample size (n)")
    plt.xscale('log')
    plt.ylabel(ylabel)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    return plt


#########################################################################################
# Various mechanisms for P_M(A | F)
#########################################################################################
# Othman and Sandholm (2010), EXAMPLE OF NOT COUNTERFACTUALLY PROPER
def PA_F_argmax(F): return [0, 1] if (np.argmax(F) == 1) else [1, 0]
# OUR EXAMPLE OF NOT OBSERVATIONALLY PROPER NOR COUNTERFACTUALLY PROPER
def PA_F_self_defeating(F): return [0, 1] if (F(1) >= 0.4 and F(0) <= 0.4) else [1, 0]
eps = 1/3
def PA_F_unif(F): return [1/2, 1/2]
def PA_F_mixture(F): return np.add(np.multiply(eps, PA_F_unif(F)), np.multiply((1-eps), PA_F_self_defeating(F)))

#########################################################################################
# Run experiments
#########################################################################################
m = 10000
sample_sizes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 21, 46]

np.random.seed(10)
def PY_A(a): return np.where(a == 0, 0.5, 0.25)
def F_correct(a): return PY_A(a)
def F_incorrect(a): return np.where(a == 0, 0.7, 0.45)

#########################################################################################
## Example without positivity
#########################################################################################
# Dhat_correct = get_percentiles(F_correct, PY_A, PA_F_argmax, divergence_estimator, brier_score, sample_sizes, m)
# expected_divergence_correct = divergence_expectation(F_correct, PY_A, PA_F_argmax, brier_score)
# Dhat_incorrect = get_percentiles(F_incorrect, PY_A, PA_F_argmax, divergence_estimator, brier_score, sample_sizes, m)
# expected_divergence_incorrect = divergence_expectation(F_incorrect, PY_A, PA_F_argmax, brier_score)
# plt = get_plot(Dhat_correct, expected_divergence_correct, Dhat_incorrect, expected_divergence_incorrect,
#                "Estimated and true divergence for correct (green) and incorrect (red) forecast")

#########################################################################################
## Example with positivity, divergence, biased estimator
#########################################################################################
Dhat_correct = get_percentiles(F_correct, PY_A, PA_F_mixture, divergence_estimator, brier_score, sample_sizes, m)
expected_divergence_correct = divergence_expectation(F_correct, PY_A, PA_F_mixture, brier_score)
Dhat_incorrect = get_percentiles(F_incorrect, PY_A, PA_F_mixture, divergence_estimator, brier_score, sample_sizes, m)
expected_divergence_incorrect = divergence_expectation(F_incorrect, PY_A, PA_F_mixture, brier_score)
plt = get_plot(Dhat_correct, expected_divergence_correct, Dhat_incorrect, expected_divergence_incorrect, "Divergence")
plt.savefig("figures/app_variance_of_divergence.pdf", format='pdf', dpi=300, bbox_inches='tight')

#########################################################################################
## Example with positivity, divergence, unbiased estimator
#########################################################################################
Dhat_correct = get_percentiles(F_correct, PY_A, PA_F_mixture, divergence_estimator_unbiased, brier_score, sample_sizes, m)
expected_divergence_correct = divergence_expectation(F_correct, PY_A, PA_F_mixture, brier_score)
Dhat_incorrect = get_percentiles(F_incorrect, PY_A, PA_F_mixture, divergence_estimator_unbiased, brier_score, sample_sizes, m)
expected_divergence_incorrect = divergence_expectation(F_incorrect, PY_A, PA_F_mixture, brier_score)
plt = get_plot(Dhat_correct, expected_divergence_correct, Dhat_incorrect, expected_divergence_incorrect, "Divergence")
plt.savefig("figures/app_variance_of_divergence_unbiased.pdf", format='pdf', dpi=300, bbox_inches='tight')

#########################################################################################
## Example with positivity, IPW
#########################################################################################
IPWhat_correct = get_percentiles(F_correct, PY_A, PA_F_mixture, ipw_estimator, brier_score, sample_sizes, m)
expected_ipw_correct = ipw_expectation(F_correct, PY_A, PA_F_mixture, brier_score)
IPWhat_incorrect = get_percentiles(F_incorrect, PY_A, PA_F_mixture, ipw_estimator, brier_score, sample_sizes, m)
expected_ipw_incorrect = ipw_expectation(F_incorrect, PY_A, PA_F_mixture, brier_score)
plt = get_plot(IPWhat_correct, expected_ipw_correct, IPWhat_incorrect, expected_ipw_incorrect, "IPW score")
plt.savefig("figures/app_variance_of_IPW.pdf", format='pdf', dpi=300, bbox_inches='tight')
plt.show()
