from DL.utils import *
from pickle import dump
from mpi4py import MPI
from time import time
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor


torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.float64)
np.set_printoptions(precision=2)
eps = 1e-4

# Init MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

# experiment setting
# n: sample size
# p: number of features, or input size
# eff_p: number of effective features
# r: hidden layer size
# sigma: noise level, \eps ~ N(0, \sigma^2)
# Model: y = \sum a f(w^T x) + \epsilon
n, p, r = 500, 100, 16
batch_size = 20
n_test = 1000
nrep = 10
scaler = StandardScaler()
lams = np.array([0, 0.03, 0.06, 0.1])
nlam = len(lams)
# ps = np.arange(20, 120, 20)
# nsetting = len(ps)
sigmas = np.array([0, 0.5, 1, 5])
nsetting = len(sigmas)
test_err = np.zeros((nsetting, nrep))
snrs = np.zeros((nsetting))
models = np.empty((nsetting, nrep), dtype=object)
# right = np.zeros((nsetting, nrep))
# wrong = np.zeros((nsetting, nrep))

t = time()
for i in range(nsetting):
    sigma = sigmas[i]
    X_test, _, y_test = gen_friedman1(n_test, p, sigma)
    dt_X_train, dt_y_train, _ = gen_friedman1(n * nrep, p, sigma)

    dt_X_train = scaler.fit_transform(dt_X_train)
    X_test = scaler.transform(X_test)
    dt_y_train = scaler.fit_transform(dt_y_train.reshape((-1, 1))).reshape(-1)
    y_test = scaler.transform(y_test.reshape((-1, 1))).reshape(-1)

    # NN
    for k in range(nrep):
        X_train, y_train, = dt_X_train[k * n:(k + 1) * n], dt_y_train[k * n:(k + 1) * n]
        train_loader, test_loader, _ = gen_data_loader_from_data(X_train, y_train, X_test, y_test, batch_size)
        flag = np.inf
        for j in range(nlam):
            train_err, err, model = run(train_loader, test_loader, FC_no_bias(p+1, r, 1),
                                        lam=lams[j], lr=0.01, num_epochs=200, verbose=False)
            if err < flag:
                models[i, k], flag = model, err

        test_err[i, k] = flag*scaler.scale_**2
        # right[i, k], wrong[i, k] = variable_selection(model, 5, False, 0)

    # LASSO
    # for k in range(nrep):
    #     X_train, y_train, = dt_X_train[k * n:(k + 1) * n], dt_y_train[k * n:(k + 1) * n]
    #     lasso_cv = LassoCV(cv=5).fit(X_train, y_train)
    #     lasso = Lasso(alpha=lasso_cv.alpha_)
    #     y_pred_lasso = lasso.fit(X_train, y_train).predict(X_test)
    #     test_err[i, k] = np.mean((y_test - y_pred_lasso) ** 2)*scaler.scale_**2
    #     right[i, k], wrong[i, k] = variable_selection(lasso, 5, False)

    # OMP
    # for k in range(nrep):
    #     X_train, y_train, = dt_X_train[k * n:(k + 1) * n], dt_y_train[k * n:(k + 1) * n]
    #     omp_cv = OrthogonalMatchingPursuitCV(cv=5).fit(X_train, y_train)
    #     omp = OrthogonalMatchingPursuit(n_nonzero_coefs=omp_cv.n_nonzero_coefs_)
    #     y_pred_omp = omp.fit(X_train, y_train).predict(X_test)
    #     test_err[i, k] = np.mean((y_test - y_pred_omp) ** 2)*scaler.scale_**2
    #     right[i, k], wrong[i, k] = variable_selection(omp, 5, False)

    # RF
    # for k in range(nrep):
    #     X_train, y_train, = dt_X_train[k * n:(k + 1) * n], dt_y_train[k * n:(k + 1) * n]
    #     rf = RandomForestRegressor(n_estimators=150)
    #     y_pred_rf = rf.fit(X_train, y_train).predict(X_test)
    #     test_err[i, k] = np.mean((y_test - y_pred_rf) ** 2)*scaler.scale_**2
    #     right[i, k], wrong[i, k] = variable_selection(rf, 5, False)

    # GB
    # for k in range(nrep):
    #     X_train, y_train, = dt_X_train[k * n:(k + 1) * n], dt_y_train[k * n:(k + 1) * n]
    #     gb = GradientBoostingRegressor(n_estimators=200)
    #     y_pred_gb = gb.fit(X_train, y_train).predict(X_test)
    #     test_err[i, k] = np.mean((y_test - y_pred_gb) ** 2)*scaler.scale_**2
    #     right[i, k], wrong[i, k] = variable_selection(gb, 5, False)

    print(f"Sigma {sigmas[i]} cost time: %4f " % (time() - t))

model_res = comm.gather(models, root=0)
err_res = comm.gather(test_err, root=0)
if rank == 0:
    models = np.concatenate(model_res, axis=1)
    test_err = np.concatenate(err_res, axis=1)

    print(f"Total cost time: %4f " % (time()-t))
    with open(f'DL/exp_Friedman_NN_n{n}nrep{nrep}_MPI.pkl', 'wb') as output:
        dump({'models': models, 'lams': lams, 'sigmas': sigmas, 'test_err': test_err}, output)
    # with open(f'DL/exp2n{n}sigma{sigma}eff{eff_p}r{r}.pkl', 'wb') as output:
    #     dump({'models': models, 'w': w, 'a': a, 'lams': lams, 'ps': ps, 'test_err': test_err}, output)
