############
import os
NUM_THREADS = 4
os.environ["OPENBLAS_NUM_THREADS"]   = str(NUM_THREADS) # export OPENBLAS_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = str(NUM_THREADS) 
os.environ["MKL_NUM_THREADS"]        = str(NUM_THREADS)
os.environ["NUMEXPR_NUM_THREADS"]    = str(NUM_THREADS)
os.environ["OMP_NUM_THREADS"]        = str(NUM_THREADS)
 
NUM_CPU = len(os.sched_getaffinity(0)) 
NUM_PROCESS = NUM_CPU // NUM_THREADS 
print(f'NUM_PROCESS={NUM_PROCESS}')


import time
import multiprocessing as mp
import pickle

from NGDOracle import *
from NGD import *
from GradientTrack import *
from BRIDGE import *
from ClippedGossip import *


import argparse

parser = argparse.ArgumentParser(description='Arguments for the simulation')
parser.add_argument('--net_typ', type=float)
parser.add_argument('--byz_typ', type=float)
parser.add_argument('--is_iid', type=int, default = 1)

args = parser.parse_args()
net_typ = int(args.net_typ)
byz_typ = int(args.byz_typ)
is_iid = bool(args.is_iid)




n_samples = 100
n_test = int(0.2*n_samples)

n_workers = 100
n_features = 50
s = int(0.2*n_features)
coefs_true = np.r_[np.ones([s,1]), np.zeros([n_features-s, 1])]

model_typ = 'linear'
max_niter = 500
lr_constant = 0.01
tol = 1e-5
R = 20

# directory to save the results
dir = './Results/'
if not os.path.exists(dir):
    os.mkdir(dir)


if net_typ == 1:
    dqs = [5, 30]
elif net_typ == 2:
    dqs = [0.2, 0.4]

byz_ratios = [0.1, 0.15, 0.2, 0.25, 0.3]



for i, dq in enumerate(dqs):
    
    outer_seed = 10*net_typ+byz_typ
    W = generate_network(n_workers, typ=net_typ, q=dq, d=dq, random_state = outer_seed)

    for j, byz_ratio in enumerate(byz_ratios):

        param_grid = {'cn': np.linspace(np.sqrt(n_samples)/25, np.sqrt(n_samples)/5, 5)}
        temp_model = ByzantiumDFL(n_workers=n_workers, model_typ=model_typ, coefs_true=coefs_true, random_state=outer_seed+j)
        X,y,Xs_all,ys_all = temp_model.generate_byz_data(byz_ratio=byz_ratio, n_samples=n_samples+n_test, is_iid = is_iid, byz_typ=byz_typ)
        Xs, Xs_val, ys, ys_val = list_train_test_split(Xs_all, ys_all, test_size=n_test, random_state=outer_seed+j)
        _, X_star, y_star = get_matform(Xs,ys,coefs_true,temp_model.n_workers)
        _, X_star_val, y_star_val = get_matform(Xs_val,ys_val, coefs_true,temp_model.n_workers)
        
        ngd = NGD(W,lr_constant,n_workers=n_workers,max_niter=max_niter,coefs_true=coefs_true,tol=tol,
                   model_typ=model_typ,byz_labels=temp_model.byz_labels,random_state=outer_seed+j)
        ngd.fit(X_star, y_star)
        angd = AdaptiveNGD(W,lr_constant,n_workers=n_workers,max_niter=max_niter,coefs_true=coefs_true,tol=tol,
                           model_typ=model_typ,byz_labels=temp_model.byz_labels,random_state=outer_seed+j, 
                           coefs_init_star=ngd.coefs_,max_nrefit=2)
        losses,scores,best_params,angd = grid_search(X_star, y_star, X_star_val,y_star_val, angd, angd.coefs_init_star, param_grid,refit=True)
        cn = angd.cn
    
        
        def map_fun(r):
            random_state = outer_seed*10+r
            
            model = ByzantiumDFL(n_workers=n_workers, model_typ=model_typ, coefs_true=coefs_true, random_state=random_state)
            X,y,Xs,ys = model.generate_byz_data(byz_ratio=byz_ratio,n_samples=n_samples,is_iid = is_iid,byz_typ=byz_typ)

            
            X_oracle = np.array([Xs[i] for i in range(len(Xs)) if not model.byz_labels[i]]).reshape(-1, n_features)
            y_oracle = np.array([ys[i] for i in range(len(ys)) if not model.byz_labels[i]]).reshape(-1, 1)
            
            coefs_oracle = np.linalg.solve(X_oracle.T @ X_oracle, X_oracle.T @ y_oracle)
            coefs_global = np.linalg.solve(X.T @ X, X.T @ y) 

            _, X_star, y_star = get_matform(Xs,ys,coefs_true,model.n_workers)
            
            
            ngd = NGD(W,lr_constant,n_workers=n_workers,max_niter=max_niter,coefs_true=coefs_true,tol=tol,
                      model_typ=model_typ,byz_labels=model.byz_labels,random_state=random_state)
            ngd.fit(X_star, y_star)

            
            angd = AdaptiveNGD(W,lr_constant,n_workers=n_workers,max_niter=max_niter,coefs_true=coefs_true,tol=tol,
                               model_typ=model_typ,byz_labels=model.byz_labels,random_state=random_state,
                               coefs_init_star=ngd.coefs_,max_nrefit=2)
            angd.cn = cn
            angd.fit(X_star, y_star)
            angd.refit(X_star, y_star)
            

            bridge_m = BRIDGE(W,lr_constant,n_workers=n_workers,max_niter=max_niter,coefs_true=coefs_true,tol=tol,
                            model_typ=model_typ,byz_labels=model.byz_labels,random_state=random_state)
            bridge_m.fit(X_star, y_star)
            

            bridge_t = BRIDGE(W,lr_constant,n_workers=n_workers,max_niter=max_niter,coefs_true=coefs_true,tol=tol,
                            model_typ=model_typ,byz_labels=model.byz_labels,random_state=random_state,
                            agg_method='trimmed_mean',proportiontocut=byz_ratio)
            bridge_t.fit(X_star, y_star)

            if net_typ == 2:
                slbrn_m = BTGradientTrack(W,lr_constant,n_workers=n_workers,max_niter=max_niter, coefs_true=coefs_true,tol=tol,
                                    model_typ=model_typ,byz_labels=model.byz_labels,random_state=random_state)
                slbrn_m.fit(X_star, y_star)
                
    
                slbrn_t = BTGradientTrack(W,lr_constant,n_workers=n_workers,max_niter=max_niter, coefs_true=coefs_true,tol=tol,
                                    model_typ=model_typ,byz_labels=model.byz_labels,random_state=random_state,
                                    agg_method='trimmed_mean',proportiontocut=byz_ratio)
                slbrn_t.fit(X_star, y_star)

            
            clipped = ClippedGossip(W,lr_constant,n_workers=n_workers,max_niter=max_niter,coefs_true=coefs_true,tol=tol,
                        model_typ=model_typ,byz_labels=model.byz_labels,random_state=random_state, 
                        alpha=0.1, delta_max=2*byz_ratio)
            clipped.fit(X_star, y_star)
            
            if net_typ == 2:
                methods = [ngd, angd, bridge_m, bridge_t, clipped, slbrn_m, slbrn_t]
            else:
                methods = [ngd, angd, bridge_m, bridge_t, clipped]

            scores = [method.history_score for method in methods]
            scores_nbyz = [method.history_score_nbyz for method in methods]

            scores_global = np.sum((coefs_global-coefs_true)**2)
            scores_oracle = np.sum((coefs_oracle-coefs_true)**2)

            scores_nbyz = [scores_global, scores_oracle] + scores_nbyz

            return (scores, scores_nbyz)
            


        tic1 = time.time()
        with mp.Pool(NUM_PROCESS) as pool: 
            Results = pool.map(map_fun, range(R))
            
        
        filename = f'Result_ntype{net_typ}_btype{byz_typ}_dq{dq}_byz_ratio{byz_ratio}_is_iid{is_iid}'+'.save'

    
        with open(dir+filename, 'wb') as f:
            pickle.dump([cn, Results], f)
        
        print(f'ntype={net_typ} | btype={byz_typ} | d/q={dq} | byz_ratio={byz_ratio} is finished, {time.time()-tic1:.2f} seconds elapsed.')

        



