import matplotlib.pyplot as plt
import math
import numpy as np
import pandas as pd
from scipy.stats import uniform, loguniform
from time import time
import argparse

from sklearn.metrics import accuracy_score, mean_squared_error

from learners import LeastSquaresLearner
from rkhs_weightings import RKHSWeightingRandomSearchCV, RKHSWeightingClassifier, RKHSWeightingRegressor
from models import RWRelu
from loss import MSE
from rks import _RKSEstimator, RKSClassifier, RKSRegressor

from dataset_loaders import *
from expe_utils import ensure_folder_exists, get_color_list, save_run_time, make_legend
from expe_utils import RESULTS_FOLDER, FIGURES_FOLDER, MARKERS
from expe_params import RWRELU_CV_PARAMS, RKS_CV_PARAMS, RW_LEARNER_CV_PARAMS
import visuals # Must be at the end of the imports for some reason


parser = argparse.ArgumentParser(description='Experiment for comparing RKHS weightings and random kitchen sinks for various numbers of random features.')
parser.add_argument('--norun', action='store_true', help='Do not run the experiment. Only generate the figures.')
parser.add_argument('--final', action='store_true', help='Do the long experiment.')
parser.add_argument('--noshow', action='store_true', help='Do not show the plots.')
args = parser.parse_args()

FINAL = args.final
TEST = not FINAL
RUN_EXPE = not args.norun
GENERATE_FIGURES = True

RNG = np.random.default_rng(0)

FILENAME = 'few-features'

if TEST:
    N_RUNS = 10
    N_CV_ITER = 50
    N_ITERS = [int(x) for x in np.logspace(math.log10(10), math.log10(1000), num=20)]
    FILENAME += '-test'

else: 
    N_RUNS = 10
    N_CV_ITER = 50
    N_ITERS = [int(x) for x in np.logspace(math.log10(10), math.log10(1000), num=20)]

CLASSIFICATION_LOADERS = [
                   BreastCancerLoader(final=FINAL), 
                   AdultsLoader(final=FINAL), ###################
                   MNISTLoader(final=FINAL, digits=[1,7]), 
                   SkinSegmentationLoader(final=FINAL),
                   BankMarketingLoader(final=FINAL),##################
                   MagicGammaTelescopeLoader(final=FINAL),
                   PhishingLoader(final=FINAL)
                   ]

REGRESSION_KWARGS = {'final' : FINAL, 'scale_x' : True, 'scale_y' : True}
REGRESSION_LOADERS = [
                DiabetesLoader(**REGRESSION_KWARGS),
                CaliforniaHousingLoader(**REGRESSION_KWARGS), ##################
                ConcreteLoader(**REGRESSION_KWARGS),
                WineLoader(**REGRESSION_KWARGS),
                ConductivityLoader(**REGRESSION_KWARGS),
                AbaloneLoader(**REGRESSION_KWARGS), ####################
            ]



def do_one_expe(estimator_class, dataset_loader):
    print('Running experiment for estimator:', estimator_class.__name__, 'on dataset:', dataset_loader.name)
    results_df = pd.DataFrame({})
    X_train, X_test, y_train, y_test = dataset_loader.load()
    for seed in range(N_RUNS):
        start = time() 
        for n_iter in N_ITERS:
            partial_results = {}
            if issubclass(estimator_class, _RKSEstimator):
                cv = RKHSWeightingRandomSearchCV(estimator_class, 
                                            model_class=RWRelu, 
                                            rng=seed,
                                            learner_param_grid={**RKS_CV_PARAMS, 
                                                                'n_neurons' : [n_iter]},
                                            model_param_grid=RWRELU_CV_PARAMS,
                                            verbose=False,
                                            n_iter=N_CV_ITER)
            else:
                cv = RKHSWeightingRandomSearchCV(estimator_class, 
                                            LeastSquaresLearner, 
                                            RWRelu, 
                                            rng=seed,
                                            learner_param_grid={**RW_LEARNER_CV_PARAMS, 
                                                                'n_iter' : [n_iter]},
                                            model_param_grid=RWRELU_CV_PARAMS,
                                            verbose=False,
                                            n_iter=N_CV_ITER)
            cv.fit(X_train, y_train)
            estimator = cv.best_estimator_
            partial_results['algorithm'] = estimator_class.__name__
            partial_results['dataset'] = dataset_loader.name
            partial_results['T'] = n_iter
            partial_results.update(cv.best_learner_params_)
            partial_results.update(cv.best_model_params_)
            partial_results['Training time'] = cv.refit_time_
            partial_results['Training MSE'] = mean_squared_error(y_train, estimator.raw_output(X_train)) 
            partial_results['Test MSE'] = mean_squared_error(y_test, estimator.raw_output(X_test)) 
            if dataset_loader.name in [d.name for d in CLASSIFICATION_LOADERS]:
                partial_results['Training error'] = 1 - accuracy_score(y_train, estimator.predict(X_train)) 
                partial_results['Test error'] = 1 - accuracy_score(y_test, estimator.predict(X_test)) 
            else:
                partial_results['Training error'] = ''
                partial_results['Test error'] = ''
            partial_results_df = pd.DataFrame(partial_results, index=[0])
            results_df = pd.concat((results_df, partial_results_df))
        elapsed = time() - start
        print("Finished testing seed {} of {} for {} in {} seconds.".format(seed+1, N_RUNS, estimator_class.__name__, elapsed))

    return results_df

def get_results_df_path():
    ensure_folder_exists(RESULTS_FOLDER)
    return RESULTS_FOLDER+FILENAME+'.csv'

def get_figure_path():
    ensure_folder_exists(FIGURES_FOLDER)
    return FIGURES_FOLDER+FILENAME

def get_results_df():
    df = get_raw_df()
    return clean_df(df)

def get_raw_df():
    return pd.read_csv(get_results_df_path(), index_col=0)

def clean_df(df: pd.DataFrame):
    df.replace('SFGDLearner', 'SFGD', inplace=True)
    df.replace('LeastSquaresLearner', 'Least squares fit', inplace=True)
    df.replace('LassoLearner', 'Lasso fit', inplace=True)
    df.replace('RKSClassifier', 'RKS', inplace=True)
    df.replace('RKSRegressor', 'RKS', inplace=True)
    df.replace('RKHSWeightingClassifier', 'RKHS weighting', inplace=True)
    df.replace('RKHSWeightingRegressor', 'RKHS weighting', inplace=True)
    df.replace('OptimalStepsizeLearner', 'Optimal stepsize descent', inplace=True) 
    return df

def expe():
    results_df = pd.DataFrame({})

    for dataset_loader in CLASSIFICATION_LOADERS:
        partial_results_df = do_one_expe(RKHSWeightingClassifier, dataset_loader)
        results_df = pd.concat((results_df, partial_results_df), ignore_index=True)
        results_df.to_csv(get_results_df_path())
        partial_results_df = do_one_expe(RKSClassifier, dataset_loader)
        results_df = pd.concat((results_df, partial_results_df), ignore_index=True)
        results_df.to_csv(get_results_df_path())

    for dataset_loader in REGRESSION_LOADERS:
        partial_results_df = do_one_expe(RKHSWeightingRegressor, dataset_loader)
        results_df = pd.concat((results_df, partial_results_df), ignore_index=True)
        results_df.to_csv(get_results_df_path())
        partial_results_df = do_one_expe(RKSRegressor, dataset_loader)
        results_df = pd.concat((results_df, partial_results_df), ignore_index=True)
        results_df.to_csv(get_results_df_path())


def make_figures():
    results_df = get_results_df()
    algos = results_df['algorithm'].unique()
    datasets = results_df['dataset'].unique()
    colors = get_color_list(len(algos))
    
    for dataset in datasets:
        if dataset in [d.name for d in CLASSIFICATION_LOADERS]:
            metric = 'Test error'
            # metric = 'Training error'
            dtype = 'classification'
        else:
            metric = 'Test MSE'
            # metric = 'Training MSE'
            dtype = 'regression'
        mask = results_df['dataset'] == dataset
        group_columns = ['algorithm', 'T', 'dataset']
        

        # Only keep numeric columns for mean and std
        partial_df = results_df[mask][group_columns + [metric]]
        mean_df = partial_df.groupby(by=group_columns, as_index=False).mean()
        std_df = partial_df.groupby(by=group_columns, as_index=False).std()
        plt.figure(figsize=(3,3))
        for i, algo in enumerate(algos):
            mean_algo_df = mean_df.loc[(mean_df['algorithm'] == algo) & (mean_df['dataset'] == dataset)]
            std_algo_df = std_df.loc[(std_df['algorithm'] == algo) & (std_df['dataset'] == dataset)]
            x = mean_algo_df['T']
            y = mean_algo_df[metric]
            error = std_algo_df[metric]
            plt.plot(x, y, label=algo, color=colors[i], marker=MARKERS[i])
            plt.errorbar(x, y, yerr=error, color=colors[i], marker=MARKERS[i], alpha=0.5)
        # plt.xlabel('T')
        # plt.ylabel(metric)
        plt.xscale('log')
        plt.title(dataset)
        # plt.legend()
        plt.tight_layout()
        plt.savefig(f'{get_figure_path()}-{dtype}-{dataset}.pdf')
        handles, labels = plt.gca().get_legend_handles_labels()
        if not args.noshow:
            plt.show()
        plt.close('all')

    make_legend(handles, labels, get_figure_path())

    
if __name__ == '__main__':
    if RUN_EXPE:
        start = time()
        expe()
        save_run_time(RESULTS_FOLDER, FILENAME, start)
    if GENERATE_FIGURES:
        make_figures()

