from src.utils import *
from src.oracle_functions import *

#TODO: put general asserts there that are common for all experiments

###################################################
######### Global project-dependend params #########
###################################################
ALLOWABLE_EXPERIMENTS = ["GD", "MARINA-P", "EF21-P", "3PCv2-P"]
ALLOWABLE_ALGORITHMS = ["GD", "MARINA-P", "EF21-P", "3PCv2-P"]
# ALLOWABLE_SAMPLINGS = ["FBS", "NICE", "imp", "NICE_original", "imp_original", ""]
ALLOWABLE_SAMPLINGS = ["full"]
ALLOWABLE_COMPRESSORS = ["sameRandK", "indRandK", "PermK", "TopK", "I", "sameRandK-TopK", "indRandK-TopK", "PermRandK-TopK"]

ALLOWABLE_STOP_CRITERIA = ['epochs', 'bits', 'comms', 'iters', 'arg_res', 'func_diff', 'sqnorm', 'total_cost']

ALLOWABLE_COLLECTABLE_METRICS = ['epochs', 'bits', 'comms', 'iters', 'arg_res', 'func_diff', 'sqnorm', 'total_cost', 'r-L_0,pm', 'r-L_0', 'r-wtL_0', 'r-L_hat', 'r-L_0+r-L_0,pm']
ALLOWABLE_COLLECTABLE_METRICS += [item+"_grad_comp" for item in ALLOWABLE_COLLECTABLE_METRICS] # when full grad is computed  

ALLOWABLE_AVG_KEYS = ["func_diff", "arg_res", "sqnorm", "total_cost"]
ALLOWABLE_AVG_KEYS += [item+"_grad_comp" for item in ALLOWABLE_AVG_KEYS] # when full grad is computed
ALLOWABLE_NON_AVG_KEYS = list(set(ALLOWABLE_COLLECTABLE_METRICS) - set(ALLOWABLE_AVG_KEYS))

ALLOWABLE_DATASETS = set(["X", "y", "X_mean", "y_mean"])

ALLOWABLE_STOP_CRITERIA_CONDITIONS = {'epochs': lambda cur_epoch_number, max_epochs: cur_epoch_number <= max_epochs,
                                        'bits': lambda cur_bits_number, max_bits: cur_bits_number <= max_bits,
                                        'comms': lambda cur_comms_number, max_comms: cur_comms_number <= max_comms,
                                        'iters': lambda cur_iters_number, max_iters: cur_iters_number <= max_iters,
                                        'arg_res': lambda cur_arg_res, tol: cur_arg_res >= tol,
                                        'func_diff': lambda cur_func_diff, tol: cur_func_diff >= tol,
                                        'sqnorm': lambda cur_sqnorm, tol: cur_sqnorm >= tol
                                        }
ALLOWABLE_PARAMS = set(['L_0', 'L_0,i', 'wtL_0', 'L_pm', 'L_0,pm', 'r-L_0,pm', 'r-L_0', 'r-L_0+r-L_0,pm' 'r-wtL_0', 'r-L_hat', 'f_star', 'x_star'])

# r-L_0,pm -  "real" L_0,pm - relates to the experimentaly observed value
LIBSVM_CLASSIFICATION_DATASETS = set(["a9a", "w8a", "mushrooms", "ijcnn1", "covtype", "phishing", "rcv1", 
                                  "real-sim", "news20.binary", "cod-rna", "dna", "svmguide3", 
                                  "svmguide1", "svmguide2", "splice", "madelon", "gisette", "dexter", 
                                  "dorothea", "colon-cancer", "leukemia", "lung-cancer", "rcv1.binary", 
                                  "sector", "usps", "mnist"])

SUPPORTED_LOSS_FUNCS = set(["log-reg", 
                            "quadratic",
                            "l1_norm"])

#specific dataset developed for auxpage project
QUADRATIC_DATASETS = set(["synthetic_dense", "synthetic_sparse", "synthetic_sparse_zero"])

L1_NORM_DATASETS = set(["synthetic_dense", "synthetic_sparse", "synthetic_sparse_zero"])

ALLOWABLE_PLOT_FAMILIES = ['ALL', "SINGLE_RELEASE"]

# Ray parameter
NUM_CORES = 48

##########################
# Auxiliary biased stuff #
##########################
cost_function_biased = lambda c,k,p: (p + (1 - p)*c)*(1 + ((1 + np.sqrt(1 - p)) / p - 1) * k) # cost function in the biased case
def get_stepsize_biased(L_0, delta, p): 
    return 1/(L_0 + delta*np.sqrt((1-p)/((1-np.sqrt(1-p))**2)))
def get_optimal_params_biased(c, kappa):
    p = smp.symbols('p')
    expr = 0.5*(c*(-kappa*(-2*p**2 + smp.sqrt(1-p)*p + 2*smp.sqrt(1-p) + 2)/p**2 - 2) + kappa*(-1/smp.sqrt(1-p) - 2) + 2)
    sol = np.array(list(map(complex, smp.solve(expr, p, domain=smp.Interval.Lopen(0, 1)))))
    sol_real = np.real(sol)
    if sol_real.shape[0]==3:
        cost_1 = cost_function_biased(c, kappa, sol_real[1])
        cost_2 = cost_function_biased(c, kappa, sol_real[2])
        if cost_1 <= cost_2:
            p_opt = sol_real[1]
            cost_opt = cost_1
        else:
            p_opt = sol_real[2]
            cost_opt = cost_2
        if cost_opt>1:
            p_opt = 1
    elif sol_real.shape[0]==2:
        cost_opt = cost_function_biased(c, kappa, sol_real[1])
        p_opt = sol_real[1]
    else:
        print(c, kappa, sol_real)
        raise ValueError("")
    return p_opt, cost_opt
##########################

############################
# Auxiliary unbiased stuff #
############################
def cost_prime_unbiased (c,k,p):
    term1 = (-((1 - p) / p**2) - p**(-1)) * (c * (1 - p) + p) * k
    term2 = 2 * np.sqrt((1 - p) / p)
    term3 = (1 - c) * (1 + np.sqrt((1 - p) / p) * k)
    return sign(term1 / term2 + term3)

p_hat_unbiased = lambda c: (3*c)/(3*c + 1)

def cost_prime_grid_unbiased(c_vals,k_vals):
    P = np.array([p_hat_unbiased(c) for c in c_vals], dtype=np.float64)
    #dim = P.shape[0]
    z = np.zeros((k_vals.shape[0], c_vals.shape[0]))
    for i,k in enumerate(k_vals):
        for j,c in enumerate(c_vals):
            z[i,j] = cost_prime_unbiased(c,k,P[j])
    return z

cost_function_unbiased = lambda c,k,p: (p + (1 - p) * c) * (1 + k * np.sqrt((1 - p) / p))
def get_stepsize_unbiased(L_0, delta, p): 
    return 1/(L_0 + delta*np.sqrt((1-p)/(p)))

def get_optimal_params_unbiased(c, kappa):
    p = smp.symbols('p')
    expr = (-((1 - p) / p**2) - p**(-1)) * (c * (1 - p) + p) * kappa / (2 * smp.sqrt((1 - p) / p)) + (1 - c) * (1 + smp.sqrt((1 - p) / p) * kappa)
    
    sol = np.array(list(map(complex, smp.solve(expr, p, domain=smp.Interval.Lopen(0, 1)))))
    sol_real = np.real(sol)
    
    sol_attached = np.append(sol_real, 1.0)
    costs = np.array([cost_function_unbiased (c, kappa, sol) for sol in  sol_attached], dtype=np.float64)
    p_opt = sol_attached[np.argmin (costs)]
    return p_opt, cost_function_unbiased(c, kappa, p_opt)
####################################################################################

####################################################################################
class Experiment():
    def __init__(self):
        pass
    
    def init_regularizers(self):
        self.regularizer = {"str-cvx":regularizer_scvx,
                            "non-cvx":regularizer_noncvx}[self.arg_values["regularizer_type"]]
        self.regularizer_grad = {"str-cvx":regularizer_scvx_grad,
                                    "non-cvx":regularizer_noncvx_grad}[self.arg_values["regularizer_type"]]
        self.regularizer_hess_bound = {"str-cvx":regularizer_scvx_hess_bound,
                                        "non-cvx":regularizer_noncvx_hess_bound}[self.arg_values["regularizer_type"]]
    
    def init_oracles(self):
        my_print("Defining oracles...", self.arg_values["print_status"])
        
        self.regularizer = {"str-cvx":regularizer_scvx, 
                            "non-cvx":regularizer_noncvx}[self.arg_values['regularizer_type']]
        self.regularizer_grad = {"str-cvx":regularizer_scvx_grad, 
                                 "non-cvx":regularizer_noncvx_grad}[self.arg_values['regularizer_type']]
        self.regularizer_hess = {"str-cvx":regularizer_scvx_hess, 
                                 "non-cvx":regularizer_noncvx_hess}[self.arg_values['regularizer_type']]
        self.regularizer_hess_bound = {"str-cvx":regularizer_scvx_hess_bound, 
                                       "non-cvx":regularizer_noncvx_hess_bound}[self.arg_values['regularizer_type']]
        
        self.oracle_loss = {"log-reg":logreg_loss_distributed, 
                            "quadratic":quad_loss_ij,
                            "l1_norm":l1_norm_loss_i_distributed}[self.arg_values['loss_func']]
        
        self.oracle_grad = {"log-reg":logreg_grad_distributed,
                            "quadratic":quad_grad_ij,
                            "l1_norm":l1_norm_grad_i_distributed}[self.arg_values['loss_func']]
        
        self.oracle_hess = {"log-reg":logreg_hess_distributed,
                            "quadratic":quad_hess_ij,
                            "l1_norm": None}[self.arg_values['loss_func']]
        
        self.oracle_hess_bound = {"log-reg":logreg_hess_bound_distributed,
                                  "quadratic":quad_hess_ij,
                                  "l1_norm": None}[self.arg_values['loss_func']]
        
        # per-worker losses computed at different points, each corresponding to the same worker
        self.local_losses = {"log-reg":None, 
                            "quadratic":None,
                            "l1_norm": l1_norm_local_losses_i_distributed}[self.arg_values['loss_func']]
        
        # per-worker grads computed at different points, each corresponding to the same worker
        self.local_grads = {"log-reg":None, 
                            "quadratic":quad_local_grads,
                            "l1_norm": l1_norm_local_grads_i_distributed}[self.arg_values['loss_func']]
        
        # per-worker grads computed at the same point
        self.non_local_grads = {"log-reg":None, 
                            "quadratic":None,
                            "l1_norm": l1_norm_non_local_grads_i_distributed}[self.arg_values['loss_func']]
        
        
        self.oracle_dict = {"f": lambda w, X, y: self.oracle_loss(w, X, y, self.alg_params_dict["la"], self.regularizer),
                            "grad": lambda w, X, y: self.oracle_grad(w, X, y, self.alg_params_dict["la"], self.regularizer_grad),
                            "hess": lambda w, X, y: self.oracle_hess(w, X, y, self.alg_params_dict["la"], self.regularizer_hess),
                            "hess_bound": lambda w, X: self.oracle_hess_bound(w, X, self.alg_params_dict["la"], self.regularizer_hess_bound),
                            "local_losses": lambda W, X, Y: self.local_losses(W, X, Y, self.alg_params_dict["la"], self.regularizer),
                            "local_grads": lambda W, X, Y: self.local_grads(W, X, Y, self.alg_params_dict["la"], self.regularizer_grad),
                            "non_local_grads": lambda w, X, Y: self.non_local_grads(w, X, Y, self.alg_params_dict["la"], self.regularizer_grad),
                            }
    
    def load_prepared_datasets(self):
        # This function was updated for MARINA-N project
        # If there are bugs, see implementation in the source code for the hfh project
        for data_part_name in self.data_dict.keys():
            path = self.dataset_path + data_part_name + self.exp_data_extension
            self.data_dict[data_part_name] = load_param(path, data_part_name, self.arg_values["print_status"])
    
    
    def get_part_dataset(self, data_part_name, inds):
        # This function was updated for MARINA-N project
        # If there are bugs, see implementation in the source code for the hfh project
        path = self.dataset_path + data_part_name + self.exp_data_extension
        return load_selected_sparse_matrices(path, data_part_name, inds, self.arg_values["print_status"])
        
    def load_w_init(self):
        self.x_0 = np.array(np.load(self.data_path + 'w_init' + self.w_init_extension + '.npy'), dtype=np.float64)

    def load_parameters(self):
        for param in self.alg_params_dict.keys():
            self.comp_params_path = self.data_path + 'comp_params' + self.exp_params_extension + "/"
            param_path = self.comp_params_path + param + self.exp_params_extension
            self.alg_params_dict[param] = load_param(param_path, param, self.arg_values["print_status"])
    
    def init_comp_params_dict(self):
        try: 
            comp_params_list = ast.literal_eval(self.arg_values["comp_params_str"])
        except ValueError:
            print("The string is not a valid list representation.")
        
        if isinstance(comp_params_list, list) and all(isinstance(item, str) for item in comp_params_list):
            self.comp_params_dict = {key: None for key in comp_params_list}
        else:
            print("The list does not contain only string elements.")

        self.comp_params_set = set(comp_params_list)
        assert(self.comp_params_set.issubset(ALLOWABLE_PARAMS))        
    
    #Project dependend functions    
    def init_exp_param_extension(self):
        if self.arg_values["loss_func"]=="l1_norm":
            if is_float(self.arg_values["la_init"]):
                self.arg_values["la"] = float(self.arg_values["la_init"])
                self.exp_params_extension = '_{6}_{0}_d{1}_nw{2}_ns{3}_nsc{7}_b{4}_la{5}'.format(self.arg_values["dataset"], self.arg_values["dim"], self.arg_values["num_workers"], self.arg_values["num_samples"], myrepr(self.arg_values["batchsize"]), myrepr(self.arg_values["la"]), self.arg_values["loss_func"], myrepr(self.arg_values["noise_scale"])) 
            else: 
                raise ValueError("other options are not supported")
        
        elif self.arg_values["loss_func"]=="log-reg":
            if is_float(self.arg_values["la_init"]):
                self.arg_values["la"] = float(self.arg_values["la_init"])
                self.exp_params_extension = '_{0}_{1}_nw{2}_la{3}'.format(self.arg_values["loss_func"], self.arg_values["dataset"],self.arg_values["num_workers"], myrepr(self.arg_values["la"]))
            else: 
                raise ValueError("other options are not supported")
        else: 
            raise ValueError("other options are not supported")
    
    def init_exp_data_extension(self):   
        self.exp_data_extension = {
            "log-reg":"_{0}_nw{1}".format(self.arg_values["dataset"], self.arg_values["num_workers"]),
            "l1_norm":'_{0}_d{1}_nw{2}_ns{3}_nsc{4}'.format(self.arg_values["dataset"], self.arg_values["dim"], self.arg_values["num_workers"], self.arg_values["num_samples"], myrepr(self.arg_values["noise_scale"]))
            }[self.arg_values["loss_func"]]
    
    def init_w_init_extension(self):
        self.w_init_extension = {"log-reg":'_{0}_{1}'.format(self.arg_values["loss_func"], self.arg_values["dataset"]),
                            "quadratic":'_{0}_{1}_d{2}'.format(self.arg_values["loss_func"], self.arg_values["dataset"], self.arg_values["dim"]),
                        "l1_norm":'_{0}_{1}_d{2}'.format(self.arg_values["loss_func"], self.arg_values["dataset"], self.arg_values["dim"])
                        }[self.arg_values["loss_func"]]
        
    
    def init_exp_name_extension(self):
        if self.arg_values['exp_name'] in ALLOWABLE_EXPERIMENTS:
            if self.arg_values['alg_name'] == "MARINA-P":
                self.exp_name_extension = self.exp_params_extension + "{3}_p{4}_s{2}_ss{0}_f{1}".format(self.arg_values['step_size_init'], myrepr(self.arg_values['factor']), self.arg_values['sampling'], self.arg_values['compressor'], myrepr(self.arg_values['prob']))
            elif self.arg_values['alg_name'] == "EF21-P":
                self.exp_name_extension = self.exp_params_extension + "{3}_s{2}_ss{0}_f{1}".format(self.arg_values['step_size_init'], myrepr(self.arg_values['factor']), self.arg_values['sampling'], self.arg_values['compressor'])
            elif self.arg_values['alg_name'] == "3PCv2-P":
                k = int(self.arg_values['dim']/self.arg_values['num_workers'])
                self.exp_name_extension = self.exp_params_extension + "{3}_k{4}_s{2}_ss{0}_f{1}".format(self.arg_values['step_size_init'], myrepr(self.arg_values['factor']), self.arg_values['sampling'], self.arg_values['compressor'], f"{int(k*self.arg_values['qC'])}-{int(k*(1 - self.arg_values['qC']))}")
        else:
            raise ValueError("other options are not supported")
        
    def init_dataset_path(self):
        self.dataset_path = {"quadratic": self.data_path + 'data' + self.exp_data_extension + "/",
                             "l1_norm": self.data_path + 'data' + self.exp_data_extension + "/",
                             "log-reg": self.data_path}[self.arg_values["loss_func"]]
    
    def extract_str_from_param(self, str):
        str_list = str_filter(extract_str_multiple(self.alg_params_dict.keys(), [str, "_"+self.arg_values['exp_name']]), "_func_opt")
        assert len(str_list)>0
        
        if len(str_list)==1:
            extracted_str = str_list[0]
        else:
            if self.arg_values['sampling'] == "NICE":
                str_list = str_filter(str_list, "imp")
            elif "imp" in self.arg_values['sampling']:
                str_list = str_filter(str_list, "NICE")
            extracted_str = str_list[0]
        assert len(str_list)==1
        return extracted_str
    
    def init_alg_params_dict(self):
        self.alg_params_dict = parse_params_to_dict(self.arg_values['loadable_params'], ALLOWABLE_PARAMS)
        assert set(self.alg_params_dict.keys()).issubset(ALLOWABLE_PARAMS)
        
    def init_load_params_dict(self):
        try: 
            load_params_list = ast.literal_eval(self.arg_values["loadable_params"])
        except ValueError:
            print("The string is not a valid list representation.")
        
        if isinstance(load_params_list, list) and all(isinstance(item, str) for item in load_params_list):
            self.load_params_dict = {key: None for key in load_params_list}
        else:
            print("The list does not contain only string elements.")
        
        self.loadable_params_list = load_params_list
        self.loadable_params_set = set(load_params_list)
        assert(self.loadable_params_set.issubset(ALLOWABLE_PARAMS))
        
    def save_comp_params(self):
        self.comp_params_path = self.data_path + 'comp_params' + self.exp_params_extension + "/"
        if not os.path.exists(self.comp_params_path):
            os.mkdir(self.comp_params_path)
        for param in self.comp_params_dict.keys():
            param_path = self.comp_params_path + param + self.exp_params_extension
            save_param(param_path, param, self.comp_params_dict[param], self.arg_values["print_status"])
            
    def log_peak_memory_usage(self):
        peak_memory_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        my_print(f"Peak Memory Usage: {peak_memory_usage / 1024} MB", self.arg_values["print_status"])