# Codes for Fed-SVT
# We run the simulations on 2.3 GHz Dual-Core Intel Core i5 Processor.

import numpy as np
import random
import matplotlib.pyplot as plt

def laplace_mechanism(scale, seed):
    np.random.seed(seed)
    return np.random.laplace(loc=0, scale=scale)

def above_threshold(query, epsilon, L, seed):
    np.random.seed(seed)
    T_hat = L + laplace_mechanism(2 / epsilon, seed)
    noise = laplace_mechanism(4 / epsilon, seed)
    if query + noise <= T_hat:
        return False # Do not switch
    return True # Need to switch

def exponential_mechanism(losses, epsilon, seed):
    np.random.seed(seed)
    max_loss = max(losses)
    scores = np.exp(-epsilon * (np.array(losses) - max_loss) / 2)
    probabilities = scores / np.sum(scores)
    return np.random.choice(len(losses), p=probabilities)

def Fed_SVT(T, m, d, N, loss_seq, eps, noise_seed):
    # Set K to be number of expert switching allowed in order to control DP budget
    K = int(np.ceil(10 * np.log(d)))

    np.random.seed(noise_seed)
    # Initialize to expert 0 by default
    selected_experts = np.zeros((m, T), dtype=int)
    current_expert = 0

    # Cumulative loss of all experts from step 0
    total_cumulative_losses = np.zeros((m, d))
    # Cumulative loss of current expert since last switching, this is for expert switching determination
    last_switch_cumulative_losses = np.zeros(m)
    k = 0
    for t in range(1, T):
        if k>K: # after this, always stick to the previous expert
            for i in range(m):
                selected_experts[i, t:T] = current_expert
            break
        for i in range(m):
            # Update total cumulative losses
            total_cumulative_losses[i, :] += loss_seq[i, t - 1, :]

            # Update cumulative loss for the current expert
            last_switch_cumulative_losses[i] += loss_seq[i, t - 1, current_expert]

        if t % N == 0:  # Communicate every N timestep
            query = np.sum(last_switch_cumulative_losses)
            if above_threshold(query, eps/(2*K), (8 * np.log(2*T**2/N**2))/eps + (8*K)/eps, noise_seed):
                # If switch, use exponential mechanism to decide a new expert
                losses_all_clients = np.sum(total_cumulative_losses, axis=0)
                current_expert = exponential_mechanism(losses_all_clients, eps/(2*K), noise_seed)
                k +=1
                # Reset last switch cumulative losses for the new expert
                last_switch_cumulative_losses = np.zeros(m)

        # Update the selected expert for each client
        for i in range(m):
            selected_experts[i, t] = current_expert

    return selected_experts


def run_simulation(m,T,d,N,eps,distribution_seed,noise_seed):
    # Adversary choose loss functions
    # Assume d-th expert the optimal
    # We set sparsity_factor=1 by default
    sparsity_factor=1
    np.random.seed(distribution_seed)
    loss_seq = np.random.rand(m, T, d)
    sparsity_mask = np.random.rand(m, T, d) < sparsity_factor
    loss_seq = loss_seq * sparsity_mask
    loss_seq[:, :, -1] = 0  # Assume optimal expert (d-th expert)

    selected_experts = Fed_SVT(T,m,d,N,loss_seq,eps,noise_seed)
    print(selected_experts[0,:])
    regret_seq = np.zeros(T)
    cumulative_regret = 0  # To keep track of cumulative regret over time

    for t in range(T):
        # Total loss at this timestep from all clients for their selected experts
        total_loss_at_t = np.sum([loss_seq[i, t, selected_experts[i, t]] for i in range(m)])
        # Optimal loss at this timestep is zero
        optimal_loss_at_t = 0
        incremental_regret = total_loss_at_t - optimal_loss_at_t
        # Cumulative regret calculation
        cumulative_regret += incremental_regret
        regret_seq[t] = cumulative_regret/m

    return regret_seq


if __name__ == '__main__':
    T = 2**9  # Number of iterations
    m = 1   # Number of clients
    d = 100    # Number of experts
    N = 50 # Communication interval
    eps = 10
    distribution_seed = 0
    noise_seed = 0

    reg_1 = run_simulation(1,T,d,1,eps,distribution_seed,noise_seed)
    reg_10_50 = run_simulation(10,T,d,50,eps,distribution_seed,noise_seed)
    reg_10_30 = run_simulation(10,T,d,30,eps,distribution_seed,noise_seed)
    reg_10_1 = run_simulation(10,T,d,1,eps,distribution_seed,noise_seed)

    time_steps = np.arange(T)


    plt.figure(figsize=(8, 6))

    plt.plot(time_steps, reg_1, label="Sparse-Vector", color="blue", linewidth=2)

    plt.plot(time_steps, reg_10_50, label="Fed-SVT ($N = 50$)", color="purple", linewidth=2)
    plt.plot(time_steps, reg_10_30, label="Fed-SVT ($N = 30$)", color="green", linewidth=2)
    plt.plot(time_steps, reg_10_1, label="Fed-SVT ($N = 1$)", color="orange", linewidth=2)

    plt.xlabel('T', fontsize=26)
    plt.ylabel('Per-client Regret', fontsize=26)
    plt.title(r'$m = 10$, $\varepsilon = 10$, $d = 100$', fontsize=26)
    plt.legend(fontsize=24,loc='center right')
    plt.tick_params(axis='both', which='major', labelsize=26)
    plt.grid(True)
    plt.tight_layout()

    # plt.savefig('results/regret_svt.png')
    plt.show()


    