import os
import argparse
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import warnings
warnings.filterwarnings("ignore")
PLOT = True


def calculating_func_val(A_arr, b_arr, x):
    vec_norm_arr = A_arr @ x - b_arr
    norm_arr = np.linalg.norm(vec_norm_arr, axis=1) ** 2
    f_cval = 0.5 * np.average(norm_arr)
    return f_cval

def GradientDescent(x_0, K=200000):
    x = x_0
    for i in range(K):
        term_1 = np.squeeze(A_arr @ x[:, np.newaxis])
        term_2 = term_1 - b_arr
        local_grad = np.squeeze(Atrans_arr @ term_2[:, :, np.newaxis])
        true_grad = np.average(local_grad, axis=0)
        x = x - (1 / L_scl) * true_grad
    return x


def FedExProx_const_full_batch(x_0, alpha=1, K=100):
    x = x_0
    f_x_0 = calculating_func_val(A_arr, b_arr, x_0)
    iterates_list = [x_0]
    func_val_list = [f_x_0 - f_val_x_star]
    for i in range(K):
        # Calculate Prox
        z = AtransA_shift_inv_arr @ ((1 / gamma) * x[:, np.newaxis] + Atransb_arr)
        z_avg = np.squeeze(np.average(z, axis=0))
        x = x + alpha * (z_avg - x)
        f_x = calculating_func_val(A_arr, b_arr, x)
        iterates_list.append(x - x_star)
        func_val_list.append(f_x - f_val_x_star)
    return np.array(iterates_list), np.array(func_val_list)


def FedExProx_const_minibatch(x_0, tau=1, alpha=1, K=100):
    x = x_0
    f_x_0 = calculating_func_val(A_arr, b_arr, x_0)
    iterates_list = [x_0]
    func_val_list = [f_x_0 - f_val_x_star]
    for i in range(K):
        # Subsampling
        idx_chosen = np.random.choice(np.arange(n_workers), tau, replace=False)
        # Calculate Prox
        z = AtransA_shift_inv_arr[idx_chosen, :, :] @ ((1 / gamma) * x[:, np.newaxis] + Atransb_arr[idx_chosen, :, :])

        z_avg = np.squeeze(np.average(z, axis=0))
        x = x + alpha * (z_avg - x)
        f_x = calculating_func_val(A_arr, b_arr, x)
        iterates_list.append(x - x_star)
        func_val_list.append(f_x - f_val_x_star)
    return np.array(iterates_list), np.array(func_val_list)


def FedExProx_GraDS_minibatch(gamma, x_0, tau, K=100, epi=5e-9):
    x = x_0
    f_x_0 = calculating_func_val(A_arr, b_arr, x_0)
    iterates_list = [x_0]
    func_val_list = [f_x_0 - f_val_x_star]
    for i in range(K):
        # Subsampling
        idx_chosen = np.random.choice(np.arange(n_workers), tau, replace=False)
        # Calculate Prox
        z = AtransA_shift_inv_arr[idx_chosen, :, :] @ ((1 / gamma) * x[:, np.newaxis] + Atransb_arr[idx_chosen, :, :])
        z_avg = np.squeeze(np.average(z, axis=0))
        # Now different from before, we need to adpatively determine the stepsize
        if tau == 1:
            alpha_k_grads = 1
        else:
            alpha_numer = np.average(np.linalg.norm(np.squeeze(z - x[:, np.newaxis]), axis=1) ** 2)
            alpha_denom = np.linalg.norm(z_avg - x) ** 2 + epi
            alpha_k_grads = alpha_numer / alpha_denom

        x = x + alpha_k_grads * (z_avg - x)
        f_x = calculating_func_val(A_arr, b_arr, x)
        iterates_list.append(x - x_star)
        func_val_list.append(f_x - f_val_x_star)
    return np.array(iterates_list), np.array(func_val_list)


def FedExProx_StoPS_minibatch(gamma, x_0, tau, K=100, epi=5e-9):
    x = x_0
    f_x_0 = calculating_func_val(A_arr, b_arr, x_0)
    iterates_list = [x_0]
    func_val_list = [f_x_0 - f_val_x_star]
    for i in range(K):
        # Subsampling
        idx_chosen = np.random.choice(np.arange(n_workers), tau, replace=False)
        # Calculate Prox
        z = AtransA_shift_inv_arr[idx_chosen, :, :] @ ((1 / gamma) * x[:, np.newaxis] + Atransb_arr[idx_chosen, :, :])
        z_avg = np.squeeze(np.average(z, axis=0))
        # Now different from before, we need to adpatively determine the stepsize
        if tau == 1:
            # Calculate the function value of a single function
            denom = (1 / gamma) * np.linalg.norm(z_avg - x) ** 2
            numer = calculating_func_val(A_arr, b_arr, x=z_avg) + (1 / (2 * gamma)) * np.linalg.norm(z_avg - x) ** 2
            alpha_k_stops = numer / (denom + epi)
        else:
            # Calculating numerator
            A_arr_sub = A_arr[idx_chosen, :, :]
            b_arr_sub = b_arr[idx_chosen, :]
            z_trans = np.transpose(z, axes=[0, 2, 1])
            A_trans_arr_sub = np.transpose(A_arr_sub, axes=[0, 2, 1])
            b_mat_arr_sub = b_arr_sub[:, :, np.newaxis]
            b_mat_trans_arr_sub = np.transpose(b_mat_arr_sub, axes=[0, 2, 1])
            # First part calculation
            t1 = z_trans @ A_trans_arr_sub @ A_arr_sub @ z
            # print(z_trans.shape, A_trans_arr_sub.shape, A_arr_sub.shape, z.shape, "  ==>  ", t1.shape)
            t2 = b_mat_trans_arr_sub @ A_arr_sub @ z
            # print(b_mat_trans_arr_sub.shape, A_arr_sub.shape, z.shape, "  ==>  ", t2.shape)
            t3 = z_trans @ A_trans_arr_sub @ b_mat_arr_sub
            # print(z_trans.shape, A_trans_arr_sub.shape, b_mat_arr_sub.shape, "  ==>  ", t3.shape)
            t4 = b_mat_trans_arr_sub @ b_mat_arr_sub
            # print(b_mat_trans_arr_sub.shape, b_mat_arr_sub.shape, t4.shape, "  ==>  ", t4.shape)
            ft = np.squeeze(0.5 * (t1 - t2 - t3 + t4))
            # Second part calculation
            st = np.squeeze((1 / (2 * gamma)) * np.linalg.norm(z - x[:, np.newaxis], axis=1) ** 2)
            # Generating numerator
            numer = np.average(ft + st)
            # Calculating denominator
            denom = (1 / gamma) * np.linalg.norm(z_avg - x) ** 2 + epi
            alpha_k_stops = 2 * numer / denom

        x = x + alpha_k_stops * (z_avg - x)
        f_x = calculating_func_val(A_arr, b_arr, x)
        iterates_list.append(x - x_star)
        func_val_list.append(f_x - f_val_x_star)
    return np.array(iterates_list), np.array(func_val_list)


def FedExP_full_batch(x_0, local_iter=1, K=100):
    x_outer = x_0
    f_x_0 = calculating_func_val(A_arr, b_arr, x_0)
    iterates_list = [x_0]
    func_val_list = [f_x_0 - f_val_x_star]
    local_stepsize = (1 / (6 * local_iter * L_scl))
    for i in range(K):
        x_inner = np.tile(x_outer, [n_workers, 1])[:, :, np.newaxis]
        x_inner_ups = x_inner
        for t in range(local_iter):
            # Calculating the average iterate
            grad_list = AtransA_arr @ x_inner_ups - Atransb_arr
            x_inner_ups = x_inner_ups - local_stepsize * grad_list

        avg_outer = np.squeeze(np.average(x_inner_ups, axis=0))
        # Calculating the adaptive stepsize
        adpt_numer = np.average(np.linalg.norm(np.squeeze(x_inner_ups - x_inner), axis=1) ** 2)
        adpt_demon =  (np.linalg.norm(avg_outer - x_outer) ** 2) + 0
        adpt_exp = np.max([adpt_numer / adpt_demon, 1])
        x_outer = x_outer + adpt_exp * (avg_outer - x_outer)
        iterates_list.append(x_outer - x_star)
        func_val_list.append(calculating_func_val(A_arr, b_arr, x_outer))
    return np.array(iterates_list), np.array(func_val_list)


def FedExP_minibatch(x_0, tau=1, local_iter=1, K=100):
    x_outer = x_0
    f_x_0 = calculating_func_val(A_arr, b_arr, x_0)
    iterates_list = [x_0]
    func_val_list = [f_x_0 - f_val_x_star]
    local_stepsize = (1 / (6 * local_iter * L_scl))
    for i in range(K):
        # Randomly chosing the indices
        idx_chosen = np.random.choice(np.arange(n_workers), tau, replace=False)

        x_inner = np.tile(x_outer, [tau, 1])[:, :, np.newaxis]
        x_inner_ups = x_inner
        for t in range(local_iter):
            # Calculating the average iterate
            grad_list = AtransA_arr[idx_chosen, :, :] @ x_inner_ups - Atransb_arr[idx_chosen, :, :]
            x_inner_ups = x_inner_ups - local_stepsize * grad_list

        avg_outer = np.squeeze(np.average(x_inner_ups, axis=0))

        # Calculating the adaptive stepsize
        if tau == 1:
            adpt_exp = 1
        else:
            adpt_numer = np.average(np.linalg.norm(np.squeeze(x_inner_ups - x_inner), axis=1) ** 2)
            adpt_demon = (np.linalg.norm(avg_outer - x_outer) ** 2) + 0
            adpt_exp = np.max([adpt_numer / adpt_demon, 1])
        x_outer = x_outer + adpt_exp * (avg_outer - x_outer)
        iterates_list.append(x_outer - x_star)
        func_val_list.append(calculating_func_val(A_arr, b_arr, x_outer))
    return np.array(iterates_list), np.array(func_val_list)


parser = argparse.ArgumentParser(description='Experiment')
parser.add_argument('--minibatch',  '-t', help='Minibatch used', default=30)
parser.add_argument('--gamma',      '-g', help='Value of local stepsize used', default=1)
parser.add_argument('--local_iter', '-l', help='Local iterations used for FedExP', default=1)
parser.add_argument('--iterations', '-K', help='Number of iterations to run', default=100)
parser.add_argument('--method',     '-m', help='Method used', default=1)
parser.add_argument('--gd_iter',    '-a', help='Initial iterations of GD to determine minimum', default=10)
args = parser.parse_args()
if __name__ == "__main__":

    # Hyperparameters
    method = int(args.method)  # Experiment index
    K = int(args.iterations)  # Number of iterations
    local_iter = int(args.local_iter)
    gamma = int(args.gamma)
    minibatch = int(args.minibatch)  # The minibatch size used
    K_gd = int(args.gd_iter)

    # Generate random matrices
    n_workers = 30
    n_dim = 20
    m_dim = 900
    tau = minibatch

    # Fixing the random seed for generating the dataset
    seed_dataset = 40
    np.random.seed(seed_dataset)

    # Randomly generate the dataset
    A_arr = np.array([np.random.rand(n_dim, m_dim) for i in range(n_workers)])
    b_arr = np.array([np.random.rand(n_dim) for i in range(n_workers)])

    # Initial point
    x_0 = np.zeros(m_dim)

    # Saving the random dataset
    # np.savez('FedExP_data_set_2024.npz', A_arr=A_arr, b_arr=b_arr, x_0=x_0)

    # Calculating A_i^T @ A_i in a list
    Atrans_arr = np.transpose(A_arr, axes=[0, 2, 1])
    AtransA_arr = Atrans_arr @ A_arr

    # Calculate Ai^T b
    Atransb_arr = Atrans_arr @ b_arr[:, :, np.newaxis]

    # The smoothness constant individually
    L_i_arr = np.max(np.linalg.eigh(AtransA_arr)[0], axis=1)

    # Calculating the hessian of f
    AtransA_avg_arr = np.average(AtransA_arr, axis=0)
    L_scl = np.linalg.eigh(AtransA_avg_arr)[0][-1]
    Lmax_scl = np.max(L_i_arr)

    # Calculate the solution
    x_star = GradientDescent(x_0, K=K_gd)
    f_val_x_star = calculating_func_val(A_arr, b_arr, x_star)
    np.save("x_star.npy", x_star)

    # Calculating the smoothness of individual moreau envelope
    L_mo_i_arr = L_i_arr / (1 + gamma * L_i_arr)
    L_max_mo = np.max(L_mo_i_arr)

    # Calculating the smoothness of global moreau envelope
    AtransA_shift_arr = AtransA_arr + (1 / gamma) * np.eye(m_dim)
    AtransA_shift_inv_arr = np.linalg.inv(AtransA_shift_arr)
    AtransA_shift_inv_avg_arr = np.average(AtransA_shift_inv_arr, axis=0)
    mat_inside = (1 / gamma) * np.eye(m_dim) - ((1 / gamma) ** 2) * AtransA_shift_inv_avg_arr
    LM_scl = np.linalg.eigh(mat_inside)[0][-1]
    alpha_star = 1 + (1 / (gamma * LM_scl))

    # Result
    cur_dir = os.getcwd()
    result_dir = "result"
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    result_path = os.path.join(cur_dir, result_dir)

    # Running FedExP
    if method == 1:
        if minibatch == n_workers:
            iter_arr, fval_arr = FedExP_full_batch(x_0, local_iter, K=K)
            # Save
            iter_name = "LinearRegression_FedExP_Full_Batch_localiter_{}_iteration.npy".format(local_iter)
            fval_name = "LinearRegression_FedExP_Full_Batch_localiter_{}_func_val.npy".format(local_iter)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch < n_workers:
            iter_arr, fval_arr = FedExP_minibatch(x_0, minibatch, local_iter, K=K)
            # Save
            iter_name = "LinearRegression_FedExP_Mini_Batch_localiter_{}_tau_{}_iteration.npy".format(local_iter, minibatch)
            fval_name = "LinearRegression_FedExP_Mini_Batch_localiter_{}_tau_{}_func_val.npy".format(local_iter, minibatch)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch > n_workers:
            raise RuntimeError("The specified minibatch exceed the total number of clients")

    # Running FedExProx
    if method == 2:
        if minibatch == n_workers:
            iter_arr, fval_arr = FedExProx_const_full_batch(x_0, alpha=alpha_star, K=K)
            # Save
            iter_name = "LinearRegression_FedExProx_Full_Batch_gamma_{}_alpha_{}_iteration.npy".format(gamma,
                                                                                                       alpha_star)
            fval_name = "LinearRegression_FedExProx_Full_Batch_gamma_{}_alpha_{}_func_val.npy".format(gamma, alpha_star)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch < n_workers:
            n = n_workers
            L_tau = ((n - tau) / (tau * (n - 1))) * L_max_mo + ((n * (tau - 1)) / (tau * (n - 1))) * LM_scl
            alpha_star = 1 / (gamma * L_tau)
            # Running FedExProx Minibatch
            iter_arr, fval_arr = FedExProx_const_minibatch(x_0, tau=tau, alpha=alpha_star, K=K)
            # Save
            iter_name = "LinearRegression_FedExProx_MiniBatch_gamma_{}_tau_{}_alpha_{}_iteration.npy".format(gamma, tau,
                                                                                                             alpha_star)
            fval_name = "LinearRegression_FedExProx_MiniBatch_gamma_{}_tau_{}_alpha_{}_func_val.npy".format(gamma, tau,
                                                                                                            alpha_star)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch > n_workers:
            raise RuntimeError("The specified minibatch exceed the total number of clients")

    # Running FedProx
    if method == 3:
        if minibatch == n_workers:
            iter_arr, fval_arr = FedExProx_const_full_batch(x_0, alpha=1, K=K)
            # Save
            iter_name = "LinearRegression_FedProx_Full_Batch_gamma_{}_alpha_{}_iteration.npy".format(gamma, 1)
            fval_name = "LinearRegression_FedProx_Full_Batch_gamma_{}_alpha_{}_func_val.npy".format(gamma, 1)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch < n_workers:
            iter_arr, fval_arr = FedExProx_const_minibatch(x_0, tau=tau, alpha=1, K=K)
            # Save
            iter_name = "LinearRegression_FedProx_MiniBatch_gamma_{}_tau_{}_alpha_{}_iteration.npy".format(gamma, tau,
                                                                                                           1)
            fval_name = "LinearRegression_FedProx_MiniBatch_gamma_{}_tau_{}_alpha_{}_func_val.npy".format(gamma, tau, 1)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch > n_workers:
            raise RuntimeError("The specified minibatch exceed the total number of clients")

    # Running FedExProx-GradS
    if method == 4:
        if minibatch == n_workers:
            iter_arr, fval_arr = FedExProx_GraDS_minibatch(gamma, x_0, tau=n_workers, K=K, epi=1e-9)
            # Save
            iter_name = "LinearRegression_FedExProx_GraDS_Full_Batch_gamma_{}_iteration.npy".format(gamma)
            fval_name = "LinearRegression_FedExProx_GraDS_Full_Batch_gamma_{}_func_val.npy".format(gamma)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch < n_workers:
            iter_arr, fval_arr = FedExProx_GraDS_minibatch(gamma, x_0, tau=tau, K=K, epi=1e-8)
            # Save
            iter_name = "LinearRegression_FedExProx_MiniBatch_gamma_{}_tau_{}_GraDS_iteration.npy".format(gamma, tau)
            fval_name = "LinearRegression_FedExProx_MiniBatch_gamma_{}_tau_{}_GraDS_func_val.npy".format(gamma, tau)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch > n_workers:
            raise RuntimeError("The specified minibatch exceed the total number of clients")

    # Running FedExProx-StoPS
    if method == 5:
        if minibatch == n_workers:
            iter_arr, fval_arr = FedExProx_StoPS_minibatch(gamma, x_0, tau=n_workers, K=K, epi=1e-10)
            # Save
            iter_name = "LinearRegression_FedExProx_StoPS_Full_Batch_gamma_{}_iteration.npy".format(gamma)
            fval_name = "LinearRegression_FedExProx_StoPS_Full_Batch_gamma_{}_func_val.npy".format(gamma)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch < n_workers:
            iter_arr, fval_arr = FedExProx_StoPS_minibatch(gamma, x_0, tau=tau, K=K, epi=1e-10)
            # Save
            iter_name = "LinearRegression_FedExProx_MiniBatch_gamma_{}_tau_{}_StoPS_iteration.npy".format(gamma,
                                                                                                          tau)
            fval_name = "LinearRegression_FedExProx_MiniBatch_gamma_{}_tau_{}_StoPS_func_val.npy".format(gamma, tau)
            iter_save_name = os.path.join(result_path, iter_name)
            fval_save_name = os.path.join(result_path, fval_name)
            np.save(iter_save_name, iter_arr)
            np.save(fval_save_name, fval_arr)
        elif minibatch > n_workers:
            raise RuntimeError("The specified minibatch exceed the total number of clients")
