import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from EB_TC import EB_TC

plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams.update({
    "font.size": 17,        
    # "axes.titlesize": 16, 
    "axes.labelsize": 17.5, 
    "xtick.labelsize": 17, 
    "ytick.labelsize": 17, 
    "legend.fontsize": 15.1  
})

'''parameters'''
n_runs = 1000 # independent repetitions
gap = 0.125 # delta

T = 10000
alpha = 3
beta = 2/3
c = 2
seed = 42

true_means = np.array([0.6, 0.55, 0.45, 0.3, 0.2])
true_losses = 1 - true_means

K = len(true_means)
i_star = np.argmin(true_losses)

def pure_alone_sto(K, seed=seed, T=T, n_runs=n_runs, gap=gap, alpha=alpha, beta=beta, c=c):
    import cvxpy as cp
    from tqdm import trange

    np.random.seed(seed)
    print(f"[pure_alone_sto] 🚩 K = {K}, gap = {gap}, alpha = {alpha}, c = {c}")   

    def generate_loss(reward_mean): # input: reward mean -> output: loss 
        return 1 - np.random.binomial(1, reward_mean)

    # --------------------------------------
    # pure exploration (alone)
    # --------------------------------------
    regret_ebtc_a = np.zeros((n_runs, T))
    
    for run in trange(n_runs, desc="[EBTC]"):
        algo = EB_TC(K=K)
        cumulative_regret = []
        for t in range(T):
            arm = algo.select_next_arm()
            loss = generate_loss(true_means[arm])
            algo.update(arm, 1- loss)
            
            inst_regret = true_losses[arm] - true_losses[i_star]
            cumulative_regret.append(inst_regret if t == 0 else cumulative_regret[-1] + inst_regret)
        
        regret_ebtc_a[run, :] = cumulative_regret
    
    avg_regret_ebtc_a = np.mean(regret_ebtc_a, axis=0)
    std_regret_ebtc_a = np.std(regret_ebtc_a, axis=0)
    
    # --------------------------------------
    # FTPL (alone)
    # --------------------------------------
    regret_ftpl_a = np.zeros((n_runs, T))
    
    for run in trange(n_runs, desc="[FTPL]"):
        hat_L = np.zeros(K)
        cumulative_regret = []
        for t in range(1, T+1):            
            # learning rate
            eta_t = c * K ** (1 / alpha - 0.5) / np.sqrt(t)
            
            # exploitation
            r_t = np.random.pareto(alpha, size=K) + 1
            A_t = np.argmin(hat_L - r_t / eta_t)
            
            # exploration
            sigma_i = np.empty(K, dtype=int)
            sigma_i = np.argsort(np.argsort(hat_L)) + 1
            underline_L = hat_L - np.min(hat_L)
            
            term1 = 1 / (1 + eta_t * underline_L)
            term2 = 1 / (sigma_i ** (1 / alpha))
            q_t = np.sqrt(np.minimum(term1, term2) ** (alpha + 1))
            p_t = q_t / np.sum(q_t)
            
            B_t = np.random.choice(np.arange(K), p=p_t)
            loss = generate_loss(true_means[B_t])
            
            # update estimated loss
            hat_L[B_t] += loss / p_t[B_t]
            
            # regret tracking
            inst_regret = true_losses[A_t] - true_losses[i_star]
            cumulative_regret.append(inst_regret if t==1 else cumulative_regret[-1] + inst_regret)
            
        regret_ftpl_a[run, :] = cumulative_regret

    avg_regret_ftpl_a = np.mean(regret_ftpl_a, axis=0)
    std_regret_ftpl_a = np.std(regret_ftpl_a, axis=0)

    # --------------------------------------
    # FTRL (alone)
    # --------------------------------------
    regret_ftrl_a = np.zeros((n_runs, T))
    
    for run in trange(n_runs, desc="[FTRL]"):
        hat_L = np.zeros(K)
        cumulative_regret = []
        for t in range(1, T+1):            
            # learning rate
            eta_t = 2 * K ** (0.5 - beta) / np.sqrt(t)
            
            # solve optimization
            p = cp.Variable(K)
            tsallis = -(1 / eta_t) * (1 / (beta * (1 - beta))) * cp.sum(cp.power(p, beta) - beta * p)
            objective = cp.Minimize(cp.sum(cp.multiply(hat_L, p)) + tsallis)
            constraints = [p >= 0, cp.sum(p) == 1]
            prob = cp.Problem(objective, constraints)
            prob.solve()
            
            # exploitation
            p_t = p.value
            A_t = np.random.choice(np.arange(K), p=p_t)
            
            # exploration
            p_power = p_t ** (1 - beta / 2)
            q_t = p_power / np.sum(p_power)
            B_t = np.random.choice(np.arange(K), p=q_t)
            loss = generate_loss(true_means[B_t])
            
            # update estimated loss
            hat_L[B_t] += loss / q_t[B_t]
            
            # regret tracking
            inst_regret = true_losses[A_t] - true_losses[i_star]
            cumulative_regret.append(inst_regret if t==1 else cumulative_regret[-1] + inst_regret)
            
        regret_ftrl_a[run, :] = cumulative_regret

    avg_regret_ftrl_a = np.mean(regret_ftrl_a, axis=0)
    std_regret_ftrl_a = np.std(regret_ftrl_a, axis=0)

    return (avg_regret_ebtc_a, std_regret_ebtc_a, avg_regret_ftpl_a, 
            avg_regret_ftrl_a, std_regret_ftpl_a, std_regret_ftrl_a)

# --------------------------------------
# Regret plot
# --------------------------------------
(avg_regret_ebtc_a, std_regret_ebtc_a, avg_regret_ftpl_a, 
avg_regret_ftrl_a, std_regret_ftpl_a, std_regret_ftrl_a) = pure_alone_sto(K)

x = np.arange(1, T+1)
plt.figure(figsize=(6, 4), dpi=300)

plt.plot(x, avg_regret_ebtc_a, label="EB-TC", color='purple')
plt.fill_between(x, avg_regret_ebtc_a - std_regret_ebtc_a, avg_regret_ebtc_a + std_regret_ebtc_a, color='green', alpha=0.2)
plt.plot(x, avg_regret_ftpl_a, label=f"FTPL", color='blue')
plt.fill_between(x, avg_regret_ftpl_a - std_regret_ftpl_a, avg_regret_ftpl_a + std_regret_ftpl_a, color='blue', alpha=0.2)
plt.plot(x, avg_regret_ftrl_a, label=f"FTRL", color='red')
plt.fill_between(x, avg_regret_ftrl_a - std_regret_ftrl_a, avg_regret_ftrl_a + std_regret_ftrl_a, color='red', alpha=0.2)

plt.xlabel("Time Step")
plt.ylabel("Regret")
plt.xlim(0, T)
plt.ylim(bottom=0)
plt.grid(True)
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()