from DL.utils import *
from pickle import dump, load
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import pandas as pd
from time import time
from mpi4py import MPI


comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.float64)
np.set_printoptions(precision=2)
eps = 1e-4
scaler = StandardScaler()

# load data
kind = 'UJI'
df = pd.read_csv('UJIndoorLoc/trainingData.csv')
lst = ['FLOOR', 'BUILDINGID', 'SPACEID', 'RELATIVEPOSITION', 'USERID', 'PHONEID']
drop_lst = ['LATITUDE', 'TIMESTAMP']
df[lst] = df[lst].astype('object')
df = df.drop(drop_lst, axis=1)
df = pd.get_dummies(df, drop_first=True)
names = df.columns.to_numpy().tolist()
names.remove('LONGITUDE')

n, p = 2000, df.shape[1]-1
n_test = 1000
eff_p = None
nrep = 10
batch_size = 50
r = 100
raw = False

# kind = 'BGS'
# df = pd.read_csv('BGSboy.csv', index_col=0)
# # df = df[['WT2', 'HT2', 'WT9', 'HT9', 'LG9', 'ST18', 'HT18']]
# df = df.drop(['Soma'], axis=1)
# names = df.columns.to_numpy().tolist()
# names.remove('BMI18')
# X, y = df.drop('BMI18', axis=1).to_numpy(), df['BMI18'].to_numpy()

# n, p = 44, df.shape[1]-1
# n_test = 22
# eff_p = None
# nrep = 10
# batch_size = 10
# r = 20
# raw = False

test_err = np.zeros(nrep)
models = np.empty(nrep, dtype=object)
importance = np.zeros(p)

t = time()
for k in range(nrep):
    # For BGS dataset
    # df = shuffle(df)
    # X, y = df.drop('BMI18', axis=1).to_numpy(), df['BMI18'].to_numpy()
    # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
    # X_train = scaler.fit_transform(X_train)
    # X_test = scaler.transform(X_test)
    # y_train = scaler.fit_transform(y_train.reshape((-1, 1))).reshape(-1)
    # y_test = scaler.transform(y_test.reshape((-1, 1))).reshape(-1)

    # For UJI dataset
    dt = shuffle(df)[:3000]
    X, y = dt.drop('LONGITUDE', axis=1).to_numpy().astype('float'), dt['LONGITUDE'].to_numpy().astype('float')
    X = scaler.fit_transform(X)
    y = scaler.fit_transform(y.reshape((-1, 1))).reshape(-1)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1 / 3, random_state=2)
    # LASSO
    # lasso_cv = LassoCV(cv=5).fit(X_train, y_train)
    # lasso = Lasso(alpha=lasso_cv.alpha_)
    # y_pred_lasso = lasso.fit(X_train, y_train).predict(X_test)
    # test_err[k] = np.mean((y_test - y_pred_lasso) ** 2)*scaler.scale_**2
    # fi = variable_selection(lasso, plot=False)

    # OMP
    # omp_cv = OrthogonalMatchingPursuitCV(cv=5).fit(X_train, y_train)
    # omp = OrthogonalMatchingPursuit(n_nonzero_coefs=omp_cv.n_nonzero_coefs_)
    # y_pred_omp = omp.fit(X_train, y_train).predict(X_test)
    # test_err[k] = np.mean((y_test - y_pred_omp) ** 2)*scaler.scale_**2
    # fi = variable_selection(omp, plot=False, raw=raw)

    # RF
    # rf = RandomForestRegressor(n_estimators=40)
    # y_pred_rf = rf.fit(X_train, y_train).predict(X_test)
    # test_err[k] = np.mean((y_test - y_pred_rf) ** 2)*scaler.scale_**2
    # fi = variable_selection(rf, plot=False, raw=raw)

    # GB
    # gb = GradientBoostingRegressor(n_estimators=60)
    # y_pred_gb = gb.fit(X_train, y_train).predict(X_test)
    # test_err[k] = np.mean((y_test - y_pred_gb) ** 2) * scaler.scale_ ** 2
    # fi = variable_selection(gb, plot=False, raw=raw)

    # NN
    lam = 0.005
    train_loader, test_loader, _ = gen_data_loader_from_data(X_train, y_train, X_test, y_test, batch_size)
    train_err, err, model = run(train_loader, test_loader, FC_no_bias(p+1, r, 1),
                                     lam=lam, lr=0.005, num_epochs=800, verbose=False)
    train_err, err, model = run(train_loader, test_loader, model,
                                     lam=0.005, lr=0.005, num_epochs=200, verbose=False)
    test_err[k] = err * scaler.scale_ ** 2
    fi = variable_selection(model, plot=False, exclude=0, raw=raw)

    models[k] = model
    # importance[fi.argsort()[-10:]] += 1  # cutoff is top 10
    importance += fi  # data-driven cutoff

model_res = comm.gather(models, root=0)
err_res = comm.gather(test_err, root=0)
fi_res = comm.gather(importance, root=0)
if rank == 0:
    models = np.concatenate(model_res)
    test_err = np.concatenate(err_res)
    importance = np.sum(fi_res, axis=0)
    print(paste(test_err.mean(), test_err.std()), end='&')
    print()
    print(f"Total cost time: %4f " % (time()-t))
    with open(f'DL/exp_{kind}_NN_n{n}nrep{nrep}_MPI.pkl', 'wb') as output:
        dump({'test_err': test_err, 'if': importance, 'models': models}, output)
