import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pareto

def sample_strong_inlier(alpha, epislon, k, setting, i_arm):
    c = 0.1
    # Define the probabilities based on the image
    if setting == 'LTC':
        gamma = c * (alpha/epislon) ** (1/k)
    elif setting == 'CTL':
        gamma = c * (alpha) ** (1/k)
    p_0 = gamma ** k
    p_1 = gamma ** k / 2  
    if i_arm == 0:    
        r_ = np.random.choice([1 / gamma, 0], p=[p_0, 1 - p_0])
    elif i_arm == 1:
        r_ = np.random.choice([1 / gamma, 0], p=[p_1, 1 - p_1])
    return r_

def strong_corruption(input, M, alpha, epislon, k, setting, i_arm):
    c = 0.1
    if setting == 'LTC':
        gamma = c * (alpha/epislon) ** (1/k)
        inv_scale_response = (np.exp(epsilon) - 1) / (np.exp(epsilon) + 1)
        LTC_p_0 = 0.5 + gamma ** k / 2 * inv_scale_response
        LTC_p_1 = 0.5 + gamma ** k / 4 * inv_scale_response
        if i_arm == 0:
            cor_value = np.random.choice([M/inv_scale_response, -M/inv_scale_response], p=[LTC_p_0, 1 - LTC_p_0])
        elif i_arm == 1:
            cor_value = np.random.choice([M/inv_scale_response, -M/inv_scale_response], p=[LTC_p_1, 1 - LTC_p_1])
        p_positive = alpha
        U_prime = np.where(np.random.rand() < p_positive, cor_value, input)
    elif setting == 'CTL':
        gamma = c * (alpha) ** (1/k)
        CTL_p_0 = gamma ** k / 2
        CTL_p_1 = gamma ** k / (2 * alpha)  
        if i_arm == 0:
            cor_value = np.random.choice([1 / gamma, 0], p=[CTL_p_0, 1 - CTL_p_0])
        elif i_arm == 1:
            cor_value = np.random.choice([1 / gamma, 0], p=[CTL_p_1, 1 - CTL_p_1])
        p_positive = alpha
        U_prime = np.where(np.random.rand() < p_positive, cor_value, input)
    return U_prime

def sample_pareto(i_arm, s, xm, k):
    scale = xm[i_arm]
    samples = pareto.rvs(s, scale=scale, size=1)
    normalization_factor = (s * (scale ** k) / (s - k)) 
    normalized_samples = samples / normalization_factor
    return normalized_samples

def epsilon_ldp_mechanism(U, M, epsilon):
    U_tilde = np.where(np.abs(U) > M, 0, U)
    p_positive = (1 + U_tilde/M) / 2
    U_prime = np.where(np.random.rand() < p_positive, M, -M)
    
    p_response = np.exp(epsilon) / (np.exp(epsilon) + 1)
    scale_response = (np.exp(epsilon) + 1) / (np.exp(epsilon) - 1)
    U_tilde_prime = np.where(np.random.rand() < p_response, scale_response * U_prime, -scale_response * U_prime)
    return U_tilde_prime

def compute_M(epsilon, alpha, k, i_arm, N, delta, setting):
    n = N[i_arm]
    if setting == 'LTC':
        M = min((epsilon / alpha) ** (1 / k) * (1/3), (epsilon * np.sqrt(n) / np.sqrt(np.log(1 / delta))) ** (1 / k))
    elif setting == 'CTL':
        M = min((1 / alpha) ** (1 / k), (epsilon * np.sqrt(n) / np.sqrt(np.log(1 / delta))) ** (1 / k))
    else:
        raise ValueError("Invalid setting. Choose 'LTC' or 'CTL'.")
    return M

def huber_corruption(Input, M, alpha, setting):
    p_positive = alpha
    if setting == 'LTC':
        U_prime = np.where(np.random.rand(len(Input)) < p_positive, np.abs(Input), Input)
    elif setting == 'CTL':
        U_prime = np.where(np.random.rand(len(Input)) < p_positive, M, Input)
    else:
        raise ValueError("Invalid setting. Choose 'LTC' or 'CTL'.")
    return U_prime

def Whole_Procedure(x_i, alpha, epsilon, i_arm, N, setting, is_strong_cor = False):
    M = compute_M(epsilon=epsilon, alpha=alpha, k=2, i_arm=i_arm, N=N, delta=0.05, setting=setting)

    if setting == 'CTL':
        if is_strong_cor:
            x_i = strong_corruption(x_i, M, alpha, epsilon, k=2, setting='CTL', i_arm=i_arm)
        else:
            x_i = huber_corruption(x_i, M, alpha, setting='CTL')

    y_i = epsilon_ldp_mechanism(x_i, M, epsilon)

    if setting == 'LTC':
        if is_strong_cor:
            y_i = strong_corruption(y_i, M, alpha, epsilon, k=2, setting='LTC', i_arm=i_arm)
        else:
            y_i = huber_corruption(y_i, M, alpha, setting='LTC')
    return y_i  

# Weak Corruption
def UCB_Pareto(n_arms, T, c, epsilon, alpha, s, xm, setting):
    mu_hat = np.zeros(n_arms)  # Estimated means
    N = np.zeros(n_arms, dtype=int)  # Number of pulls for each arm
    rewards = np.zeros((n_arms, T))  # Rewards for each arm at each time step
    regrets = []
    for t in range(1, T + 1):
        if any(N[a] <= 6 * np.log(t) / alpha for a in range(n_arms)):
            a_t = np.argmin(N)    # Pull the least pulled arm
        else:
            gamma_t = ((1 / epsilon) * np.sqrt(4 * np.log(t) / N)) ** (1 - 1/k)
            beta_t = np.zeros(n_arms)
            for a in range(n_arms):
                if setting == 'LTC':
                    beta_t[a] = c * (alpha / epsilon) ** (1 - 1 / k) + c * gamma_t[a]
                elif setting == 'CTL':
                    beta_t[a] = c * alpha ** (1 - 1 / k) + c * gamma_t[a]
            UCB = mu_hat + beta_t
            a_t = np.argmax(UCB)
        
        regrets.append(xm_means[0] - xm_means[a_t])
        N[a_t] += 1
        reward = sample_pareto(a_t, s, xm, k)
        c_privacy_re = Whole_Procedure(reward, alpha, epsilon, a_t, N, setting)
        #print(c_privacy_re)
        #print(N[a_t])
        rewards[a_t, N[a_t]] = c_privacy_re
        mu_hat[a_t] = np.mean(rewards[a_t, :N[a_t]])
    return regrets

# Strong Corruption
def UCB_Strong_Corruption(T, c, epsilon, alpha, setting):
    n_arms = 2
    mu_hat = np.zeros(n_arms)  # Estimated means
    N = np.zeros(n_arms, dtype=int)  # Number of pulls for each arm
    rewards = np.zeros((n_arms, T))  # Rewards for each arm at each time step
    regrets = []
    for t in range(1, T + 1):
        if any(N[a] <= 6 * np.log(t) / alpha for a in range(n_arms)):
            a_t = np.argmin(N)    # Pull the least pulled arm
        else:
            gamma_t = ((1 / epsilon) * np.sqrt(4 * np.log(t) / N)) ** (1 - 1/k)
            beta_t = np.zeros(n_arms)
            for a in range(n_arms):
                if setting == 'LTC':
                    beta_t[a] = c * (alpha / epsilon) ** (1 - 1 / k) + c * gamma_t[a]
                elif setting == 'CTL':
                    beta_t[a] = c * alpha ** (1 - 1 / k) + c * gamma_t[a]
            UCB = mu_hat + beta_t
            a_t = np.argmax(UCB)
        
        regrets.append(xm_means[0] - xm_means[a_t])
        N[a_t] += 1
        reward = sample_strong_inlier(alpha, epsilon, k, setting, a_t)
        c_privacy_re = Whole_Procedure(reward, alpha, epsilon, a_t, N, setting, is_strong_cor=True)
        #print(c_privacy_re)
        #print(N[a_t])
        rewards[a_t, N[a_t]] = c_privacy_re
        mu_hat[a_t] = np.mean(rewards[a_t, :N[a_t]])
    return regrets

alpha = 0.02
epsilon = 0.2
k = 2
s = 11
n_arms = 10
xm = np.arange(1, n_arms + 1)
xm_means = (s-k) / (xm**(k-1)*(s - 1))
print(xm)
print(xm_means)
T = int(1e2)
c = 0.5
#regrets = UCB_Pareto(n_arms, T, c, epsilon, alpha, s, xm, setting = 'LTC')
#regrets = UCB_Strong_Corruption(T, c, epsilon, alpha, setting = 'LTC')
