import numpy as np
import pickle
from algorithms import *

M = 100  # total number of clients
n = 100  # total number of samples per client
d = 10  # dimension size of parameter
K = 5  # number of clients chosen at each round
B_bar = 10  # mini-batch size for computing local gradients

# Parameter for MABS
alpha_Avare = 0.4

# Perform SGD
n_iter = 1000  # number of iterations
eta_SGD = 0.1  # learning rate for SGD
n_repeat = 10  # number of times for repeating the experiment

sigma_list = [1.0, 3.0, 10.0] # parameter controlling the level of heterogeneity
result_path = "result/"  # the path to save results
data_path = "data/"  # the path to load the data

# load the data
with open(data_path + "simulated_data.txt", "rb") as f:
    simulated_data = pickle.load(f)

for i_data, sigma in enumerate(sigma_list):
    print("sigma={}".format(sigma))
    np.random.seed(111)

    data_list = simulated_data[i_data]
    loss_list_Avare = np.zeros((n_repeat, n_iter))
    regret_list_Avare = np.zeros((n_repeat, n_iter))
    for k_rep in range(n_repeat):
        loss_list_Avare[k_rep, :], regret_list_Avare[k_rep, :] = train_Avare(K, B_bar, data_list, n_iter, eta_SGD, alpha_Avare)
        print("Repeat: {} finished!".format(k_rep+1))

    # Compute loss mean and std
    log_loss_list_Avare_mean = np.log(loss_list_Avare).mean(0)
    log_loss_list_Avare_std = np.log(loss_list_Avare).std(0)

    # Compute regret mean and std
    log_regret_list_Avare_mean = np.log(regret_list_Avare.cumsum(1)).mean(0)
    log_regret_list_Avare_std = np.log(regret_list_Avare.cumsum(1)).std(0)

    # Save results
    with open(result_path + "log_loss_list_Avare_mean_sigma=" + str(sigma) + ".txt", "wb") as f:
        pickle.dump(log_loss_list_Avare_mean, f)

    with open(result_path +  "log_loss_list_Avare_std_sigma=" + str(sigma) + ".txt", "wb") as f:
        pickle.dump(log_loss_list_Avare_std, f)

    with open(result_path + "log_regret_list_Avare_mean_sigma=" + str(sigma) + ".txt", "wb") as f:
        pickle.dump(log_regret_list_Avare_mean, f)

    with open(result_path + "log_regret_list_Avare_std_sigma=" + str(sigma) + ".txt", "wb") as f:
        pickle.dump(log_regret_list_Avare_std, f)
