from matplotlib import lines
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
from torch import lt

plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams.update({
    "font.size": 17,          
    # "axes.titlesize": 16,   
    "axes.labelsize": 17.5,   
    "xtick.labelsize": 17,    
    "ytick.labelsize": 17,   
    "legend.fontsize": 15.1  
})

'''parameters'''
n_runs = 100 # independent repetitions
Ks = [2, 4, 8, 16, 32, 64, 128, 256, 512] # number of arms
gap = 0.125 # delta (or gap = 0.0625)

T = 10000
alpha = 3
c = 2
beta = 2/3
seed = 42

def solve_newton(L, eta, tol=1e-9, max_iter=100): 
    """
    p_i = ((eta_t / 3) * (hat_L[i] + lambda) + 1) ** (-3)
    """
    lam_min = -np.min(L) - 3/eta
    lam = lam_min + 1e-6
    
    for _ in range(max_iter):
        vals = (eta/3) * (L + lam) + 1
        
        F = np.sum(vals ** (-3)) - 1
        F_prime = -eta * np.sum(vals ** (-4))
        
        lam_new = lam - F / F_prime
        lam_new = max(lam_new, lam_min + 1e-12)
        
        if abs(lam_new - lam) < tol:
            p_t = ((eta/3) * (L + lam_new) + 1) ** (-3)
            p_t = np.maximum(p_t, 0)
            p_t = p_t / np.sum(p_t)
            return p_t
        
        lam = lam_new
        
    raise RuntimeError("Newton method did not converge")

def compare_to_rouyer(K, seed=seed, T=T, n_runs=n_runs, gap=gap, alpha=alpha, beta=beta, c=c):
    import cvxpy as cp
    import time
    from tqdm import trange

    np.random.seed(seed)
    print(f"[compare_to_rouyer] 🚩 K = {K}, gap = {gap}, alpha = {alpha}, c = {c}, seed = {seed}, n_runs = {n_runs}")

    true_means = np.zeros((T, K))

    phase_start = 0
    phase_num = 0 
    growth_factor = 1.6

    while phase_start < T:
        phase_length = int(growth_factor ** phase_num)
        phase_end = min(phase_start + phase_length, T)
        
        # phase A: optimal = 0, suboptimal = delta
        if phase_num % 2 == 0:
            true_means[phase_start:phase_end, 0] = 0.0
            true_means[phase_start:phase_end, 1:] = gap
        # phase B: optimal = 1, suboptimal = 1 - delta 
        else:
            true_means[phase_start:phase_end, 0] = 1 - gap
            true_means[phase_start:phase_end, 1:] = 1.0
            
        phase_start = phase_end
        phase_num += 1

    i_star = np.argmin(np.sum(true_means, axis=0)) # optimal arm = 0 (fixed)

    def generate_loss(mu): # loss mean -> loss
        return np.random.binomial(1, mu)

    # --------------------------------------
    # FTPL
    # --------------------------------------
    regret_ftpl = np.zeros((n_runs, T))
    time_ftpl = np.zeros((n_runs, T))

    for run in trange(n_runs, desc="[FTPL]"):
        hat_L = np.zeros(K)
        cumulative_regret = []
        for t in range(1, T+1):
            t_start = time.perf_counter()
            
            # learning rate
            eta_t = c * K ** (1 / alpha - 0.5) / np.sqrt(t) 
            
            # exploitation
            r_t = np.random.pareto(alpha, size=K) + 1
            A_t = np.argmin(hat_L - r_t / eta_t)
            
            # exploration
            sigma_i = np.empty(K, dtype=int)
            sigma_i = np.argsort(np.argsort(hat_L)) + 1
            underline_L = hat_L - np.min(hat_L)
            
            term1 = 1 / (1 + eta_t * underline_L)
            term2 = 1 / (sigma_i ** (1 / alpha))
            q_t = np.sqrt(np.minimum(term1, term2) ** (alpha + 1))
            p_t = q_t / np.sum(q_t)
            
            B_t = np.random.choice(np.arange(K), p=p_t)
            lossB = generate_loss(true_means[t-1, B_t])
            
            # update estimated loss
            hat_L[B_t] += lossB / p_t[B_t]
            
            # regret tracking
            optimal_loss = true_means[t-1, i_star]
            inst_regret = true_means[t-1, A_t] - optimal_loss
            cumulative_regret.append(inst_regret if t==1 else cumulative_regret[-1] + inst_regret)

            # time tracking
            t_end = time.perf_counter()
            time_ftpl[run, t-1] = t_end - t_start
            
        regret_ftpl[run, :] = cumulative_regret

    avg_time_ftpl = np.mean(time_ftpl)
    print(f"[FTPL] Average time per step over all runs: {avg_time_ftpl:.6f} seconds")

    avg_regret_ftpl = np.mean(regret_ftpl, axis=0)
    std_regret_ftpl = np.std(regret_ftpl, axis=0)

    # --------------------------------------
    # FTRL
    # --------------------------------------
    regret_ftrl = np.zeros((n_runs, T))
    time_ftrl = np.zeros((n_runs, T))

    for run in trange(n_runs, desc="[FTRL]"):
        hat_L = np.zeros(K)
        cumulative_regret = []
        for t in range(1, T+1):
            t_start = time.perf_counter()
            
            # learning rate
            eta_t = 2 * K ** (0.5 - beta) / np.sqrt(t) 
            
            # convex optimization ----------------------------
            '''METHOD 1: splitting conic solver (SCS)'''
            p = cp.Variable(K)
            tsallis = -(1 / eta_t) * (1 / (beta * (1 - beta))) * cp.sum(cp.power(p, beta) - beta * p)
            objective = cp.Minimize(cp.sum(cp.multiply(hat_L, p)) + tsallis)
            constraints = [p >= 0, cp.sum(p) == 1]
            prob = cp.Problem(objective, constraints)
            prob.solve(solver = cp.SCS) 
            
            p_t = p.value
            p_t = np.clip(p.value, 0, 1)
            p_t = p_t / np.sum(p_t) 
            
            '''METHOD 2: Newton's method'''
            # p_t = solve_newton(hat_L, eta_t)
            # -----------------------------------------------
            
            # exploitation
            A_t = np.random.choice(np.arange(K), p=p_t)
            
            # exploration
            p_power = p_t ** (1 - beta / 2)
            q_t = p_power / np.sum(p_power)
            B_t = np.random.choice(np.arange(K), p=q_t)
            
            # update estimated loss
            loss = generate_loss(true_means[t-1, B_t])
            hat_L[B_t] += loss / q_t[B_t] 
            
            # regret tracking
            optimal_loss = true_means[t-1, i_star]
            inst_regret = true_means[t-1, A_t] - optimal_loss
            cumulative_regret.append(inst_regret if t==1 else cumulative_regret[-1] + inst_regret) 
            
            # time tracking
            t_end = time.perf_counter() 
            time_ftrl[run, t-1] = t_end - t_start 
            
        regret_ftrl[run, :] = cumulative_regret

    avg_time_ftrl = np.mean(time_ftrl)
    print(f"[FTRL] Average time per step over all runs: {avg_time_ftrl:.6f} seconds")

    avg_regret_ftrl = np.mean(regret_ftrl, axis=0)
    std_regret_ftrl = np.std(regret_ftrl, axis=0)

    return avg_time_ftpl, avg_time_ftrl, avg_regret_ftpl, avg_regret_ftrl, std_regret_ftpl, std_regret_ftrl

# --------------------------------------
# Plot
# --------------------------------------
num_K = len(Ks)

ftpl_times = []
ftrl_times = []
regret_ftpl_list = [] # [(T,)] * num_K
regret_ftrl_list = []
std_ftpl_list = []
std_ftrl_list = [] 

for K in Ks:
    (avg_time_ftpl, avg_time_ftrl, 
     avg_regret_ftpl, avg_regret_ftrl, 
     std_regret_ftpl, std_regret_ftrl) = compare_to_rouyer(K)
    ftpl_times.append(avg_time_ftpl)
    ftrl_times.append(avg_time_ftrl)
    regret_ftpl_list.append(avg_regret_ftpl)
    regret_ftrl_list.append(avg_regret_ftrl)
    std_ftpl_list.append(std_regret_ftpl)
    std_ftrl_list.append(std_regret_ftrl)

# --------------------------------------
# FTPL vs FTRL Average per-step runtime
# --------------------------------------
plt.figure(figsize=(6,4), dpi=300)
plt.plot(Ks, np.array(ftpl_times)*1000, marker='o', label='FTPL (Ours)', color='blue')
plt.plot(Ks, np.array(ftrl_times)*1000, marker='o', label='FTRL', color='red')
plt.xscale("log", base=2) 
plt.xlabel("Number of Arms ($K$)")
plt.ylabel("Average Runtime per Step (ms)")
plt.grid(True)
plt.legend(loc='upper left')
plt.show()

# --------------------------------------
# FTPL vs FTRL Ratio
# --------------------------------------
time_ratios = np.array(ftrl_times) / np.array(ftpl_times)

plt.figure(figsize=(6,4), dpi=300)
plt.plot(Ks, time_ratios, marker='o', label='FTRL / FTPL', color='purple')
plt.xscale("log", base=2) 
plt.xlabel("Number of Arms ($K$)")
plt.ylabel("Time Ratio")
plt.legend(loc='upper left')
plt.grid(True)
plt.tight_layout()
plt.show()