"""
EF21_P
"""
# MARINA-N project: completed

from src.algorithm import *

class EF21_P(Algorithm):
    def __init__(self, args=None):
        super().__init__(args)

    def script_directory(self):
        return os.path.dirname(os.path.abspath(__file__))

    def fill_alg_params_dict(self, state, oracle, data, alg_param):
        
        alg_param["k"] = int(alg_param["dim"]/alg_param["num_workers"])
        alg_param["cost_per_float"] = NUM_BITS_PER_FLOAT+1 + np.log2(alg_param["dim"])
        alg_param["comm_cost"] = alg_param["cost_per_float"]*alg_param["k"]
        
        alpha = alg_param["k"]/alg_param["dim"]
        alg_param["B_star"] = 1 + 2*(( np.sqrt(1-alpha) + 1-alpha)/(alpha))
        
        if alg_param['step_size_init'] == "const":
            x_0 = state["x"]
            T = alg_param["max_iters"]
            R_0 = twonorm(x_0 - alg_param["x_star"])
            stepsize_baseline = (R_0)/(np.sqrt(T*alg_param["B_star"])*alg_param["L_0"])
            alg_param["fixed_step_size"] = np.float64(stepsize_baseline)*alg_param['factor']
        
        alg_param["step_size"] = {
            "const": lambda f_wt, sqnorm_f_wt: alg_param["fixed_step_size"],
            "polyak": lambda f_wt, sqnorm_f_wt: (f_wt - alg_param["f_star"])*alg_param['factor']/(alg_param["B_star"]*sqnorm_f_wt)
            }[alg_param['step_size_init']]
        
        my_print(f"stepsize is set: {alg_param['step_size']}", alg_param['print_status'])
        
        return d_copy(alg_param)
    
    # Algorithm dependendent function
    def init_states_dict(self, state, oracle, data, alg_param):
        state["w"] = state["x"].copy()
        state["g"] = oracle["grad"](state["w"], data["X"], data["y"]).copy()
        return d_copy(state)
    
    # Algorithm dependendent function
    def init_collectable_metrics_dict(self, state, collectable_metric, alg_param, oracle, data):
        if "iters" in collectable_metric.keys():
            collectable_metric["iters"] = [0]        
        
        if "bits" in collectable_metric.keys():
            collectable_metric["bits"] = [0]
        
        if "r-L_0" in collectable_metric.keys():
            collectable_metric["r-L_0"] = [twonorm(state["g"])]
            
        if "func_diff" in collectable_metric.keys():
            collectable_metric["func_diff"] = [oracle["f"](state["x"], data["X"], data["y"]) - alg_param["f_star"]]
        
        return d_copy(collectable_metric)
    
    # Algorithm dependendent function
    def update_collectable_metrics_dict(self, state, collectable_metric, alg_param, oracle, data, comm_cost_single_iter):
        if "iters" in collectable_metric.keys():
            collectable_metric["iters"].append(collectable_metric["iters"][-1]+1)
        
        if "r-L_0" in collectable_metric.keys():
            collectable_metric["r-L_0"].append(twonorm(state["g"]))
            
        if "bits" in collectable_metric.keys():
            collectable_metric["bits"].append(collectable_metric["bits"][-1]+comm_cost_single_iter)  
        
        if "func_diff" in collectable_metric.keys():
            collectable_metric["func_diff"].append(oracle["f"](state["x"], data["X"], data["y"]) - alg_param["f_star"])
        
        return d_copy(collectable_metric)
        
    # Algorithm dependendent function
    def update(self, state, data, collectable_metric, alg_param, oracle, update_collectable_metrics_dict):
         
        x_prev = state["x"].copy()
        step_size = alg_param["step_size"]( oracle["f"](state["w"], data["X"], data["y"]), sqnorm(state["g"]) )
        state["x"] = x_prev - step_size*state["g"] #x = x - step_size*grad_0
        
        w_prev = state["w"].copy()
        state["w"] = w_prev + top_k_compressor(state["x"] - w_prev, alg_param["k"])
        
        comm_cost_single_iter = alg_param["comm_cost"]
        state["g"] = oracle["grad"](state["w"], data["X"], data["y"]).copy()
        
        collectable_metric = update_collectable_metrics_dict(state, d_copy(collectable_metric), alg_param, oracle, data, comm_cost_single_iter)
        return d_copy(state), d_copy(collectable_metric), d_copy(alg_param)
        
if __name__ == "__main__":
    EF21_P().run()
    
    
    