# Codes for Fed-DP-OPE-Stoch
# We run the simulations on 2.3 GHz Dual-Core Intel Core i5 Processor.
import numpy as np
import random
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline

def sample_true_distributions(n, d, seed,means,variances):
    np.random.seed(seed)
    distributions = np.zeros((n, d))
    for i in range(d):
        distributions[:, i] = np.random.normal(loc=means[i], scale=np.sqrt(variances[i]), size=n)

    distributions = np.exp(distributions)
    # Normalize to form a valid probability distribution
    distributions /= np.sum(distributions, axis=1, keepdims=True)

    return distributions

def cross_entropy_loss(predicted, true):
    """
    Calculate the cross-entropy loss for a predicted distribution against a true distribution.
    """
    return -np.sum(true * np.log(predicted + 1e-9))

def gradient_cross_entropy_loss(predicted, true):
    """
    Calculate the gradient of the cross-entropy loss for a predicted distribution
    against a true distribution.
    """
    gradient = -true / (predicted + 1e-9)
    return gradient

def federated_dp_ope(T,m,d,loss_param,eps,seed):
    selected_experts = np.zeros((m, T, d)) # Mixed policy
    # Initialize for timestep 0
    for i in range(m):
        selected_experts[i, 0, :] = np.full(d, 1.0 / d)
    tau = 0

    for t in range(1, T):
        if t+1 > 0 and (t & (t + 1)) == 0:  # Power of 2 check, update expert using loss funcs from tau to (t-1)
            k = 1
            gradient_sum_root = np.zeros((m, d)) # Mixed policy
            gradient_sum_root_dp = np.zeros((m, d))
            x_root = selected_experts[i, t-1, :]
            # print("x_root")
            # print(x_root)
            for i in range(m):
                # Calculate gradient estimate at x_root using loss funcs from tau to (t-1)
                gradient_sum = np.zeros(d)
                for idx in range(tau, t):
                    true = loss_param[i][idx]
                    gradient = gradient_cross_entropy_loss(x_root, true)
                    gradient_sum += gradient
                gradient_sum_root[i] = gradient_sum/(t-tau)
                np.random.seed(seed)
                noise = np.random.laplace(0, 1/((t-tau)*eps), d)
                gradient_sum_root_dp[i] = noise + gradient_sum_root[i]

            #Reached left leaf, server side
            aggregated_gradients = np.mean(gradient_sum_root_dp, axis=0)
            # print(aggregated_gradients)
            target_left = np.argmin(aggregated_gradients)
            x_left_target = np.zeros(d)
            x_left_target[target_left] = 1
            x_left = (1-2/(k+1)) * x_root + (2/(k+1)) * x_left_target
            k += 1
            #Reached right leaf
            # calculate gradient estimate using half (randomly select) of loss funcs from tau to (t-1)
            gradient_sum_leaf = np.zeros((m, d))
            gradient_sum_leaf_dp = np.zeros((m, d))
            for i in range(m):
                gradient_sum_child = np.zeros(d)
                gradient_sum_parent = np.zeros(d)
                indices = np.random.choice(range(tau, t), size=(t - tau) // 2, replace=False)
                for idx in indices:
                    true = loss_param[i][idx]
                    gradient_parent = gradient_cross_entropy_loss(x_root, true)
                    gradient_child = gradient_cross_entropy_loss(x_left, true)
                    gradient_sum_child += gradient_child
                    gradient_sum_parent += gradient_parent
                gradient_sum_leaf[i] = gradient_sum_root[i] - gradient_sum_parent/len(indices) + gradient_sum_child/len(indices)
                noise = np.random.laplace(0, 2/((t-tau)*eps), d)
                gradient_sum_leaf_dp[i] = noise + gradient_sum_leaf[i]
            # Server side
            aggregated_gradients = np.mean(gradient_sum_leaf_dp, axis=0)
            target_right = np.argmin(aggregated_gradients)
            x_right_target = np.zeros(d)
            x_right_target[target_right] = 1
            x_right = (1/3) * x_left + (2/3) * x_right_target
            for i in range(m):
                selected_experts[i, t, :] = (1-2/(k+1)) * x_root + (2/(k+1)) * x_right
            k += 1

            #print(t,"selected_experts:",selected_experts[i, t, :])
            tau = t
        else:
            # Stick with the previously selected expert
            for i in range(m):
                selected_experts[i, t, :] = selected_experts[i, t-1, :]

    return selected_experts

def run_simulation(m,T,d,eps,distribution_seed,noise_seed,means,variances):

    all_distributions = sample_true_distributions(m*T, d, distribution_seed,means,variances)
    # loss_param[i][t] gives the distribution for client i at time t
    loss_param = all_distributions.reshape(m, T, d)
    selected_experts = federated_dp_ope(T,m,d,loss_param,eps,noise_seed)
    regret_seq = []
    total_policy_loss = 0
    expert_cumulative_losses = np.zeros(d)  # To keep track of each expert's cumulative loss across all timesteps

    # Calculate the cumulative loss of the strategy over T timesteps
    for t in range(T):
        for i in range(m):
            best_expert = np.argmax(selected_experts[i, t])
            #print(best_expert)
            policy = np.zeros(d)
            policy[best_expert] = 1
            true_loss_distribution = loss_param[i, t]
            policy_loss = cross_entropy_loss(policy, true_loss_distribution)
            total_policy_loss += policy_loss

            # Accumulate losses for each expert to find the best expert in hindsight
            for expert in range(d):
                expert_policy = np.zeros(d)
                expert_policy[expert] = 1  # One-hot encoding for the expert
                expert_cumulative_losses[expert] += cross_entropy_loss(expert_policy, true_loss_distribution)

        # Find the best expert's cumulative loss in hindsight
        best_expert_cumulative_loss = np.min(expert_cumulative_losses)
        # Calculate total regret
        total_regret = (total_policy_loss - best_expert_cumulative_loss)/m
        regret_seq.append(total_regret)

    return regret_seq



if __name__ == '__main__':
    T = 2 ** 14  # Number of iterations 14
    m = 1   # Number of clients
    d = 100    # Number of experts
    eps = 10
    distribution_seed = 2
    noise_seed = 0

    # Genearte distribution
    np.random.seed(distribution_seed)
    means = np.random.uniform(low=0.01, high=1, size=d)
    variances = np.random.uniform(low=0.01, high=0.5, size=d)
    # Run simulations
    #m = 1
    # regret_seq_1_0 = run_simulation(1,T,d,eps,0,0,means,variances)
    # regret_seq_1_1 = run_simulation(1,T,d,eps,1,1,means,variances)
    regret_seq_1_2 = run_simulation(1,T,d,eps,2,2,means,variances)
    # regret_seq_1_3 = run_simulation(1,T,d,eps,3,3,means,variances)
    # regret_seq_1_4 = run_simulation(1,T,d,eps,4,4,means,variances)

    regret_sequences_1 = [regret_seq_1_2]
    regret_matrix_1 = np.array(regret_sequences_1)
    mean_regrets_1 = np.mean(regret_matrix_1, axis=0)
    std_dev_regrets_1 = np.std(regret_matrix_1, axis=0)

    # m=10
    # regret_seq_10_0 = run_simulation(10,T,d,eps,0,0,means,variances)
    # regret_seq_10_1 = run_simulation(10,T,d,eps,1,1,means,variances)
    regret_seq_10_2 = run_simulation(10,T,d,eps,2,2,means,variances)
    # regret_seq_10_3 = run_simulation(10,T,d,eps,3,3,means,variances)
    # regret_seq_10_4 = run_simulation(10,T,d,eps,4,4,means,variances)

    regret_sequences_10 = [regret_seq_10_2]
    regret_matrix_10 = np.array(regret_sequences_10)
    mean_regrets_10 = np.mean(regret_matrix_10, axis=0)
    std_dev_regrets_10 = np.std(regret_matrix_10, axis=0)



    T_original = np.arange(T)
    T_smooth = np.linspace(T_original.min(), T_original.max(), 300)  # Increase 300 for more smoothness

    spl_mean_1 = make_interp_spline(T_original, mean_regrets_1, k=3)
    mean_regrets_smooth_1 = spl_mean_1(T_smooth)

    spl_mean_10 = make_interp_spline(T_original, mean_regrets_10, k=3)
    mean_regrets_smooth_10 = spl_mean_10(T_smooth)

    plt.figure(figsize=(8, 6))

    plt.plot(T_smooth, mean_regrets_smooth_1, label="Limited Updates", color="blue", linewidth=2)

    plt.plot(T_smooth, mean_regrets_smooth_10, label="Fed-DP-OPE-Stoch", color="red", linewidth=2)

    plt.xlabel('T', fontsize=26)
    plt.ylabel('Per-client Regret', fontsize=26)
    plt.title(r'$m = 10$, $\varepsilon = 10$, $d = 100$', fontsize=26)
    plt.legend(fontsize=24,loc='center right')
    plt.tick_params(axis='both', which='major', labelsize=26)
    plt.grid(True)
    plt.tight_layout()

    # plt.savefig('results/regret_stoch.png')
    plt.show()



