"""
MARINA_P
"""
# MARINA-N project: in progress

from src.algorithm import *

class MARINA_P(Algorithm):
    def __init__(self, args=None):
        super().__init__(args)

    def script_directory(self):
        return os.path.dirname(os.path.abspath(__file__))

    def sameRandKCompressor(self, alg_param, vec):
        inds = alg_param["rs_compressor"].choice(a=np.arange(alg_param["dim"]), size=alg_param["k"], replace=False, p=alg_param["randk_probs"])
        randk_vec = (alg_param["dim"]/alg_param["k"])*vec_sparsificator(vec.copy(), inds)
        return np.tile(randk_vec, (alg_param["num_workers"], 1))
    
    def indRandKCompressor(self, alg_param, vec):
        Q_mat = np.zeros((alg_param["num_workers"], alg_param["dim"]), dtype=np.float64)
        for i in range(alg_param["num_workers"]):
            inds = alg_param["rs_compressor"].choice(a=np.arange(alg_param["dim"]), size=alg_param["k"], replace=False, p=alg_param["randk_probs"])
            randk_vec = (alg_param["dim"]/alg_param["k"])*vec_sparsificator(vec.copy(), inds)
            Q_mat[i] = randk_vec.copy()
        return Q_mat
    
    def PermKCompressor(self, alg_param, vec):
        perm_inds = alg_param["rs_compressor"].permutation(alg_param["dim"])
        Q_mat = np.zeros((alg_param["num_workers"], alg_param["dim"]), dtype=np.float64)
        for i in range(alg_param["num_workers"]):
            permk_vec = permk_compressor(vec.copy(), perm_inds, alg_param["k"], i)
            Q_mat[i] = permk_vec.copy()
        return Q_mat 
    
    def fill_alg_params_dict(self, state, oracle, data, alg_param):
        #completed
        alg_param["k"] = int(alg_param["dim"]/alg_param["num_workers"])
        alg_param["cost_per_float"] = NUM_BITS_PER_FLOAT+1 + np.log2(alg_param["dim"])
        alg_param["randk_probs"] = np.full(alg_param["dim"], 1 / alg_param["dim"])
        #alg_param["comm_cost"] = alg_param["cost_per_float"]*alg_param["k"]
        
        alg_param["omega"] = alg_param["dim"]/alg_param["k"] - 1
        if alg_param['step_size_init'] == "const":
            x_0 = state["x"]
            p = alg_param["prob"]
            T = alg_param["max_iters"]
            omega = alg_param["omega"]
            wtL_0 = alg_param["wtL_0"]
            L_0 = alg_param["L_0"]
            wtB_star_0 = wtL_0**2 + 2*wtL_0*L_0*np.sqrt((1-p)*omega/(p))
            R_0 = twonorm(x_0 - alg_param["x_star"])
            stepsize_baseline = (R_0)/(np.sqrt(T*wtB_star_0))
            alg_param["fixed_step_size"] = np.float64(stepsize_baseline)*alg_param['factor']
        
        alg_param["step_size"] = {
            "const": lambda sf_wit, wtB_star: alg_param["fixed_step_size"],
            "polyak": lambda sf_wit, wtB_star: (sf_wit - alg_param["f_star"])*alg_param['factor']/wtB_star
            }[alg_param['step_size_init']]
        
        alg_param["Q_matrix"] = {
            "sameRandK": lambda alg_param, vec: self.sameRandKCompressor(alg_param, vec),
            "indRandK": lambda alg_param, vec: self.indRandKCompressor(alg_param, vec), 
            "PermK": lambda alg_param, vec: self.PermKCompressor(alg_param, vec)
            }[alg_param["compressor"]]
        
        
        my_print(f"stepsize is set: {alg_param['step_size']}", alg_param['print_status'])
        
        return d_copy(alg_param)
    
    # Algorithm dependendent function
    def init_states_dict(self, state, oracle, data, alg_param):
        state["W"] = np.tile(state["x"], (alg_param["num_workers"], 1)).copy()
        state["w"] = state["x"].copy()
        state["gs_W"] = oracle["local_grads"](state["W"], data["X"], data["y"]).copy()
        state["gs_w"] = state["gs_W"].copy()
        state["g_a_W"] = np.mean(state["gs_W"], axis=0).copy()
        state["g_a_w"] = np.mean(state["gs_w"], axis=0).copy()
        return d_copy(state)
    
    # Algorithm dependendent function
    def init_collectable_metrics_dict(self, state, collectable_metric, alg_param, oracle, data):
        # MARINA-N: completed
        if "iters" in collectable_metric.keys():
            collectable_metric["iters"] = [0]        
        if "bits" in collectable_metric.keys():
            collectable_metric["bits"] = [0]     
        if "func_diff" in collectable_metric.keys():
            collectable_metric["func_diff"] = [oracle["f"](state["x"], data["X"], data["y"]) - alg_param["f_star"]]
        if "r-L_0" in collectable_metric.keys():
            collectable_metric["r-L_0"] = [twonorm(state["g_a_w"])]
        if "r-wtL_0" in collectable_metric.keys():
            collectable_metric["r-wtL_0"] =  [np.sqrt(np.mean(np.linalg.norm(state["gs_w"], ord=2, axis=1)**2) )]
        if "r-L_hat" in collectable_metric.keys():
            collectable_metric["r-L_hat"] = [twonorm(state["g_a_W"])]
        if "r-L_0,pm" in collectable_metric.keys():
            collectable_metric["r-L_0,pm"] = [np.sqrt(np.mean(np.linalg.norm(state["gs_W"] - np.tile(state["g_a_w"], (alg_param["num_workers"], 1)), ord=2, axis=1)**2) )]
        if "r-L_0+r-L_0,pm" in collectable_metric.keys():
            collectable_metric["r-L_0+r-L_0,pm"] = [np.sqrt((collectable_metric["r-L_0"][-1])**2 + (collectable_metric["r-L_0,pm"][-1])**2) ]
        return d_copy(collectable_metric)
    
    # Algorithm dependendent function
    def update_collectable_metrics_dict(self, state, collectable_metric, alg_param, oracle, data, comm_cost_single_iter):
         # MARINA-N: completed
        if "iters" in collectable_metric.keys():
            collectable_metric["iters"].append(collectable_metric["iters"][-1]+1)
        if "bits" in collectable_metric.keys():
            collectable_metric["bits"].append(collectable_metric["bits"][-1]+comm_cost_single_iter)  
        if "func_diff" in collectable_metric.keys():
            collectable_metric["func_diff"].append(oracle["f"](state["x"], data["X"], data["y"]) - alg_param["f_star"])
        if "r-L_0" in collectable_metric.keys():
            collectable_metric["r-L_0"].append(twonorm(state["g_a_w"]))
        if "r-wtL_0" in collectable_metric.keys():
            collectable_metric["r-wtL_0"].append( np.sqrt(np.mean(np.linalg.norm(state["gs_w"], ord=2, axis=1)**2) ) )  
        if "r-L_hat" in collectable_metric.keys():
            collectable_metric["r-L_hat"].append(twonorm(state["g_a_W"]))
        if "r-L_0,pm" in collectable_metric.keys():
            collectable_metric["r-L_0,pm"].append(np.sqrt(np.mean(np.linalg.norm(state["gs_W"] - np.tile(state["g_a_w"], (alg_param["num_workers"], 1)), ord=2, axis=1)**2) ))
        if "r-L_0+r-L_0,pm" in collectable_metric.keys():
            collectable_metric["r-L_0+r-L_0,pm"].append( np.sqrt((collectable_metric["r-L_0"][-1])**2 + (collectable_metric["r-L_0,pm"][-1])**2) )
        return d_copy(collectable_metric)
        
    # Algorithm dependendent function
    def update(self, state, data, collectable_metric, alg_param, oracle, update_collectable_metrics_dict):
        # MARINA-N: in progress
        state["x_prev"] = state["x"].copy()
        
        r_wtL_0 = collectable_metric["r-wtL_0"][-1]
        r_L_hat = collectable_metric["r-L_hat"][-1]
        wt_B_star = r_wtL_0**2 + 2*r_wtL_0*r_L_hat*np.sqrt((1-alg_param["prob"])*alg_param["omega"]/alg_param["prob"])
        sf_wit = np.mean(oracle["local_losses"](state["W"], data["X"], data["y"]))
        
        step_size = alg_param["step_size"](sf_wit, wt_B_star)
        state["x"] = state["x_prev"] - step_size*state["g_a_W"] # gradient step
        
        c_t = alg_param["rs_bernoulli"].binomial(1, alg_param["prob"], 1)[0] #sample bernoully random varinable
        if c_t == 1:
            state["W"] = np.tile(state["x"], (alg_param["num_workers"], 1)).copy()
            comm_cost_single_iter = alg_param["cost_per_float"]*alg_param["dim"]
        else:
            W_prev = state["W"].copy()
            state["W"] = W_prev + alg_param["Q_matrix"](alg_param, state["x"] - state["x_prev"])
            comm_cost_single_iter = alg_param["cost_per_float"]*alg_param["k"]
            
        state["w"] = np.mean(state["W"], axis=0).copy()
        
        state["gs_W"] = oracle["local_grads"](state["W"], data["X"], data["y"]).copy()
        state["gs_w"] = oracle["non_local_grads"](state["w"], data["X"], data["y"]).copy()
        
        state["g_a_W"] = np.mean(state["gs_W"], axis=0).copy()
        state["g_a_w"] = np.mean(state["gs_w"], axis=0).copy()
        
        collectable_metric = update_collectable_metrics_dict(state, d_copy(collectable_metric), alg_param, oracle, data, comm_cost_single_iter)
        return d_copy(state), d_copy(collectable_metric), d_copy(alg_param)
        
if __name__ == "__main__":
    MARINA_P().run()
    
    
    