import numpy as np
import pickle
from algorithms import *

M = 100  # total number of clients
n = 100  # total number of samples per client
K = 5  # number of clients chosen at each round
B_bar = 10  # mini-batch size for computing local gradients

# Perform SGD
n_iter = 1000  # number of iterations
eta_list = [1.0, 0.5, 0.1, 0.05, 0.01]  # learning rate for SGD
n_repeat = 5  # number of times for repeating the experiment

nu_list = [1.0, 3.0, 10.0] # parameter controlling the level of heterogeneity
result_path = "result_fixed/"  # the path to save results
data_path = "data/"  # the path to load the data

# load the data
with open(data_path + "simulated_data_fixed.txt", "rb") as f:
    simulated_data = pickle.load(f)

for i_data, nu in enumerate(nu_list):
    print("nu={}".format(nu))
    np.random.seed(111)
    
    data_list = simulated_data[i_data]

    # Parameter for Adaptive OSMD-sampler
    alpha_OSMD = 0.4
    
    loss_list_IS = np.zeros((n_repeat, len(eta_list), n_iter))
    loss_list_optimal = np.zeros((n_repeat, len(eta_list), n_iter))
    loss_list_Adaptive_OSMD = np.zeros((n_repeat, len(eta_list), n_iter))
    
    regret_list_IS = np.zeros((n_repeat, len(eta_list), n_iter))
    regret_list_Adaptive_OSMD = np.zeros((n_repeat, len(eta_list), n_iter))

    for k_rep in range(n_repeat):
        for j, eta_SGD in enumerate(eta_list):
            print("learning rate: {}".format(eta_SGD))
            loss_list_IS[k_rep, j, :], regret_list_IS[k_rep, j, :] = train_IS(K, B_bar, data_list, n_iter, eta_SGD)
            loss_list_optimal[k_rep, j, :] = train_optimal(K, B_bar, data_list, n_iter, eta_SGD)
            loss_list_Adaptive_OSMD[k_rep, j, :], regret_list_Adaptive_OSMD[k_rep, j, :] = train_Ada_OSMD(K, B_bar, data_list, eta_SGD, alpha_OSMD, n_iter)
        print("Repeat: {} finished!".format(k_rep+1))

    # Compute loss mean and std
    log_loss_list_IS_mean = np.log(loss_list_IS).mean(0)
    log_loss_list_IS_std = np.log(loss_list_IS).std(0)

    log_loss_list_optimal_mean = np.log(loss_list_optimal).mean(0)
    log_loss_list_optimal_std = np.log(loss_list_optimal).std(0)

    log_loss_list_Adaptive_OSMD_mean = np.log(loss_list_Adaptive_OSMD).mean(0)
    log_loss_list_Adaptive_OSMD_std = np.log(loss_list_Adaptive_OSMD).std(0)

    # Compute regret mean and std
    log_regret_list_IS_mean = np.log(regret_list_IS.cumsum(2)).mean(0)
    log_regret_list_IS_std = np.log(regret_list_IS.cumsum(2)).std(0)

    log_regret_list_Adaptive_OSMD_mean = np.log(regret_list_Adaptive_OSMD.cumsum(2)).mean(0)
    log_regret_list_Adaptive_OSMD_std = np.log(regret_list_Adaptive_OSMD.cumsum(2)).std(0)

    # choose the best learning rate
    min_ind = np.argmin(log_loss_list_IS_mean[:,-1])
    log_loss_list_IS_mean = log_loss_list_IS_mean[min_ind, :]
    log_loss_list_IS_std = log_loss_list_IS_std[min_ind, :]
    log_regret_list_IS_mean = log_regret_list_IS_mean[min_ind, :]
    log_regret_list_IS_std = log_regret_list_IS_std[min_ind, :]

    min_ind = np.argmin(log_loss_list_Adaptive_OSMD_mean[:,-1])
    log_loss_list_Adaptive_OSMD_mean = log_loss_list_Adaptive_OSMD_mean[min_ind, :]
    log_loss_list_Adaptive_OSMD_std = log_loss_list_Adaptive_OSMD_std[min_ind, :]
    log_regret_list_Adaptive_OSMD_mean = log_regret_list_Adaptive_OSMD_mean[min_ind, :]
    log_regret_list_Adaptive_OSMD_std = log_regret_list_Adaptive_OSMD_std[min_ind, :]

    min_ind = np.argmin(log_loss_list_optimal_mean[:,-1])
    log_loss_list_optimal_mean = log_loss_list_optimal_mean[min_ind, :]
    log_loss_list_optimal_std = log_loss_list_optimal_std[min_ind, :]

    # Save results
    with open(result_path + "log_loss_list_IS_mean_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_loss_list_IS_mean, f)

    with open(result_path +  "log_loss_list_IS_std_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_loss_list_IS_std, f)

    with open(result_path + "log_loss_list_optimal_mean_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_loss_list_optimal_mean, f)

    with open(result_path + "log_loss_list_optimal_std_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_loss_list_optimal_std, f)

    with open(result_path + "log_loss_list_Adaptive_OSMD_mean_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_loss_list_Adaptive_OSMD_mean, f)

    with open(result_path + "log_loss_list_Adaptive_OSMD_std_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_loss_list_Adaptive_OSMD_std, f)

    # Save results
    with open(result_path + "log_regret_list_IS_mean_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_regret_list_IS_mean, f)

    with open(result_path +  "log_regret_list_IS_std_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_regret_list_IS_std, f)

    with open(result_path + "log_regret_list_Adaptive_OSMD_mean_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_regret_list_Adaptive_OSMD_mean, f)

    with open(result_path + "log_regret_list_Adaptive_OSMD_std_nu=" + str(nu) + ".txt", "wb") as f:
        pickle.dump(log_regret_list_Adaptive_OSMD_std, f)