import matplotlib.pyplot as plt
import math
import numpy as np
import pandas as pd
from scipy.stats import uniform, loguniform
from time import time
import argparse

from sklearn.metrics import accuracy_score, mean_squared_error

from learners import LassoLearner
from rkhs_weightings import RKHSWeightingRandomSearchCV, RKHSWeightingClassifier, RKHSWeightingRegressor
from models import RWRelu
from loss import MSE
from rks import _RKSEstimator, RKSClassifier, RKSRegressor

from dataset_loaders import *
from expe_utils import ensure_folder_exists, get_color_list, save_run_time, make_legend
from expe_utils import RESULTS_FOLDER, FIGURES_FOLDER, MARKERS
from expe_params import RWRELU_CV_PARAMS
import visuals # Must be at the end of the imports for some reason


parser = argparse.ArgumentParser(description='Experiment comparing RKHS weightings and random kitchen sinks using their respective Lasso fit.')
parser.add_argument('--norun', action='store_true', help='Do not run the experiment. Only generate the figures.')
parser.add_argument('--final', action='store_true', help='Do the long experiment.')
parser.add_argument('--noshow', action='store_true', help='Do not show the plots.')
args = parser.parse_args()

FINAL = args.final
TEST = not FINAL
RUN_EXPE = not args.norun
GENERATE_FIGURES = True

RNG = np.random.default_rng(0)

FILENAME = 'few-features-lasso'

CLASSIFICATION_LOADERS = [
                   BreastCancerLoader(final=FINAL), 
                   AdultsLoader(final=FINAL), ###################
                   MNISTLoader(final=FINAL, digits=[1,7]), 
                   SkinSegmentationLoader(final=FINAL),
                   BankMarketingLoader(final=FINAL),##################
                   MagicGammaTelescopeLoader(final=FINAL),
                   PhishingLoader(final=FINAL)
                   ]

REGRESSION_KWARGS = {'final' : FINAL, 'scale_x' : True, 'scale_y' : True}
REGRESSION_LOADERS = [
                DiabetesLoader(**REGRESSION_KWARGS),
                CaliforniaHousingLoader(**REGRESSION_KWARGS), ##################
                ConcreteLoader(**REGRESSION_KWARGS),
                WineLoader(**REGRESSION_KWARGS),
                ConductivityLoader(**REGRESSION_KWARGS),
                AbaloneLoader(**REGRESSION_KWARGS), ####################
            ]

RKS_REGULARISATION_DIST = loguniform(1e-4, 1e4)
RW_REGULARISATION_DIST = loguniform(1e-8, 1e0)

if TEST:
    N_RUNS = 100
    N_CV_ITER = 50
    N_ITER = 1000
    FILENAME += '-test'
else: 
    N_RUNS = 100
    N_CV_ITER = 50
    N_ITER = 1000

def do_one_expe(estimator_class, dataset_loader):
    print('Running experiment for estimator:', estimator_class.__name__, 'on dataset:', dataset_loader.name)
    results_df = pd.DataFrame({})
    X_train, X_test, y_train, y_test = dataset_loader.load()
    for seed in range(N_RUNS):
        start = time() 
        partial_results = {}
        if issubclass(estimator_class, _RKSEstimator):
            reg = RKS_REGULARISATION_DIST.rvs(random_state=seed)
            cv = RKHSWeightingRandomSearchCV(estimator_class, 
                                        model_class=RWRelu, 
                                        rng=seed,
                                        learner_param_grid={'n_neurons' : [N_ITER],
                                                            'regularization' : [reg],
                                                            'solver' : ['lasso']},
                                        model_param_grid=RWRELU_CV_PARAMS,
                                        verbose=False,
                                        n_iter=N_CV_ITER)
        else:
            reg = RW_REGULARISATION_DIST.rvs(random_state=seed)
            cv = RKHSWeightingRandomSearchCV(estimator_class, 
                                        LassoLearner, 
                                        RWRelu, 
                                        rng=seed,
                                        learner_param_grid={'n_iter' : [N_ITER],
                                                            'regularization' : [reg]},
                                        model_param_grid=RWRELU_CV_PARAMS,
                                        verbose=False,
                                        n_iter=N_CV_ITER)
        cv.fit(X_train, y_train)
        estimator = cv.best_estimator_
        partial_results['algorithm'] = estimator_class.__name__
        partial_results['dataset'] = dataset_loader.name
        partial_results['T'] = N_ITER
        partial_results['regularization'] = reg
        if issubclass(estimator_class, _RKSEstimator):
            partial_results['Nonzero coefficients'] = estimator.get_n_centers()
        else:
            partial_results['Nonzero coefficients'] = estimator.model.get_n_centers()
        partial_results['Training time'] = cv.refit_time_
        partial_results['Training MSE'] = mean_squared_error(y_train, estimator.raw_output(X_train)) 
        partial_results['Test MSE'] = mean_squared_error(y_test, estimator.raw_output(X_test)) 
        if dataset_loader.name in [d.name for d in CLASSIFICATION_LOADERS]:
            partial_results['Training error'] = 1 - accuracy_score(y_train, estimator.predict(X_train)) 
            partial_results['Test error'] = 1 - accuracy_score(y_test, estimator.predict(X_test)) 
        else:
            partial_results['Training error'] = ''
            partial_results['Test error'] = ''
        partial_results_df = pd.DataFrame(partial_results, index=[0])
        results_df = pd.concat((results_df, partial_results_df))
        elapsed = time() - start
        print("Finished testing seed {} of {} for {} in {} seconds.".format(seed+1, N_RUNS, estimator_class.__name__, elapsed))

    return results_df

def get_results_df_path():
    ensure_folder_exists(RESULTS_FOLDER)
    return RESULTS_FOLDER+FILENAME+'.csv'

def get_figure_path():
    ensure_folder_exists(FIGURES_FOLDER)
    return FIGURES_FOLDER+FILENAME

def get_results_df():
    df = get_raw_df()
    return clean_df(df)

def get_raw_df():
    return pd.read_csv(get_results_df_path(), index_col=0)

def clean_df(df: pd.DataFrame):
    df.replace('SFGDLearner', 'SFGD', inplace=True)
    df.replace('LeastSquaresLearner', 'Least squares fit', inplace=True)
    df.replace('LassoLearner', 'Lasso fit', inplace=True)
    df.replace('RKSClassifier', 'RKS', inplace=True)
    df.replace('RKSRegressor', 'RKS', inplace=True)
    df.replace('RKHSWeightingClassifier', 'RKHS weighting', inplace=True)
    df.replace('RKHSWeightingRegressor', 'RKHS weighting', inplace=True)
    df.replace('OptimalStepsizeLearner', 'Optimal stepsize descent', inplace=True) 
    return df

def expe():
    results_df = pd.DataFrame({})

    for dataset_loader in CLASSIFICATION_LOADERS:
        partial_results_df = do_one_expe(RKHSWeightingClassifier, dataset_loader)
        results_df = pd.concat((results_df, partial_results_df), ignore_index=True)
        results_df.to_csv(get_results_df_path())
        partial_results_df = do_one_expe(RKSClassifier, dataset_loader)
        results_df = pd.concat((results_df, partial_results_df), ignore_index=True)
        results_df.to_csv(get_results_df_path())

    for dataset_loader in REGRESSION_LOADERS:
        partial_results_df = do_one_expe(RKHSWeightingRegressor, dataset_loader)
        results_df = pd.concat((results_df, partial_results_df), ignore_index=True)
        results_df.to_csv(get_results_df_path())
        partial_results_df = do_one_expe(RKSRegressor, dataset_loader)
        results_df = pd.concat((results_df, partial_results_df), ignore_index=True)
        results_df.to_csv(get_results_df_path())

def make_figures():
    results_df = get_results_df()
    results_df = results_df[results_df['Nonzero coefficients'] >= 1]
    datasets = results_df['dataset'].unique()
    algos = results_df['algorithm'].unique()
    colors = get_color_list(len(algos))
    
    for dataset in datasets:
        if dataset in [d.name for d in CLASSIFICATION_LOADERS]:
            metric = 'Test error'
            dtype = 'classification'
        else:
            metric = 'Test MSE'
            dtype = 'regression'
        plt.figure(figsize=(3,3))
        plot_df = results_df.loc[results_df['dataset'] == dataset].copy()
        for i, algo in enumerate(algos):
            algo_df = plot_df[plot_df['algorithm'] == algo]
            plt.scatter(algo_df['Nonzero coefficients'], algo_df[metric], color=colors[i], marker=MARKERS[i], label=algo)
        plt.title(f'{dataset}')
        plt.tight_layout()
        plt.xscale('log')
        plt.savefig(f'{get_figure_path()}-{dtype}-{dataset}.pdf')
        handles, labels = plt.gca().get_legend_handles_labels()
        if not args.noshow:
            plt.show()
        plt.close('all')
    make_legend(handles, labels, get_figure_path())
    
if __name__ == '__main__':
    if RUN_EXPE:
        start = time()
        expe()
        save_run_time(RESULTS_FOLDER, FILENAME, start)
    if GENERATE_FIGURES:
        make_figures()

