import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
import warnings

from funcs.CP import SplitCP, RO_StabCP, LOO_StabCP
from funcs.utils import get_data
from funcs.post import post_nn, figure_nn

np.random.seed(42)
warnings.filterwarnings("ignore")

T = 100
m = 100

R_boston = 30
R_diabetes = 30

X_boston_full, Y_boston_full = get_data('boston')
X_diabetes_full, Y_diabetes_full = get_data('diabetes')

save_Boston = np.zeros((T, 3, 3))
save_Diabetes = np.zeros((T, 3, 3))

params_nn_boston = {'hidden': [20], 'activation':'logistic', 'lr': 0.001, 'epochs': R_boston}
params_nn_diabetes = {'hidden': [20], 'activation':'logistic', 'lr': 0.001, 'epochs': R_diabetes}
# params_nn_boston = {'hidden': [10, 5], 'activation':'logistic', 'lr': 0.001, 'epochs': R_boston}
# params_nn_diabetes = {'hidden': [10, 5], 'activation':'logistic', 'lr': 0.001, 'epochs': R_diabetes}

print('Starting Boston...')
for t in range(T):
    if (t + 1) % 10 == 0 or t < 10:
        print(f'Iteration {t + 1} out of {T}')
    X_boston, X_boston_test, Y_boston, Y_boston_test = train_test_split(X_boston_full, Y_boston_full, test_size=int(m))
    D_boston, D_boston_test = (X_boston, Y_boston), (X_boston_test, Y_boston_test)

    save_Boston[t, 0, :] = SplitCP(D_boston, D_boston_test, 'NN', params_nn = params_nn_boston)
    save_Boston[t, 1, :] = RO_StabCP(D_boston, D_boston_test, 'NN', params_nn = params_nn_boston)
    save_Boston[t, 2, :] = LOO_StabCP(D_boston, D_boston_test, 'NN', params_nn = params_nn_boston)

print('Starting Diabetes...')
for t in range(T):
    if (t + 1) % 10 == 0 or t < 10:
        print(f'Iteration {t + 1} out of {T}')
    X_diabetes, X_diabetes_test, Y_diabetes, Y_diabetes_test = train_test_split(X_diabetes_full, Y_diabetes_full, test_size=int(m))
    D_diabetes, D_diabetes_test = (X_diabetes, Y_diabetes), (X_diabetes_test, Y_diabetes_test)
    
    save_Diabetes[t, 0, :] = SplitCP(D_diabetes, D_diabetes_test, 'NN', params_nn = params_nn_diabetes)
    save_Diabetes[t, 1, :] = RO_StabCP(D_diabetes, D_diabetes_test, 'NN', params_nn = params_nn_diabetes)
    save_Diabetes[t, 2, :] = LOO_StabCP(D_diabetes, D_diabetes_test, 'NN', params_nn = params_nn_diabetes)
    
saves = [save_Boston, save_Diabetes]
results = post_nn(saves, T)

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['axes.titlesize'] = 17
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 11.3
sns.set_palette("Set2")
figure_nn(results)

print('Results from Real Data Examples:')
print(results.groupby(['Dataset','Method']).agg({'Coverage': ['mean', 'std'], 'Length': ['mean', 'std'], 'Time': ['mean', 'std']}).round(3))
