import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from math import sqrt
from scipy.integrate import quad
from scipy.stats import norm
from scipy import stats


def ETC(T, mu1, mu2, sigma1, sigma2, m=4):
    Y11 = np.random.normal(mu1, sigma1, int(T/2))
    Y12 = np.random.normal(mu2, sigma2, int(T/2))

    mode = 0
    
    if np.mean(Y11) > (np.mean(Y12)):
        Y2 = np.random.normal(mu1, sigma1, m * T)
        #Y2 = np.random.binomial(n=1, p = mu1, size = int(m*T))
        theta_hat, eta_hat = np.mean(np.concatenate([Y11, Y2])), np.mean(Y12)
        mode = 1
    else:
        Y2 = np.random.normal(mu2, sigma2, m * T)
        #Y2 = np.random.binomial(n=1, p = mu2, size = int(m*T))
        theta_hat, eta_hat = np.mean(Y11), np.mean(np.concatenate([Y12, Y2]))
        #theta_hat, eta_hat = np.mean(Y11), np.mean(Y12)
        mode = 2

        
    return theta_hat, eta_hat, mode


def test(T, theta_hat, mu10, eta_hat, alpha=0.1, M=1000):
    # Simulate M draws of the test stat under (mu1, mu2) = (mu10, eta_hat)
    theta_hat_sims = [ETC(T, mu1=mu10, mu2=eta_hat, sigma1=sigma1, sigma2=sigma2, m = 4)[0] for _ in range(M)]
    # Do I reject mu1 = mu10 at sig level alpha 
    return np.mean(theta_hat < np.array(theta_hat_sims)) # theta_hat < np.quantile(theta_hat_sims, alpha)


T = 5000
mu1, mu2 = 0, 0
#mu1, mu2 = 1/2, 1/2
sigma1, sigma2 = 1, 1


# Draw the data N times, see how many times the test rejects
rejects = []
rejects_biased_1 = []
rejects_biased_2 = []
rejects_biased_3 = []
modes = []
eta_hat_vals = []
theta_hat_vals = []
N = 1000
m = 1
with tqdm(total=N) as pbar:
    for i in range(N):
        theta_hat, eta_hat, mode = ETC(T, mu1, mu2, sigma1, sigma2, m=4)
        eta_hat_vals.append(eta_hat)
        theta_hat_vals.append(theta_hat)
        modes.append(mode)
        n2 = 500 + (mode == 2)*500
        rejects.append(test(T, theta_hat, mu1, eta_hat))
        rejects_biased_1.append(test(T, theta_hat, mu1, eta_hat + np.max([0, np.log10(np.log10(n2))/np.sqrt(n2)]) ))
        rejects_biased_2.append(test(T, theta_hat, mu1, eta_hat + np.max([0, np.log10(n2)/np.sqrt(n2)]) ))
        rejects_biased_3.append(test(T, theta_hat, mu1, 1))

        pbar.update(1)


alpha = np.linspace(0,1,1000)
rejection_prob = np.zeros(1000)
rejection_prob_biased_1 = np.zeros(1000)
rejection_prob_biased_2 = np.zeros(1000)
rejection_prob_biased_3 = np.zeros(1000)

for i in range(len(alpha)):
    rejection_prob[i] = (np.sum(np.array(rejects) <= alpha[i]/2) + np.sum(np.array(rejects) >= 1-alpha[i]/2))/len(rejects)
    rejection_prob_biased_1[i] = (np.sum(np.array(rejects_biased_1) <= alpha[i]/2) + np.sum(np.array(rejects_biased_1) >= 1-alpha[i]/2))/len(rejects)
    rejection_prob_biased_2[i] = (np.sum(np.array(rejects_biased_2) <= alpha[i]/2) + np.sum(np.array(rejects_biased_2) >= 1-alpha[i]/2))/len(rejects)
    rejection_prob_biased_3[i] = (np.sum(np.array(rejects_biased_3) <= alpha[i]/2) + np.sum(np.array(rejects_biased_3) >= 1-alpha[i]/2))/len(rejects)



import seaborn as sns
sns.set_context("talk")
sns.set_style("whitegrid")
fig, axes = plt.subplots(1, 1, figsize=(18, 5))
#plt.plot(alpha, rejection_prob - alpha, color = "red", label = "Simulation (Plug-in)")
plt.axhline(0, color = "black", linestyle = "--")
plt.plot(alpha, rejection_prob_biased_1 - alpha, color = "blue", label = "Simulation (Bias 1)")
plt.plot(alpha, rejection_prob_biased_2 - alpha, color = "green", label = "Simulation (Bias 2)")
plt.plot(alpha, rejection_prob_biased_3 - alpha, color = "purple", label = "Simulation (Bias 3)")

#plt.plot(alpha, rejection_prob_sm - alpha, color = "green", label = "Sample Mean")

plt.xlabel("Nominal Type I Error Rate")
#plt.fill_between(
#        alpha,
#        rejection_prob - alpha - stats.norm.ppf(0.95) * np.sqrt(rejection_prob*(1-rejection_prob)/len(rejects)),
#        rejection_prob - alpha + stats.norm.ppf(0.95) * np.sqrt(rejection_prob*(1-rejection_prob)/len(rejects)),
#        color= "red",
#        alpha=0.2
#    )
plt.fill_between(
        alpha,
        rejection_prob_biased_1 - alpha - stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_1*(1-rejection_prob_biased_1)/len(rejects)),
        rejection_prob_biased_1 - alpha + stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_1*(1-rejection_prob_biased_1)/len(rejects)),
        color= "blue",
        alpha=0.2
    )
plt.fill_between(
        alpha,
        rejection_prob_biased_2 - alpha - stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_2*(1-rejection_prob_biased_2)/len(rejects)),
        rejection_prob_biased_2 - alpha + stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_2*(1-rejection_prob_biased_2)/len(rejects)),
        color= "green",
        alpha=0.2
    )
plt.fill_between(
        alpha,
        rejection_prob_biased_3 - alpha - stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_3*(1-rejection_prob_biased_3)/len(rejects)),
        rejection_prob_biased_3 - alpha + stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_3*(1-rejection_prob_biased_3)/len(rejects)),
        color= "purple",
        alpha=0.2
    )
#plt.fill_between(
#        alpha,
#        rejection_prob_sm - alpha - np.sqrt(rejection_prob_sm*(1-rejection_prob_sm)/len(rejection_prob_sm)),
#        rejection_prob_sm - alpha + np.sqrt(rejection_prob_sm*(1-rejection_prob_sm)/len(rejection_prob_sm)),
#        color= "green",
#        alpha=0.2
#    )
plt.ylabel("Realized - Nominal")
plt.legend()
plt.tight_layout()
plt.savefig('type_i_error.png', format='png', dpi=300)


# get power plots
rejects = []
rejects_biased_1 = []
rejects_biased_2 = []
rejects_biased_3 = []
modes = []
eta_hat_vals = []
theta_hat_vals = []
N = 1000
m = 1
with tqdm(total=N) as pbar:
    for i in range(N):
        theta_hat, eta_hat, mode = ETC(T, mu1, mu2, sigma1, sigma2, m=4)
        eta_hat_vals.append(eta_hat)
        theta_hat_vals.append(theta_hat)
        modes.append(mode)
        n2 = 500 + (mode == 2)*500
        rejects.append(test(T, theta_hat, 0.1, eta_hat))
        rejects_biased_1.append(test(T, theta_hat, 0.02, eta_hat + np.max([0, np.log10(np.log10(n2))/np.sqrt(n2)]) ))
        rejects_biased_2.append(test(T, theta_hat, 0.02, eta_hat + np.max([0, np.log10(n2)/np.sqrt(n2)]) ))
        rejects_biased_3.append(test(T, theta_hat, 0.02, 1))

        pbar.update(1)

alpha = np.linspace(0,1,1000)
rejection_prob = np.zeros(1000)
rejection_prob_biased_1 = np.zeros(1000)
rejection_prob_biased_2 = np.zeros(1000)
rejection_prob_biased_3 = np.zeros(1000)

for i in range(len(alpha)):
    rejection_prob[i] = (np.sum(np.array(rejects) <= alpha[i]/2) + np.sum(np.array(rejects) >= 1-alpha[i]/2))/len(rejects)
    rejection_prob_biased_1[i] = (np.sum(np.array(rejects_biased_1) <= alpha[i]/2) + np.sum(np.array(rejects_biased_1) >= 1-alpha[i]/2))/len(rejects)
    rejection_prob_biased_2[i] = (np.sum(np.array(rejects_biased_2) <= alpha[i]/2) + np.sum(np.array(rejects_biased_2) >= 1-alpha[i]/2))/len(rejects)
    rejection_prob_biased_3[i] = (np.sum(np.array(rejects_biased_3) <= alpha[i]/2) + np.sum(np.array(rejects_biased_3) >= 1-alpha[i]/2))/len(rejects)



import seaborn as sns
sns.set_context("talk")
sns.set_style("whitegrid")
fig, axes = plt.subplots(1, 1, figsize=(18, 5))
#plt.plot(alpha, rejection_prob, color = "red", label = "Simulation (Plug-in)")
plt.axhline(0, color = "black", linestyle = "--")
plt.plot(alpha, rejection_prob_biased_1, color = "blue", label = "Simulation (Bias 1)")
plt.plot(alpha, rejection_prob_biased_2, color = "green", label = "Simulation (Bias 2)")
plt.plot(alpha, rejection_prob_biased_3, color = "purple", label = "Simulation (Bias 3)")

#plt.plot(alpha, rejection_prob_sm - alpha, color = "green", label = "Sample Mean")

plt.xlabel("Nominal Type I Error Rate")
#plt.fill_between(
#        alpha,
#        rejection_prob - stats.norm.ppf(0.95) * np.sqrt(rejection_prob*(1-rejection_prob)/len(rejects)),
#        rejection_prob + stats.norm.ppf(0.95) * np.sqrt(rejection_prob*(1-rejection_prob)/len(rejects)),
#        color= "red",
#        alpha=0.2
#    )
plt.fill_between(
        alpha,
        rejection_prob_biased_1 - stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_1*(1-rejection_prob_biased_1)/len(rejects)),
        rejection_prob_biased_1 + stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_1*(1-rejection_prob_biased_1)/len(rejects)),
        color= "blue",
        alpha=0.2
    )
plt.fill_between(
        alpha,
        rejection_prob_biased_2 - stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_2*(1-rejection_prob_biased_2)/len(rejects)),
        rejection_prob_biased_2 + stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_2*(1-rejection_prob_biased_2)/len(rejects)),
        color= "green",
        alpha=0.2
    )
plt.fill_between(
        alpha,
        rejection_prob_biased_3 - stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_3*(1-rejection_prob_biased_3)/len(rejects)),
        rejection_prob_biased_3 + stats.norm.ppf(0.95) * np.sqrt(rejection_prob_biased_3*(1-rejection_prob_biased_3)/len(rejects)),
        color= "purple",
        alpha=0.2
    )
#plt.fill_between(
#        alpha,
#        rejection_prob_sm - alpha - np.sqrt(rejection_prob_sm*(1-rejection_prob_sm)/len(rejection_prob_sm)),
#        rejection_prob_sm - alpha + np.sqrt(rejection_prob_sm*(1-rejection_prob_sm)/len(rejection_prob_sm)),
#        color= "green",
#        alpha=0.2
#    )
plt.ylabel("Power")
plt.legend()
plt.tight_layout()
plt.savefig('power_bias.png', format='png', dpi=300)





### Testing robustness of sample mean against plug-in nuisances
### get quantile of sample mean for arm 1 under true simulations
test_stat = theta_hat_vals_true
true_quantile = print(np.quantile(test_stat, 0.1))

def get_plugin_quantile(T, theta_hat, mu10, eta_hat, alpha=0.1, M=1000):
    # Simulate M draws of the test stat under (mu1, mu2) = (mu10, eta_hat)
    theta_hat_sims = [ETC(T, mu1=mu10, mu2=eta_hat, sigma1=sigma1, sigma2=sigma2)[0] for _ in range(M)]
    # Do I reject mu1 = mu10 at sig level alpha 
    return (theta_hat<theta_hat_sims).mean() ## 1-F



T = 1000
N = 50000
mu1 = 0

### for each eta_hat_vals_true, calculate the quantile under the 
quantiles = []
theta_hats = []
eta_hats = []
mode_list = []
with tqdm(total=N) as pbar:
    for i in range(N):
        theta_hat, eta_hat, mode = ETC(T, mu1, mu2, sigma1, sigma2)
        theta_hats.append(theta_hat)
        eta_hats.append(eta_hat)
        mode_list.append(mode)
        quantiles.append(get_plugin_quantile(T, theta_hat, mu1, eta_hat))
        pbar.update(1)



## calcualte the distance between uniform distributions for this test statistic - show that this fails for ETC
uniform_samples = np.random.uniform(size=50000)

plt.hist(1-np.array(quantiles), bins = 20, alpha = 0.5, label = "realized")
plt.hist(uniform_samples, alpha = 0.5, bins=20 ,label = "uniform")
plt.legend()


tested_quantiles = np.linspace(0, 1, 100)
type_i_error = np.zeros(100)
type_i_error_unif = np.zeros(100)
for i in range(100):
    type_i_error[i] = np.mean(1-np.array(quantiles) < tested_quantiles[i])
    type_i_error_unif[i] = np.mean(uniform_samples < tested_quantiles[i])

print(np.max(np.abs(type_i_error - tested_quantiles)))
print(np.max(np.abs(type_i_error_unif - tested_quantiles)))


# construct CDF for this test statistic
ks_test_for_uniform = (stats.ks_2samp(quantiles, uniform_samples))
print(ks_test_for_uniform.pvalue) ### this distribution is clearly not uniform


## get distirbution of conditional normals
import numpy as np
import matplotlib.pyplot as plt

X = np.random.normal(size = 10000000)
Y = np.random.normal(size = 10000000)
Z = np.random.normal(size = 10000000)

Z1 = np.random.normal(size = 10000000)
Z2 = np.random.normal(size = 10000000)
Z3 = np.random.normal(size = 10000000)

err_dist = (Z2 + (Z2 > Z1) * Z3)/(1 + (Z2 > Z1)*(np.sqrt(2) - 1))
#err_dist = 100

## distribution of test statistic under the null distribution
test_stat_true = (X + Z * (X>Y))/(1+ (X>Y)*(np.sqrt(2)-1) )
test_stat_err = (X + Z * (X>(Y + err_dist)))/(1+ (X>(Y+err_dist))*(np.sqrt(2)-1) )

plt.hist(test_stat_true, bins = 200, alpha = 0.5, label = "True Test Stat")
plt.hist(test_stat_err, bins = 200, alpha = 0.5, label = "Test Stat with Plug-in")
plt.legend()

plt.show()

print(np.quantile(test_stat_true, 0.4))
print(np.quantile(test_stat_err, 0.4))



#plt.hist(X[X>Y]/np.sqrt(2) + Z[X>Y]/np.sqrt(2), bins = 100, alpha = 0.5)
#plt.hist(X[X>Y], bins = 100, alpha = 0.5)
#plt.hist(X[X<=Y], bins = 100, alpha = 0.5)



plt.


plt.hist(quantiles, bins = 100)
plt.axvline(np.quantile(test_stat, 0.1), color = "red")
plt.axvline(np.median(quantiles), color = "yellow")
print(np.median(quantiles))
print(np.quantile(test_stat, 0.1))


plt.scatter(eta_hat_vals_true, quantiles, s=1)
