
import scipy.special
from scipy.stats import norm, binom_test,sem
from statsmodels.stats.proportion import proportion_confint
import math
import numpy as np
import torch
import matplotlib.pyplot as plt

def get_cvar_cert_time_t(estimate, t, eps,sigma):
	erf = scipy.special.erf(math.sqrt(t+1) * eps/(2*math.sqrt(2)*sigma))
	cvar = 1. if estimate > erf else estimate/erf
	return cvar * erf

def get_exact_time_t(estimate, t, eps,sigma):
	return norm.cdf(norm.ppf(estimate) - math.sqrt(t+1) * eps/(sigma))

def _lower_confidence_bound(NA: int, N: int, alpha: float) -> float:
    return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0]
def get_exact_total(estimate, eps,sigma):
	return norm.cdf(norm.ppf(estimate) -  eps/(sigma))
import argparse



plt.figure(figsize=(8.4,4.8))

for sigma in [12.75, 25.5]:
	data = torch.load('pong_1r_sigma_'+str(sigma)+'/best_model.zip_evals_10000.pth')
	prob = (torch.tensor([x[0] for x in data])/2.+.5).mean()
	lcb = _lower_confidence_bound(int(prob*10000), 10000, 0.05)
	vals = [lcb]
	for eps in np.arange(0.01*255,0.41*255, 0.01*255):
		vals.append(get_exact_total(lcb,eps,sigma))
	plt.plot(np.arange(0,0.41, 0.01),vals, color=('blue' if sigma ==12.75 else 'cornflowerblue'), linestyle=('-' if sigma ==12.75 else '--') ,label="σ = " + str(sigma/255) ) #"Policy Smoothing: Certified\nLower Bound (σ = " + str(sigma/255) +')')


# attack_mags_nonzero = [12.75,25.5, 38.25,51.0]

# attack_vals =  [torch.tensor(torch.load('pong_1r_sigma_0.0/best_model.zip_evals_10000.pth')).mean().item()/2. + .5]
# attack_sems = [sem(torch.tensor(torch.load('pong_1r_sigma_0.0/best_model.zip_evals_10000.pth'))/2. + .5)]
# for attack_mag in attack_mags_nonzero:
# 	attack_val = None
# 	attack_sem = None
# 	for i,thresh in enumerate([0.1,0.3,0.5]):
# 		cur_val = (torch.tensor( torch.load('pong_1r_sigma_0.0/best_model.zip_evals_1000_attack_eps_'+str(attack_mag)+'_attack_step_count_multiplier_2_attack_step_2.5500000000000003_threshold_'+str(thresh)+'.pth')).mean().item()/2. + .5)
# 		if (attack_val is None or cur_val < attack_val):
# 			attack_val = cur_val
# 			attack_sem = (sem(torch.tensor( torch.load('pong_1r_sigma_0.0/best_model.zip_evals_1000_attack_eps_'+str(attack_mag)+'_attack_step_count_multiplier_2_attack_step_2.5500000000000003_threshold_'+str(thresh)+'.pth'))/2. + .5))
# 	attack_vals.append(attack_val)
# 	attack_sems.append(attack_sem)
# attack_mags = [0] + attack_mags_nonzero
# plt.errorbar([x/255 for x in attack_mags],attack_vals,  yerr= attack_sems, color='red',  linestyle ="--",label="Undefended"  )



# styles = ['-.','--',":"]
# attack_mags_nonzero = [25.5, 51.0, 76.5,102.0]
# for j,sigma in enumerate([12.75]):#,25.6]):
# 	attack_vals =  [torch.tensor(torch.load('pong_1r_sigma_'+str(sigma)+'/best_model.zip_evals_10000.pth')).mean().item()/2. + .5]
# 	attack_sems = [sem(torch.tensor(torch.load('pong_1r_sigma_'+str(sigma)+'/best_model.zip_evals_10000.pth'))/2. + .5)]
# 	for attack_mag in attack_mags_nonzero:
# 		attack_val = None
# 		attack_sem = None
# 		for i,thresh in enumerate([0.1,0.3,0.5]):
# 			cur_val = (torch.tensor( torch.load('pong_1r_sigma_'+str(sigma)+'/best_model.zip_evals_1000_smooth_attack_eps_'+str(attack_mag)+'_attack_step_count_multiplier_2_attack_step_2.5500000000000003_threshold_'+str(thresh)+'_num_smoothing_points_128.pth')).mean().item()/2. + .5)
# 			if (attack_val is None or cur_val < attack_val):
# 				attack_val = cur_val
# 				attack_sem = (sem(torch.tensor( torch.load('pong_1r_sigma_'+str(sigma)+'/best_model.zip_evals_1000_smooth_attack_eps_'+str(attack_mag)+'_attack_step_count_multiplier_2_attack_step_2.5500000000000003_threshold_'+str(thresh)+'_num_smoothing_points_128.pth'))/2. + .5))
# 		attack_vals.append(attack_val)
# 		attack_sems.append(attack_sem)
# 	attack_mags = [0] + attack_mags_nonzero
# 	plt.errorbar([x/255 for x in attack_mags],attack_vals,  yerr= attack_sems, color='blue',  linestyle ="-",label="Policy Smoothing (σ = " + str(sigma/255.) + ')')


plt.legend()
plt.title('(b) Pong (One-Round)', fontsize=18)
plt.xlim(0,.4)
plt.xlabel('Perturbation Budget', fontsize=14)
plt.ylim(0,1)
plt.ylabel('Certified Win Rate', fontsize=14)
plt.savefig('pong_1r_certs.png',dpi=400,bbox_inches='tight')
