import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from EB_TC import EB_TC

plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams.update({
    "font.size": 17,      
    # "axes.titlesize": 16,  
    "axes.labelsize": 17.5,  
    "xtick.labelsize": 17,
    "ytick.labelsize": 17, 
    "legend.fontsize": 15.1 
})

'''parameters'''
n_runs = 1000 # independent repetitions
gap = 0.125 # delta

T = 10000
alpha = 3
beta = 2/3
c = 2
seed = 42

true_means = np.array([0.6, 0.55, 0.45, 0.3, 0.2])
true_losses = 1 - true_means

K = len(true_means)
i_star = np.argmin(true_losses)

def pure_mixed_sto(K, seed=seed, T=T, n_runs=n_runs, gap=gap, alpha=alpha, beta=beta, c=c):
    import cvxpy as cp
    from tqdm import trange

    np.random.seed(seed)
    print(f"[pure_mixed_sto] 🚩 K = {K}, gap = {gap}, alpha = {alpha}, c = {c}")   

    def generate_loss(reward_mean): # input: reward mean -> output: loss 
        return 1 - np.random.binomial(1, reward_mean)

    # --------------------------------------
    # FTPL (alone)
    # --------------------------------------
    regret_ftpl_a = np.zeros((n_runs, T))
    
    for run in trange(n_runs, desc="[FTPL]"):
        hat_L = np.zeros(K)
        cumulative_regret = []
        for t in range(1, T+1):            
            # learning rate
            eta_t = c * K ** (1 / alpha - 0.5) / np.sqrt(t)
            
            # exploitation
            r_t = np.random.pareto(alpha, size=K) + 1
            A_t = np.argmin(hat_L - r_t / eta_t)
            
            # exploration
            sigma_i = np.empty(K, dtype=int)
            sigma_i = np.argsort(np.argsort(hat_L)) + 1
            underline_L = hat_L - np.min(hat_L)
            
            term1 = 1 / (1 + eta_t * underline_L)
            term2 = 1 / (sigma_i ** (1 / alpha))
            q_t = np.sqrt(np.minimum(term1, term2) ** (alpha + 1))
            p_t = q_t / np.sum(q_t)
            
            B_t = np.random.choice(np.arange(K), p=p_t)
            loss = generate_loss(true_means[B_t])
            
            # update estimated loss
            hat_L[B_t] += loss / p_t[B_t]
            
            # regret tracking
            inst_regret = true_losses[A_t] - true_losses[i_star]
            cumulative_regret.append(inst_regret if t==1 else cumulative_regret[-1] + inst_regret)
            
        regret_ftpl_a[run, :] = cumulative_regret

    avg_regret_ftpl_a = np.mean(regret_ftpl_a, axis=0)
    std_regret_ftpl_a = np.std(regret_ftpl_a, axis=0)
    
    # --------------------------------------
    # FTPL (Mixed)
    # --------------------------------------
    regret_ftpl = np.zeros((n_runs, T))

    for run in trange(n_runs, desc="[FTPL(Mixed)]"):
        hat_L = np.zeros(K)
        cumulative_regret = []
        pure_explorer = EB_TC(K=K)
        
        for t in range(1, T+1):            
            # learning rate
            eta_t = c * K ** (1 / alpha - 0.5) / np.sqrt(t)
            
            # exploitation
            r_t = np.random.pareto(alpha, size=K) + 1
            A_t = np.argmin(hat_L - r_t / eta_t)
            
            # exploration
            B_t = pure_explorer.select_next_arm()
            loss = generate_loss(true_means[B_t])
            pure_explorer.update(B_t, 1 - loss)
            
            # update estimated loss 
            hat_L[B_t] += loss
            
            # regret tracking
            inst_regret = true_losses[A_t] - true_losses[i_star]
            cumulative_regret.append(inst_regret if t==1 else cumulative_regret[-1] + inst_regret)

        regret_ftpl[run, :] = cumulative_regret

    avg_regret_ftpl = np.mean(regret_ftpl, axis=0)
    std_regret_ftpl = np.std(regret_ftpl, axis=0)

    # --------------------------------------
    # FTRL (alone)
    # --------------------------------------
    regret_ftrl_a = np.zeros((n_runs, T))
    
    for run in trange(n_runs, desc="[FTRL]"):
        hat_L = np.zeros(K)
        cumulative_regret = []
        for t in range(1, T+1):            
            # learning rate
            eta_t = 2 * K ** (0.5 - beta) / np.sqrt(t)
            
            # solve optimization
            p = cp.Variable(K)
            tsallis = -(1 / eta_t) * (1 / (beta * (1 - beta))) * cp.sum(cp.power(p, beta) - beta * p)
            objective = cp.Minimize(cp.sum(cp.multiply(hat_L, p)) + tsallis)
            constraints = [p >= 0, cp.sum(p) == 1]
            prob = cp.Problem(objective, constraints)
            prob.solve()
            
            # exploitation
            p_t = p.value
            A_t = np.random.choice(np.arange(K), p=p_t)
            
            # exploration
            p_power = p_t ** (1 - beta / 2)
            q_t = p_power / np.sum(p_power)
            B_t = np.random.choice(np.arange(K), p=q_t)
            loss = generate_loss(true_means[B_t])
            
            # update estimated loss
            hat_L[B_t] += loss / q_t[B_t]
            
            # regret tracking
            inst_regret = true_losses[A_t] - true_losses[i_star]
            cumulative_regret.append(inst_regret if t==1 else cumulative_regret[-1] + inst_regret)
            
        regret_ftrl_a[run, :] = cumulative_regret

    avg_regret_ftrl_a = np.mean(regret_ftrl_a, axis=0)
    std_regret_ftrl_a = np.std(regret_ftrl_a, axis=0)
    
    # --------------------------------------
    # FTRL (Mixed)
    # --------------------------------------
    regret_ftrl = np.zeros((n_runs, T))

    for run in trange(n_runs, desc="[FTRL(Mixed)]"):
        hat_L = np.zeros(K)
        cumulative_regret = []
        pure_explorer = EB_TC(K=K)
        
        for t in range(1, T+1):            
            # learning rate
            eta_t = 2 * K ** (0.5 - beta) / np.sqrt(t)
            
            # solve optimization
            p = cp.Variable(K)
            tsallis = -(1 / eta_t) * (1 / (beta * (1 - beta))) * cp.sum(cp.power(p, beta) - beta * p)
            objective = cp.Minimize(cp.sum(cp.multiply(hat_L, p)) + tsallis)
            constraints = [p >= 0, cp.sum(p) == 1]
            prob = cp.Problem(objective, constraints)
            prob.solve()
            
            # exploitation
            p_t = p.value
            A_t = np.random.choice(np.arange(K), p=p_t)
            
            # exploration
            B_t = pure_explorer.select_next_arm()
            loss = generate_loss(true_means[B_t])
            pure_explorer.update(B_t, 1 - loss)

            # regret tracking
            hat_L[B_t] += loss
            
            # regret tracking
            inst_regret = true_losses[A_t] - true_losses[i_star]
            cumulative_regret.append(inst_regret if t==1 else cumulative_regret[-1] + inst_regret)
            
        regret_ftrl[run, :] = cumulative_regret

    avg_regret_ftrl = np.mean(regret_ftrl, axis=0)
    std_regret_ftrl = np.std(regret_ftrl, axis=0)

    return (avg_regret_ftpl_a, avg_regret_ftpl, avg_regret_ftrl_a, avg_regret_ftrl, 
            std_regret_ftpl_a, std_regret_ftpl, std_regret_ftrl_a, std_regret_ftrl)

# --------------------------------------
# Regret plot (four)
# --------------------------------------
(avg_regret_ftpl_a, avg_regret_ftpl, avg_regret_ftrl_a, avg_regret_ftrl, 
std_regret_ftpl_a, std_regret_ftpl, std_regret_ftrl_a, std_regret_ftrl) = pure_mixed_sto(K)

x = np.arange(1, T+1)
plt.figure(figsize=(6,4), dpi=300)

# mixed
plt.plot(x, avg_regret_ftpl, label=f"FTPL (Mixed)", color='blue', linestyle='--')
plt.fill_between(x, avg_regret_ftpl - std_regret_ftpl, avg_regret_ftpl + std_regret_ftpl, color='blue', alpha=0.2)
plt.plot(x, avg_regret_ftrl, label=f"FTRL (Mixed)", color='red', linestyle='--')
plt.fill_between(x, avg_regret_ftrl - std_regret_ftrl, avg_regret_ftrl + std_regret_ftrl, color='red', alpha=0.2)

# alone
plt.plot(x, avg_regret_ftpl_a, label=f"FTPL (Ours)", color='blue')
plt.fill_between(x, avg_regret_ftpl_a - std_regret_ftpl_a, avg_regret_ftpl_a + std_regret_ftpl_a, color='blue', alpha=0.2)
plt.plot(x, avg_regret_ftrl_a, label=f"FTRL", color='red')
plt.fill_between(x, avg_regret_ftrl_a - std_regret_ftrl_a, avg_regret_ftrl_a + std_regret_ftrl_a, color='red', alpha=0.2)

plt.xlabel("Time Step")
plt.ylabel("Regret")
plt.xlim(0, T)
plt.ylim(bottom=0)
plt.grid(True)
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()