import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

from funcs.CP import SplitCP, RO_StabCP, LOO_StabCP
from funcs.utils import DGP_lin, DGP_nonlin, kernel_matrix
from funcs.post import post_kernel, figure_kernel

np.random.seed(42)
warnings.filterwarnings("ignore")

T = 100
n, m, d = 100, 100, 100
R = 15

save_linear_no = np.zeros((T, 3, 3))
save_linear_RBF = np.zeros((T, 3, 3))
save_linear_poly = np.zeros((T, 3, 3))

save_nonlinear_no = np.zeros((T, 3, 3))
save_nonlinear_RBF = np.zeros((T, 3, 3))
save_nonlinear_poly = np.zeros((T, 3, 3))

for t in range(T):
    if (t + 1) % 10 == 0 or t < 10:
        print(f'Iteration {t+1} out of {T}')
    
    for j in range(2):
        if j == 0:
            D, D_test = DGP_lin(n, m, d, cov = 'ar1')
        else:
            D, D_test = DGP_nonlin(n, m, d, cov = 'ar1')
        X, X_test = D[0], D_test[0]
        K_RBF = kernel_matrix(X, X, kernel = 'RBF', params_RBF = {'gamma': 0.1})
        K_test_RBF = kernel_matrix(X_test, X, kernel = 'RBF', params_RBF = {'gamma': 0.1})
        D_RBF, D_RBF_test = (K_RBF, D[1]), (K_test_RBF, D_test[1])
        
        K_poly = kernel_matrix(X, X, kernel = 'poly', params_poly = {'degree': 2, 'coef0': 1})
        K_test_poly = kernel_matrix(X_test, X, kernel = 'poly', params_poly = {'degree': 2, 'coef0': 1})
        D_poly, D_poly_test = (K_poly, D[1]), (K_test_poly, D_test[1])

        params_sgd = {'shuffles': [np.random.permutation(n + 1) for _ in range(R)], 'lr': 0.001, 'epochs': R}
        params_sgd_kernel = {'shuffles': [np.random.permutation(n + 1) for _ in range(R)], 'lr': 0.00001, 'epochs': R}
        
        if j == 0:
            save_linear_no[t, 0, :] = SplitCP(D, D_test, 'SGD', params_sgd=params_sgd)
            save_linear_no[t, 1, :] = RO_StabCP(D, D_test, 'SGD', params_sgd=params_sgd)
            save_linear_no[t, 2, :] = LOO_StabCP(D, D_test, 'SGD', params_sgd=params_sgd)
            
            save_linear_RBF[t, 0, :] = SplitCP(D_RBF, D_RBF_test, 'SGD', params_sgd=params_sgd_kernel)
            save_linear_RBF[t, 1, :] = RO_StabCP(D_RBF, D_RBF_test, 'SGD', params_sgd=params_sgd_kernel, kernel = True)
            save_linear_RBF[t, 2, :] = LOO_StabCP(D_RBF, D_RBF_test, 'SGD', params_sgd=params_sgd_kernel, kernel = True)
            
            save_linear_poly[t, 0, :] = SplitCP(D_poly, D_poly_test, 'SGD', params_sgd=params_sgd_kernel)
            save_linear_poly[t, 1, :] = RO_StabCP(D_poly, D_poly_test, 'SGD', params_sgd=params_sgd_kernel, kernel = True)
            save_linear_poly[t, 2, :] = LOO_StabCP(D_poly, D_poly_test, 'SGD', params_sgd=params_sgd_kernel, kernel = True)

        else:
            save_nonlinear_no[t, 0, :] = SplitCP(D, D_test, 'SGD', params_sgd=params_sgd)
            save_nonlinear_no[t, 1, :] = RO_StabCP(D, D_test, 'SGD', params_sgd=params_sgd)
            save_nonlinear_no[t, 2, :] = LOO_StabCP(D, D_test, 'SGD', params_sgd=params_sgd)
            
            save_nonlinear_RBF[t, 0, :] = SplitCP(D_RBF, D_RBF_test, 'SGD', params_sgd=params_sgd_kernel)
            save_nonlinear_RBF[t, 1, :] = RO_StabCP(D_RBF, D_RBF_test, 'SGD', params_sgd=params_sgd_kernel, kernel = True)
            save_nonlinear_RBF[t, 2, :] = LOO_StabCP(D_RBF, D_RBF_test, 'SGD', params_sgd=params_sgd_kernel, kernel = True)
            
            save_nonlinear_poly[t, 0, :] = SplitCP(D_poly, D_poly_test, 'SGD', params_sgd=params_sgd_kernel)
            save_nonlinear_poly[t, 1, :] = RO_StabCP(D_poly, D_poly_test, 'SGD', params_sgd=params_sgd_kernel, kernel = True)
            save_nonlinear_poly[t, 2, :] = LOO_StabCP(D_poly, D_poly_test, 'SGD', params_sgd=params_sgd_kernel, kernel = True)

saves = [save_linear_no, save_linear_RBF, save_linear_poly, save_nonlinear_no, save_nonlinear_RBF, save_nonlinear_poly]
results = post_kernel(saves, T)

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['axes.titlesize'] = 17
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 11.3
sns.set_palette("Set2")
figure_kernel(results)

print('Results from Simulation:')
print(results.groupby(['DGP', 'Kernel', 'Method']).agg({'Coverage': ['mean', 'std'], 'Length': ['mean', 'std'], 'Time': ['mean', 'std']}).round(3))