import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from math import sqrt
from scipy.integrate import quad
from scipy.stats import norm
import scipy
import scipy.stats as stats
import random
import copy
import pandas as pd


# Jupyter-friendly plots
%matplotlib inline

## to do tonight - get the baselines two-point and stable weight running for this procedure, compare confidence intervals


## import dataset
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
readRDS = robjects.r['readRDS']
df = readRDS('study_1.rds')
df = pandas2ri.rpy2py_dataframe(df)


np.sum(df.Z_mw==2)


0.1* np.log10(np.log10(100))/np.sqrt(100)


rewards = df["Y_mw"]
arms = df["Z_mw"]-1


df


# read results of previous confidence interval
past_cis = pd.read_csv('conf_ints_study.csv')


### batch thompson sampling
past_cis


class BatchThompsonSampling:
    def __init__(self, n_arms, batch_sizes):
        self.n_arms = n_arms
        self.batch_sizes = batch_sizes
        self.successes = np.zeros(n_arms)  # Number of successes (1s)
        self.failures = np.zeros(n_arms)  # Number of failures (0s)

    # sample enough to get
    # t indicates batch number
    def select_arm(self, t):
        # uniform sampling
        if t == 0:
            return np.random.choice(range(self.n_arms), self.batch_sizes[t])
        else:
            # sample n_batch runs of thompson sampling
            prop_calculation_theta = np.random.beta(self.successes + 1, self.failures + 1, size = (5000, self.n_arms))
            selected_arms = np.argmax(prop_calculation_theta, axis = 1)

            # calculate propensities from next batch
            propensity = np.zeros(self.n_arms)
            for i in range(self.n_arms):
                propensity[i] = 0.1/self.n_arms + (1-0.1) * np.mean(selected_arms == i)
                
            return np.random.choice(range(self.n_arms), self.batch_sizes[t], p = propensity)

    def update(self, arms, rewards):
        # update successes and failures
        for i in range(self.n_arms):
            self.successes[i] = self.successes[i] + np.sum((arms == i) * rewards)
            self.failures[i] = self.failures[i] + np.sum((arms == i) * (1-rewards))



### get propensity calculation function

def propensity_calculation(df):
    # get the number treatments and batches
    n_batches = len(set(df.batch))
    n_arms = len(set(df.Z_rtw))
    prop_matrix = np.zeros((n_arms, n_batches))
    prop_matrix[:, 0] = 1/n_arms

    successes = np.zeros(n_arms)
    failures = np.zeros(n_arms)

    # calculate propensities based on batched TS
    for batch in range(1,n_batches):
        # calculate previous successes and failures
        for arm in range(n_arms):
            successes[arm] += np.sum(df.Y_rtw[(df.batch == batch) * (df.Z_rtw == arm+1)] == 1)
            failures[arm] += np.sum(df.Y_rtw[(df.batch == batch) * (df.Z_rtw == arm+1)] == 0)
        
        ## look at previous batches
        prop_calculation_theta = np.random.beta(successes + 1, 
                                                failures + 1, 
                                                size = (100000, n_arms))
        selected_arms = np.argmax(prop_calculation_theta, axis = 1)

        propensity = np.zeros(n_arms)
        for i in range(n_arms):
            propensity[i] = 0.1/n_arms + (1-0.1) * np.mean(selected_arms == i)

        prop_matrix[:, batch] = propensity

    return prop_matrix

### get function to create propensity function for each time step
def prop_matrix_preprocess(df, props):

    n_arms = len(set(df.Z_rtw))
    prop_matrix = np.zeros((n_arms, len(df)))

    # get indices of batches
    for batch in set(df.batch):
        indices = np.where(df.batch == batch)[0]
        for index in indices:
            prop_matrix[:, int(index)] = props[:, int(batch - 1)]

    return prop_matrix



### methods for conducting inference on arms

def hadad_et_al(rewards, arms, propensity, target_arm = 0, alpha = 0.1, decay_rate = 0):
    arms = np.array(arms)
    rewards = np.array(rewards)
    e_t = propensity[target_arm-1,:]
    T = len(arms)
    
    ## calculate anytime valid interval based on Empirical Bernstein Predictable Plug-in
    scores = np.zeros(len(arms))
    weights = np.zeros(len(arms))
    lower = 0
    upper = 1
    for i in range(len(arms)):
        if i == 0 or np.sum(arms[0:i-1]==target_arm) == 0  :
            meani = 0
        else:
            meani = np.mean(rewards[0:i-1][arms[0:i-1] == target_arm])

        scores[i] = rewards[i]*(arms[i] == target_arm)/e_t[i] + (1 - (arms[i] == target_arm)/e_t[i]) * meani
        
        allocation = e_t[i]/(T - i) + (1-e_t[i])*((i+1)**(-decay_rate))/((i+1)**(-decay_rate) + (T**(1-decay_rate)-(i+1)**(1-decay_rate))/(1-decay_rate))
        #print(allocation)
        if i == len(arms)-1:
            allocation = 1
        
        if (i != 0):
            weights[i] = np.sqrt(e_t[i] * (1 - np.sum(weights[0:i-1]**2/e_t[0:i-1])) * allocation)
        else:
            weights[i] = np.sqrt(e_t[i] * allocation)
        
    #print(allocation)
    #print(weights)
    

    # compute constant allocation confidence interval
    const_weights = np.sqrt(e_t/T)
    Q_const = np.sum(const_weights * scores)/np.sum(const_weights)
    sd_const = np.sqrt(np.sum(const_weights**2 * (scores - Q_const)**2)/(np.sum(const_weights)**2))

    Q_twopoint = np.sum(weights * scores)/np.sum(weights)
    sd_twopoint = np.sqrt(np.sum(weights**2 * (scores - Q_const)**2)/(np.sum(weights)**2))
    
    crit = stats.norm.ppf(1-alpha/2) 

    return Q_const, [Q_const - crit*sd_const, Q_const + crit*sd_const], Q_twopoint, [Q_twopoint - crit*sd_twopoint, Q_twopoint + crit*sd_twopoint]


def av_conf_int(rewards, arms, target_arm, alpha = 0.1):
    arms = np.array(arms)
    rewards = np.array(rewards)
    rel_rewards = rewards[arms == target_arm]
    center = 1/2

    ## calculate anytime valid interval based on Empirical Bernstein Predictable Plug-in
    lambdas = np.zeros(len(rel_rewards))
    lambdas[0] = 1/2
    lower = 0
    upper = 1
    for i in range(len(rel_rewards)-1):
        if i == 0:
            meani = 1/2
            vari = 1/4
        else:
            meani = (np.sum(rel_rewards[0:i]) + 1/2)/(i+1)
            vari = (1/4 + np.sum((rel_rewards[0:i]-meani)**2) ) / (i+1)

        lambdas[i+1] = np.min([1/2, np.sqrt( 2 * np.log(2/alpha) / (vari * (i+1) * np.log(i+2))  )])

        center = np.sum(rel_rewards[0:i+1] * lambdas[0:i+1])/np.sum(lambdas[0:i+1])
        v = 4*(rel_rewards[0:i+1] - meani)**2 
        psi = (-np.log(1-lambdas[0:i+1]) - lambdas[0:i+1])/4
        radius = (np.log(2/alpha) + np.sum(v * psi) )/np.sum(lambdas[0:i+1])

        lower = np.max([lower, center-radius])
        upper = np.min([upper, center+radius])

    return center, [lower, upper]
  


# reward streams generated for each arm

def simulate(bandit, reward_streams, batch_sizes, target_arm = 0, baseline_arm = None):

    means = np.zeros(bandit.n_arms)

    # Continue simulation
    for t in range(len(batch_sizes)):
        arms = bandit.select_arm(t) ## this returns a vector of arms proporitional to the batch size
        
        ## get the number of pulls for each arm
        reward = np.zeros(int(bandit.batch_sizes[t]))
        for i in range(bandit.n_arms):
            indices = np.where(arms == i)[0]
            reward[indices] = reward_streams[i, np.where(arms == i)[0]]
            
        bandit.update(arms, reward)


    # should return just a mean
    if baseline_arm == None:
        return bandit.successes[target_arm]/(bandit.successes[target_arm] + bandit.failures[target_arm])
    else:
        target_mean = bandit.successes[target_arm]/(bandit.successes[target_arm] + bandit.failures[target_arm])
        baseline_mean = bandit.successes[baseline_arm]/(bandit.successes[target_arm] + bandit.failures[baseline_arm])
        return target_mean - baseline_mean








from joblib import Parallel, delayed

def single_simulation(theta_i, means, n_arms, batch_sizes, target_arm, baseline_arm = None):
    
    T = np.sum(batch_sizes)
    # reward stream is a matrix
    reward_streams = np.zeros((len(means), int(T)))
    for i in range(len(means)):
        reward_streams[i,:] = np.random.binomial(n=1, p=means[i], size=int(T))

    #
    bandit = BatchThompsonSampling(n_arms,batch_sizes)
    
    return simulate(bandit, reward_streams, batch_sizes, target_arm = target_arm, baseline_arm = baseline_arm)

def simulation_inference(rewards, arms, n_arms, batch_sizes, target_arm, test_stat_obs, baseline_arm = None, grid=[], grid_fidelity=200, B=500, alpha=0.1, n_jobs=-1):
    arms = np.array(arms)
    rewards = np.array(rewards)

    means = np.zeros(n_arms)
    sds = np.zeros(n_arms)
    for i in range(n_arms):
        counts = np.sum(arms == i)
        means[i] = 0 if i == target_arm else np.mean(rewards[arms == i]) + 0.1*np.nanmax([0, np.log10(np.log10(counts)+0.01)/np.sqrt(counts)])

    if baseline_arm != None:
        means[baseline_arm] = np.mean(rewards[arms == baseline_arm])

    theta = grid if len(grid) != 0 else np.linspace(0, 1, grid_fidelity)
    grid_fidelity = len(theta)

    p_values = np.zeros(grid_fidelity)
    reject = np.zeros(grid_fidelity)

    for i in range(grid_fidelity):
        theta_i = theta[i]
        means[target_arm] = theta_i 
        test_stats = Parallel(n_jobs=n_jobs)(
            delayed(single_simulation)(theta_i = theta_i, 
                                       means = means, 
                                       n_arms = n_arms, 
                                       batch_sizes=batch_sizes, 
                                       target_arm=target_arm,
                                       baseline_arm = baseline_arm)
            for _ in range(B)
        )

        p_values[i] = np.mean(np.array(test_stats) <= test_stat_obs)
        reject[i] = int(p_values[i] <= alpha / 2 or p_values[i] >= 1 - alpha / 2)

    point_est = theta[np.argmin(np.abs(p_values - 0.5))] 
    if (baseline_arm != None):
        point_est = point_est - means[baseline_arm]

    
    if np.sum(reject == 0) == 0:
        lower, upper = point_est, point_est
    else:
        interval_values = theta[reject == 0]
        if (baseline_arm != None):
            interval_values = interval_values - means[baseline_arm]
        lower, upper = np.min(interval_values), np.max(interval_values)

    return point_est, [lower, upper], p_values



### test for mean of arm 0 in batch experiments
rewards = df.Y_rtw
arms = df.Z_rtw - 1
# batch sizes
batch_sizes = np.zeros(len(set(df.batch)))
for i in range(len(batch_sizes)):
    batch_sizes[i] = np.sum(df.batch == i+1)
batch_sizes = batch_sizes.astype(int)


### right to work simulation-based CIs
target_arms = range(8)
cis_list = []

with tqdm(total=len(target_arms)) as pbar:
    
    for target_arm in target_arms:
        
        # sample mean of arm 1
        test_stat_obs = np.sum(rewards * (arms == target_arm))/np.sum(arms == target_arm)
        
        # test simulation inference on arm 1
        x = simulation_inference(rewards = rewards,
                                arms = arms,
                                n_arms = len(target_arms),
                                batch_sizes = batch_sizes,
                                target_arm = target_arm,
                                test_stat_obs = test_stat_obs,
                                grid_fidelity = 100,
                                B=100, alpha = 0.1)
    
        cis_list.append(x)
        
        pbar.update(1)

    


### current ci point est and upper/lower
sim_ci_upper = np.zeros(8)
sim_ci_est = np.zeros(8)
sim_ci_lower = np.zeros(8)
for i in range(8):
    sim_ci_upper[i] = cis_list[i][1][1]
    sim_ci_lower[i] = cis_list[i][1][0]
    sim_ci_est[i] = cis_list[i][0]


#plt.plot(range(8),past_cis.estimate[10:] - stats.norm.ppf(0.95) * past_cis.iloc[10:,3])
#plt.plot(range(8), past_cis.estimate[10:] + stats.norm.ppf(0.95) * past_cis.iloc[10:,3])


## calculation for past CIs
past_ci_upper, past_ci_lower = past_cis.estimate[10:] + stats.norm.ppf(0.95) * past_cis.iloc[10:,3], past_cis.estimate[10:] - stats.norm.ppf(0.95) * past_cis.iloc[10:,3]
past_ci_est = past_cis.estimate[10:]


## calculation with alternative methods
arms = range(1,9)
stable_ci_upper, stable_ci_lower, stable_ci_est = np.zeros(len(arms)), np.zeros(len(arms)),np.zeros(len(arms))
twop_ci_upper, twop_ci_lower, twop_ci_est = np.zeros(len(arms)), np.zeros(len(arms)), np.zeros(len(arms))
av_ci_upper, av_ci_lower, av_ci_est = np.zeros(len(arms)), np.zeros(len(arms)), np.zeros(len(arms))

rewards = df.Y_rtw
arms = df.Z_rtw
props = propensity_calculation(df)
propensity = prop_matrix_preprocess(df, props)

for arm in arms:
    # av ci bounds
    est, bounds = av_conf_int(rewards, arms, target_arm = arm)
    av_ci_upper[arm-1], av_ci_lower[arm-1] = bounds[1], bounds[0]
    av_ci_est[arm-1] = est

    
    result = hadad_et_al(rewards, arms, propensity, target_arm = arm, alpha = 0.1, decay_rate = 0)
    stable_ci_upper[arm-1], stable_ci_lower[arm-1] = result[1][1], result[1][0]
    stable_ci_est[arm-1] = result[0]
    twop_ci_upper[arm-1], twop_ci_lower[arm-1] = result[3][1], result[3][0]
    twop_ci_est[arm-1] = result[2]

    
    
    


import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_context("talk")
sns.set_style("whitegrid")

# Example data (assuming these come from your DataFrame or variables)
categories = past_cis.term[10:]
cat = categories.to_list()
for i in range(len(cat)):
    cat[i] = cat[i][14:]

# Convert categories to x-axis positions
x = np.arange(len(categories))

# Calculate CI widths and their differences
past_widths = np.min([past_ci_upper - past_ci_lower, av_ci_upper - av_ci_lower,
                      stable_ci_upper - stable_ci_lower, twop_ci_upper - twop_ci_lower], 
                    axis = 0)

sim_widths = (sim_ci_upper - sim_ci_lower)
width_diffs = (sim_widths - past_widths)#.to_list()  # difference: simulation - previous

# Create the plot
plt.figure(figsize=(15, 6))

offset = 0.1
plt.errorbar(x - 2 * offset, y=past_ci_est,
             yerr=[past_ci_est - past_ci_lower, past_ci_upper - past_ci_est],
             fmt='o', capsize=5, label='Previous', color="black")
plt.errorbar(x - offset, y=av_ci_est,
             yerr=[av_ci_est - av_ci_lower, av_ci_upper - av_ci_est],
             fmt='o', capsize=5, label='Anytime Valid', color="green")
plt.errorbar(x + 0.0, y=stable_ci_est,
             yerr=[stable_ci_est - stable_ci_lower, stable_ci_upper - stable_ci_est],
             fmt='o', capsize=5, label='Stable Weights', color="maroon")
plt.errorbar(x + offset, y=twop_ci_est,
             yerr=[twop_ci_est - twop_ci_lower, twop_ci_upper - twop_ci_est],
             fmt='o', capsize=5, label='Two-Point', color="grey")
plt.errorbar(x + 2 * offset, y=sim_ci_est,
             yerr=[sim_ci_est - sim_ci_lower, sim_ci_upper - sim_ci_est],
             fmt='o', capsize=5, label='Simulation-Based', color="teal")

# Annotate the difference in widths on the plot
for i in range(len(x)):
    # Format the difference nicely, e.g., with 3 decimals
    diff_text = rf"$\Delta W$ = {width_diffs[i]:.3f}"
    # Position text slightly above the higher of the two error bars
    max_y = av_ci_lower[i]
    plt.text(x[i], max_y - 0.05, diff_text, ha='center', fontsize=14, color='black', weight = 'bold')

# Customize x-axis with categorical labels
plt.xticks(x, cat, rotation=0)

# Labels and legend
plt.ylabel('Proportion')
plt.xlabel('Measure')
plt.title('Average Proportion of Respondents Supporting the Measure')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.ylim((.1, 1.05))
plt.savefig('single_arm_ci_plot.png', format='png', dpi=600)
#plt.savefig('single_arm_ci_plot_replacement.png', format='png', dpi=600)

plt.show()



print(sim_ci_upper - sim_ci_lower)
print(sim_ci_upper - sim_ci_lower  - (past_ci_upper - past_ci_lower))





rewards = df.Y_rtw
arms = df.Z_rtw - 1
# batch sizes
batch_sizes = np.zeros(len(set(df.batch)))
for i in range(len(batch_sizes)):
    batch_sizes[i] = np.sum(df.batch == i+1)
batch_sizes = batch_sizes.astype(int)


### right to work simulation-based CIs
target_arms = range(4)
cis_differences_list = []


with tqdm(total=len(target_arms)) as pbar:
    
    for target_arm in target_arms:
        
        # sample mean of arm 1
        target = 2*target_arm
        baseline = 2*target_arm + 1
        
        test_stat_obs = np.sum(rewards * (arms == target))/np.sum(arms == target) - np.sum(rewards * (arms == baseline))/np.sum(arms == baseline)
        
        # test simulation inference on arm 1
        x = simulation_inference(rewards = rewards,
                                arms = arms,
                                n_arms = 2*len(target_arms),
                                batch_sizes = batch_sizes,
                                target_arm = target,
                                baseline_arm = baseline,
                                test_stat_obs = test_stat_obs,
                                grid_fidelity = 200,
                                B=1000)
    
        cis_differences_list.append(x)
        
        pbar.update(1)


cis_differences_list


# get confidence intervals implied by test stat

mean_differences = np.array(past_cis.estimate.to_list())[10::2] - np.array(past_cis.estimate.to_list())[11::2]
std_errors = np.sqrt(np.array(past_cis.iloc[10:,3][::2].to_list())**2 + np.array(past_cis.iloc[11:,3][::2].to_list())**2)

past_ci_upper_diff, past_ci_lower_diff = mean_differences + stats.norm.ppf(0.95) * std_errors, mean_differences- stats.norm.ppf(0.95) * std_errors
past_ci_est_diff = mean_differences


### current ci point est and upper/lower
sim_ci_upper_diff = np.zeros(4)
sim_ci_est_diff = np.zeros(4)
sim_ci_lower_diff = np.zeros(4)
for i in range(4):
    sim_ci_upper_diff[i] = cis_differences_list[i][1][1]
    sim_ci_lower_diff[i] = cis_differences_list[i][1][0]
    sim_ci_est_diff[i] = cis_differences_list[i][0]


#plt.plot(range(8),past_cis.estimate[10:] - stats.norm.ppf(0.95) * past_cis.iloc[10:,3])
#plt.plot(range(8), past_cis.estimate[10:] + stats.norm.ppf(0.95) * past_cis.iloc[10:,3])


(sim_ci_upper_diff-sim_ci_lower_diff)-(past_ci_upper_diff-past_ci_lower_diff)


import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_context("talk")
sns.set_style("whitegrid")
# Example data
categories = past_cis.term[10::2]
cat = categories.to_list()
for i in range(len(cat)):
    cat[i] = cat[i][5:len(cat[i])-5]

# Convert categories to x-axis positions
x = np.arange(len(categories))

# Create the plot
plt.figure(figsize=(10, 10))
#plt.scatter(x, past_ci_est, marker='o', label='Previous')
#plt.fill_between(x, past_ci_lower, past_ci_upper, color='red', alpha=0.2)
offset = 0.1
plt.errorbar(x - offset, y = past_ci_est_diff, yerr = [past_ci_est_diff - past_ci_lower_diff, past_ci_upper_diff-past_ci_est_diff], 
             fmt='o', capsize=10, label='Previous', color = "maroon")
plt.errorbar(x + offset, y = sim_ci_est_diff, yerr = [sim_ci_est_diff - sim_ci_lower_diff, sim_ci_upper_diff-sim_ci_est_diff], 
             fmt='o', capsize=10, label='Simulation-Based', color = "teal")


# Create the plot for sim
#plt.scatter(x, sim_ci_est, marker='o', label='Simulation')
#plt.fill_between(x, sim_ci_lower, sim_ci_upper, color='blue', alpha=0.2)


# Customize x-axis with categorical labels
plt.xticks(x, cat)

# Labels and legend
#plt.xlabel('Category')
plt.ylabel('Estimate')
plt.title('Difference In Response for CA vs. BA')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()



## code for difference between arms
## meed to implement final baselines for these figures (estimate the propensities for these approaches)
rewards = df.Y_rtw
arms = df.Z_rtw - 1
# batch sizes
batch_sizes = np.zeros(len(set(df.batch)))
for i in range(len(batch_sizes)):
    batch_sizes[i] = np.sum(df.batch == i+1)
batch_sizes = batch_sizes.astype(int)


### right to work simulation-based CIs
target_arms = range(4)
baseline = 1


with tqdm(total=len(target_arms)) as pbar:
    
    for target in target_arms: 
        baseline = 
        test_stat_obs = np.sum(rewards * (arms == target))/np.sum(arms == target) - np.sum(rewards * (arms == baseline))/np.sum(arms == baseline)
        
        # test simulation inference on arm 1
        x = simulation_inference(rewards = rewards,
                                arms = arms,
                                n_arms = 8,
                                batch_sizes = batch_sizes,
                                target_arm = target,
                                baseline_arm = baseline,
                                test_stat_obs = test_stat_obs,
                                grid_fidelity = 200,
                                B=1000)
    
        cis_differences_list.append(x)
        
        pbar.update(1)


set(df.Z_rtw)


cis_differences_list[len(cis_differences_list)-3:len(cis_differences_list)]


# get confidence intervals implied by test stat
mean_differences = np.array(past_cis.estimate.to_list())[10:] - np.array(past_cis.estimate.to_list())[10]
std_errors = np.sqrt(np.array(past_cis.iloc[10:,3].to_list())**2 + np.array(past_cis.iloc[10,3])**2)

past_ci_upper_diff, past_ci_lower_diff = mean_differences + stats.norm.ppf(0.95) * std_errors, mean_differences- stats.norm.ppf(0.95) * std_errors
past_ci_est_diff = mean_differences


#print(mean_differences[2::2])
past_ci_upper_diff, past_ci_lower_diff = past_ci_upper_diff[2::2], past_ci_lower_diff[2::2]
past_ci_est_diff = mean_differences[2::2]
print(past_ci_lower_diff)
print(past_ci_upper_diff)





### code for second experiment
### test for mean of arm 0 in batch experiments
rewards = df.Y_mw
arms = df.Z_mw - 1
# batch sizes
batch_sizes = np.zeros(len(set(df.batch)))
for i in range(len(batch_sizes)):
    batch_sizes[i] = np.sum(df.batch == i+1)
batch_sizes = batch_sizes.astype(int)


### right to work simulation-based CIs
target_arms = range(10)
cis_list_mw = []

with tqdm(total=len(target_arms)) as pbar:
    
    for target_arm in target_arms:
        
        # sample mean of arm 1
        test_stat_obs = np.sum(rewards * (arms == target_arm))/np.sum(arms == target_arm)
        
        # test simulation inference on arm 1
        x = simulation_inference(rewards = rewards,
                                arms = arms,
                                n_arms = len(target_arms),
                                batch_sizes = batch_sizes,
                                target_arm = target_arm,
                                test_stat_obs = test_stat_obs,
                                grid_fidelity = 200,
                                B=1000)
    
        cis_list_mw.append(x)
        
        pbar.update(1)

    


cis_list_mw


past_ci_upper_mw, past_ci_lower_mw = past_cis.estimate[:10] + stats.norm.ppf(0.95) * past_cis.iloc[:10,3], past_cis.estimate[:10] - stats.norm.ppf(0.95) * past_cis.iloc[:10,3]
past_ci_est_mw = past_cis.estimate[:10]

### current ci point est and upper/lower
sim_ci_upper_mw = np.zeros(8)
sim_ci_est_mw = np.zeros(8)
sim_ci_lower_mw = np.zeros(8)
for i in range(8):
    sim_ci_upper_mw[i] = cis_list_mw[i][1][1]
    sim_ci_lower_mw[i] = cis_list_mw[i][1][0]
    sim_ci_est_mw[i] = cis_list_mw[i][0]





print(past_ci_upper_mw-past_ci_lower_mw)
print(sim_ci_upper_mw - sim_ci_lower_mw)


### methods for conducting inference with AIPW reweighted

def hadad_raw(rewards, arms, target, propensity, decay_rate = 0.7):
    arms = np.array(arms)
    rewards = np.array(rewards)
    e_t = propensity
    T = len(arms)
    
    ## calculate anytime valid interval based on Empirical Bernstein Predictable Plug-in
    scores = np.zeros(len(arms))
    weights = np.zeros(len(arms))
    lower = 0
    upper = 1
    for i in range(len(arms)):
        if i == 0 or np.sum(arms[0:i-1]==target) == 0  :
            meani = 0
        else:
            meani = np.mean(rewards[0:i-1][arms[0:i-1] == target])

        scores[i] = rewards[i]*(arms[i] == target)/e_t[i] + (1 - (arms[i] == target)/e_t[i]) * meani
        
        allocation = e_t[i]/(T - i) + (1-e_t[i])*((i+1)**(-decay_rate))/((i+1)**(-decay_rate) + (T**(1-decay_rate)-(i+1)**(1-decay_rate))/(1-decay_rate))
        #print(allocation)
        if i == len(arms)-1:
            allocation = 1
        
        if (i != 0):
            weights[i] = np.sqrt(e_t[i] * (1 - np.sum(weights[0:i-1]**2/e_t[0:i-1])) * allocation)
        else:
            weights[i] = np.sqrt(e_t[i] * allocation)    

    # compute constant allocation confidence interval
    const_weights = np.sqrt(e_t/T)
    Q_const = np.sum(const_weights * scores)/np.sum(const_weights)
    sd_const = np.sqrt(np.sum(const_weights**2 * (scores - Q_const)**2)/(np.sum(const_weights)**2))

    Q_twopoint = np.sum(weights * scores)/np.sum(weights)
    sd_twopoint = np.sqrt(np.sum(weights**2 * (scores - Q_const)**2)/(np.sum(weights)**2))
    
    return Q_const, sd_const, Q_twopoint, sd_twopoint

def hadad_diff_in_means(rewards, arms, target, baseline, propensity_target, propensity_baseline, alpha = 0.1, decay_rate = 0):
    x = hadad_raw(rewards, arms, target, propensity_target, decay_rate = decay_rate)
    Q_const_target, sd_const_target, Q_twopoint_target, sd_twopoint_target = x
    y = hadad_raw(rewards, arms, baseline, propensity_ baseline, decay_rate = decay_rate)
    Q_const_baseline, sd_const_baseline, Q_twopoint_baseline, sd_twopoint_baseline = y

    crit_val = stats.norm.ppf(1-alpha/2)
    
    point_est_const = Q_const_target-Q_const_baseline
    sd_const = np.sqrt(sd_const_target**2 + sd_const_baseline**2 )
    
    point_est_twopoint = Q_twopoint_target-Q_twopoint_baseline
    sd_twopoint = np.sqrt(sd_twopoint_target**2 + sd_twopoint_baseline**2 )

    return point_est_const, [point_est_const - crit_val * sd_const, point_est_const - crit_val * sd_const], point_est_twopoint, [point_est_twopoint - crit_val * sd_twopoint, point_est_twopoint +crit_val * sd_twopoint]

    





sim_ci_upper_mw


sim_ci_lower_mw


def av_conf_int(rewards, arms, alpha = 0.1):
    arms = np.array(arms)
    rewards = np.array(rewards)
    rel_rewards = rewards[arms == 0]
    center = 1/2

    ## calculate anytime valid interval based on Empirical Bernstein Predictable Plug-in
    lambdas = np.zeros(len(rel_rewards))
    lambdas[0] = 1/2
    lower = 0
    upper = 1
    for i in range(len(rel_rewards)-1):
        if i == 0:
            meani = 1/2
            vari = 1/4
        else:
            meani = (np.sum(rel_rewards[0:i]) + 1/2)/(i+1)
            vari = (1/4 + np.sum((rel_rewards[0:i]-meani)**2) ) / (i+1)

        lambdas[i+1] = np.min([1/2, np.sqrt( 2 * np.log(2/alpha) / (vari * (i+1) * np.log(i+2))  )])

        center = np.sum(rel_rewards[0:i+1] * lambdas[0:i+1])/np.sum(lambdas[0:i+1])
        v = 4*(rel_rewards[0:i+1] - meani)**2 
        psi = (-np.log(1-lambdas[0:i+1]) - lambdas[0:i+1])/4
        radius = (np.log(2/alpha) + np.sum(v * psi) )/np.sum(lambdas[0:i+1])

        lower = np.max([lower, center-radius])
        upper = np.min([upper, center+radius])

    return center, [lower, upper]



def sample_mean(rewards, arms, alpha = 0.1):
    arms = np.array(arms)
    rewards = np.array(rewards)
    rel_rewards = rewards[arms == 0]
    mean = np.mean(rel_rewards)    
    sd = np.sqrt(np.var(rel_rewards))
    crit = stats.norm.ppf(1-alpha/2)

    return mean, [mean - crit * sd/np.sqrt(len(rel_rewards)), mean + crit * sd/np.sqrt(len(rel_rewards))]

        


def AIPW(rewards, arms, propensity, alpha = 0.1):
    arms = np.array(arms)
    rewards = np.array(rewards)
    e_t = propensity[0,:]
    T = len(arms)
    
    ## calculate anytime valid interval based on Empirical Bernstein Predictable Plug-in
    scores = np.zeros(len(arms))
    weights = np.zeros(len(arms))
    lower = 0
    upper = 1
    for i in range(len(arms)):
        if i == 0 or np.sum(arms[0:i-1]==0) == 0  :
            meani = 0
        else:
            meani = np.mean(rewards[0:i-1][arms[0:i-1] == 0])

        scores[i] = rewards[i]*(arms[i] == 0)/e_t[i] + (1 - (arms[i] == 0)/e_t[i]) * meani

    crit = stats.norm.ppf(1-alpha/2)
    center = np.mean(scores)
    sderr = np.sqrt(np.var(scores))/np.sqrt(len(arms))

    return center, [center - crit * sderr, center + crit * sderr]

    


# code for seeing if confidence interval is in
def in_interval(ci, val):
    return_vec = np.zeros(len(val))
    for i in range(len(val)):
        if (val[i] >= ci[0] and val[i] <= ci[1]):
            return_vec[i] = 1
    return return_vec

def in_interval_pvalues(p_values, alpha):
    return (p_values >= alpha/2) * (p_values <= 1-alpha/2)




sample_mean_cis = []
av_conf_int_cis = []
hadad_stable_cis = []
hadad_twopoint_cis = []
simulation_cis = []


aipw_cis = []


n_arms = 3
true_means = [0.45, 0.5, 0.55]
T_list = [200, 400, 800, 1600]
nsim = 200
alpha = 0.1
B = 200


with tqdm(total=int(nsim * len(T_list))) as pbar:
    for j in range(len(T_list)):
        T = T_list[j]
        for i in range(nsim):
            # generate stream of rewards
            reward_streams = [
                list(np.random.binomial(n=1, p=true_means[arm], size=T))
                for arm in range(n_arms)
            ]
            # run thompson sampling
            ucb_bandit = UCB1(n_arms, T=T)
            ts_rewards, ts_arms, ts_mean0, ts_propensitymatrix = simulate(ucb_bandit, copy.deepcopy(reward_streams), T)

    
            # get confidence intervals
            aipw_cis.append(AIPW(ts_rewards, ts_arms, ts_propensitymatrix, alpha = alpha))
            sample_mean_cis.append(sample_mean(ts_rewards, ts_arms, alpha = alpha))
            av_conf_int_cis.append(av_conf_int(ts_rewards, ts_arms, alpha = alpha))
            x, ci_hadad_stable, y, ci_hadad_twopoint = hadad_et_al(ts_rewards, ts_arms, ts_propensitymatrix, alpha = alpha)
            hadad_stable_cis.append((x,ci_hadad_stable))
            hadad_twopoint_cis.append((y, ci_hadad_twopoint))
            simulation_cis.append(simulation_inference(ts_rewards,
                                                       ts_arms,
                                                       n_arms,
                                                       T, 
                                                       agent_type = "UCB", 
                                                       test_stat_obs = ts_mean0, 
                                                       grid_fidelity = 100, 
                                                       B = B)[0:2])
            pbar.update(1)
            

        
        
        
        


cis_list = [sample_mean_cis, av_conf_int_cis, hadad_stable_cis, hadad_twopoint_cis, simulation_cis]

coverage = np.zeros((len(T_list), len(cis_list)))
width = np.zeros((len(T_list), len(cis_list)))
#avg_upper_bound = np.zeros((len(T_list), len(cis_list)))
#avg_lower_bound = np.zeros((len(T_list), len(cis_list)))
rmse = np.zeros((len(T_list), len(cis_list)))

coverage_sd = np.zeros((len(T_list), len(cis_list)))
width_sd = np.zeros((len(T_list), len(cis_list)))
#avg_upper_bound = np.zeros((len(T_list), len(cis_list)))
#avg_lower_bound = np.zeros((len(T_list), len(cis_list)))
rmse_sd = np.zeros((len(T_list), len(cis_list)))

## need standard errors for all these terms too

for j in range(len(cis_list)):
    cis = cis_list[j]
    for i in range(len(T_list)):
        covi = np.zeros(nsim)
        widthi = np.zeros(nsim)
        rmsei = np.zeros(nsim)

        for k in range(i*nsim, (i+1)*nsim):
            #avg_point = avg_point + 1/nsim * (cis[k][0] - true_means[0])**2
            #avg_upper = avg_upper + 1/nsim * cis[k][1][1]
            #avg_lower = avg_lower + 1/nsim * cis[k][1][0]

            #if j == len(cis_list)-1:
            #    cis[k][1][0] = cis[k][1][0] + 0.01
            #    cis[k][1][1] = cis[k][1][1] - 0.01
            
            covi[k - i*nsim] = in_interval(cis[k][1], [true_means[0]])[0]
            widthi[k - i*nsim] = cis[k][1][1] - cis[k][1][0] 
            rmsei[k - i*nsim] = (cis[k][0] - true_means[0])**2

        coverage[i,j], coverage_sd[i,j] =  np.mean(covi), np.sqrt(np.var(covi))
        width[i,j], width_sd[i,j] = np.mean(widthi), np.sqrt(np.var(width_sd))
        rmse[i,j], width_sd[i,j] = np.mean(rmsei), np.sqrt(np.var(rmsei))

# Coverage data
coverage = coverage.T
coverage_se = coverage_sd.T / np.sqrt(nsim)

# Average CI width data
ci_width = width.T
ci_width_se = width_sd.T / np.sqrt(nsim)

# RMSE data
rmse = rmse.T
rmse_se = rmse_sd.T / np.sqrt(nsim)         
    
            


import matplotlib.pyplot as plt
import numpy as np

import seaborn as sns
sns.set_context("talk")
sns.set_style("whitegrid")

# Data
T_values = np.array([200, 400, 800, 1600])

# Assume coverage, width, rmse, and their standard errors are already defined

# Methods
methods = ["Sample Mean", "Anytime Valid", "Stable Weights", "Two Point Weights", "Simulation"]

# Colors (optional: for prettier matching)
colors = plt.cm.tab10(np.linspace(0, 1, len(methods)))

# Create the figure with subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot Coverage with shaded confidence bands
for i in range(1, coverage.shape[0]):
    axes[0].plot(T_values, coverage[i], label=methods[i], marker='o', color=colors[i])
    axes[0].fill_between(
        T_values,
        coverage[i] - coverage_se[i],
        coverage[i] + coverage_se[i],
        color=colors[i],
        alpha=0.2
    )
axes[0].set_title('Coverage (Nominal 90%)')
axes[0].set_xlabel('T')
axes[0].set_ylabel('Coverage')
axes[0].axhline(y=0.9, color='black', linestyle='dotted', label='Nominal 90%')

# Plot Average CI Width with shaded confidence bands
for i in range(1, ci_width.shape[0]):
    axes[1].plot(T_values, ci_width[i], label=methods[i], marker='o', color=colors[i])
    axes[1].fill_between(
        T_values,
        ci_width[i] - ci_width_se[i],
        ci_width[i] + ci_width_se[i],
        color=colors[i],
        alpha=0.2
    )
axes[1].set_title('Average CI Width')
axes[1].set_xlabel('T')
axes[1].set_ylabel('CI Width')

# Plot RMSE with shaded confidence bands
for i in range(1, rmse.shape[0]):
    axes[2].plot(T_values, rmse[i], label=methods[i], marker='o', color=colors[i])
    axes[2].fill_between(
        T_values,
        rmse[i] - rmse_se[i],
        rmse[i] + rmse_se[i],
        color=colors[i],
        alpha=0.2
    )
axes[2].set_title('MSE')
axes[2].set_xlabel('T')
axes[2].set_ylabel('MSE')

# Only one legend
axes[2].legend()

# Adjust layout, save, and show
plt.tight_layout()
plt.savefig('bad_arm_plot.png', format='png', dpi=300)
plt.show()



# save all objects into file
np.save('coverage_ucb_bad_arm.npy', coverage)
np.save('coverage_se_ucb_bad_arm.npy', coverage_se)
np.save('ci_width_ucb_bad_arm.npy', ci_width)
np.save('ci_width_se_ucb_bad_arm.npy', ci_width_se)
np.save('rmse_ucb_bad_arm.npy', rmse)
np.save('rmse_se_ucb_bad_arm.npy', rmse_se)   


import json
with open("ucb_bad_arm_cis.txt", 'w') as fp:
    json.dump(cis_list, fp)

with open("ucb_bad_arm_cis.txt") as fp:
    a = json.load(fp)
    print(a)


a
