import argparse
import os
from time import time

import numpy as np
import pandas as pd
from scipy.stats import uniform, loguniform
from sklearn.ensemble import AdaBoostClassifier, AdaBoostRegressor
from sklearn.metrics import accuracy_score, mean_squared_error, r2_score
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.svm import SVC, SVR
from sklearn.base import BaseEstimator

from dataset_loaders import *
from expe_utils import RESULTS_FOLDER, TABLES_FOLDER, save_run_time
from expe_params import RWSIGN_CV_PARAMS, RWRELU_CV_PARAMS, RWEXPSIGN_CV_PARAMS, RWEXPRELU_CV_PARAMS, RWSTUMPS_CV_PARAMS, RW_LEARNER_CV_PARAMS, RKS_CV_PARAMS
from learners import *
from base_model import *
from models import RWExpRelu, RWSign, RWRelu, RWExpSign, RWStumps
from rks import RKSClassifier, RKSRegressor
from rkhs_weightings import RKHSWeightingRandomSearchCV, RKHSWeightingClassifier, RKHSWeightingRegressor

parser = argparse.ArgumentParser(description='Experiment for comparing RKHS weightings and random kitchen sinks to other algorithms.')
parser.add_argument('--norun', action='store_true', help='Do not run the experiment. Only generate the figures.')
parser.add_argument('--final', action='store_true', help='Do the long experiment.')
args = parser.parse_args()

FINAL = args.final
TEST = not args.final
RUN_EXPE = not args.norun
RUN_SFGD = True
RUN_OTHER = True

GENERATE_TABLES = True
PRINT_BEST_PARAMS = False
FILENAME = 'sota'
RNG = np.random.default_rng(0)

SVM_PARAMS = {'C' : loguniform(0.001, 1000), 'gamma' : loguniform(0.001, 1000)}
AB_PARAMS = {'n_estimators' : [10, 25, 50, 100, 150, 200, 250, 500]}
if TEST:
    N_RUNS = 10
    RW_LEARNER_CV_PARAMS['n_iter'] = [2000]
    RKS_CV_PARAMS['n_neurons'] = [2000]
    N_CV_ITER = 50
    FILENAME = FILENAME + '-test'
else:
    N_RUNS = 10
    RW_LEARNER_CV_PARAMS['n_iter'] = [2000]
    RKS_CV_PARAMS['n_neurons'] = [2000]
    N_CV_ITER = 50

def get_updated_filename(task='classification'):
    if task == 'classification':
        return FILENAME + '-classification'
    elif task == 'regression':
        return FILENAME + '-regression'
    
def get_table_path(filename):
    return TABLES_FOLDER + filename + '.tex'

def get_table_with_std_path(filename):
    return TABLES_FOLDER + filename + '-with-std.tex'

CLASSIFICATION_LOADERS = [
                   AdultsLoader(final=FINAL), 
                   BreastCancerLoader(final=FINAL), 
                   BankMarketingLoader(final=FINAL),
                   MNISTLoader(final=FINAL, digits=[1,7]), 
                   PhishingLoader(final=FINAL),
                   SkinSegmentationLoader(final=FINAL),
                   MagicGammaTelescopeLoader(final=FINAL)
                   ]

RW_CLASSIFICATION_INSTANTIATIONS = [
    [RWExpRelu, RWRelu, RWStumps], # Adults
    [RWSign, RWExpSign], # Breast Cancer
    [RWExpRelu, RWRelu, RWStumps], # Bank Marketing
    [RWExpRelu, RWRelu, RWStumps], # MNIST
    [RWExpRelu, RWRelu],  # Phishing
    [RWExpRelu, RWRelu], # Skin Segmentation
    [RWExpRelu, RWRelu] # Magic Gamma Telescope
]

RKS_CLASSIFICATION_INSTANTIATIONS = [
    [RWRelu, RWStumps], # Adults
    [RWSign], # Breast Cancer
    [RWRelu, RWStumps], # Bank Marketing
    [RWRelu, RWStumps], # MNIST
    [RWRelu],  # Phishing
    [RWRelu], # Skin Segmentation
    [RWRelu] # Magic Gamma Telescope
]

REGRESSION_KWARGS = {'final' : FINAL, 'scale_x' : True, 'scale_y' : True}
REGRESSION_LOADERS = [
                AbaloneLoader(**REGRESSION_KWARGS),
                CaliforniaHousingLoader(**REGRESSION_KWARGS),
                ConcreteLoader(**REGRESSION_KWARGS),
                ConductivityLoader(**REGRESSION_KWARGS),
                DiabetesLoader(**REGRESSION_KWARGS),
                WineLoader(**REGRESSION_KWARGS)
            ]

RW_REGRESSION_INSTANTIATIONS = [
    [RWExpRelu, RWRelu, RWSign], # Abalone
    [RWExpRelu, RWRelu, RWStumps], # Housing
    [RWExpRelu, RWRelu, RWStumps], # Concrete
    [RWExpRelu, RWRelu, RWStumps], # Conductivity
    [RWExpRelu, RWRelu, RWStumps], # Diabetes
    [RWSign, RWExpSign] # Wine
]

RKS_REGRESSION_INSTANTIATIONS = [
    [RWRelu, RWSign], # Abalone
    [RWRelu, RWStumps], # Housing
    [RWRelu, RWStumps], # Concrete
    [RWRelu, RWStumps], # Conductivity
    [RWRelu, RWStumps], # Diabetes
    [RWSign] # Wine
]

class Algorithm():
    def __init__(self, clf: BaseEstimator, params, name):
        self.clf = clf
        self.params = params
        self.name = name

class TimeTracker():
    def __init__(self, n_total_fits):
        self.start = time()
        self.n_completed_fits = 0
        self.n_total_fits = n_total_fits

    def update(self):
        self.n_completed_fits += 1
        time_so_far = time() - self.start
        time_remaining = time_so_far / self.n_completed_fits * (self.n_total_fits - self.n_completed_fits)
        print("Fit {} of {} done. Elapsed time : {} hours.".format(self.n_completed_fits, self.n_total_fits, time_so_far/3600))
        print("Estimated time remaining : {} hours.".format(time_remaining/3600))


LEARNERS = [LeastSquaresLearner]
INSTANTIATIONS_PARAMS_DICT = {
    'RWSign': RWSIGN_CV_PARAMS,
    'RWExpRelu': RWEXPRELU_CV_PARAMS,
    'RWRelu': RWRELU_CV_PARAMS,
    'RWExpSign': RWEXPSIGN_CV_PARAMS,
    'RWStumps': RWSTUMPS_CV_PARAMS
}

CLASSIFICATION_ALGOS = []
CLASSIFICATION_ALGOS.append(Algorithm(AdaBoostClassifier(), AB_PARAMS, "AdaBoost"))  
CLASSIFICATION_ALGOS.append(Algorithm(SVC(), SVM_PARAMS, "SVM"))  

REGRESSION_ALGOS = []
REGRESSION_ALGOS.append(Algorithm(AdaBoostRegressor(), AB_PARAMS, "AdaBoost"))  
REGRESSION_ALGOS.append(Algorithm(SVR(), SVM_PARAMS, "SVM"))  

def get_n_total_fits(task):
    n_total_fits = 0
    # SVM does not require multiple fits, thus the awkward formula for number of fits
    if task == 'classification':
        for (dataset_loader, rw_instantiations, rks_instantiations) in zip(CLASSIFICATION_LOADERS, RW_CLASSIFICATION_INSTANTIATIONS, RKS_CLASSIFICATION_INSTANTIATIONS):
            n_total_fits += len(rw_instantiations) * N_RUNS * len(LEARNERS) # RKHS Weightings
            n_total_fits += len(rks_instantiations) * N_RUNS # RKS
            n_total_fits += (len(CLASSIFICATION_ALGOS)-1) * N_RUNS # AdaBoost
            n_total_fits += 1 # SVM
    elif task == 'regression':
        for (dataset_loader, rw_instantiations, rks_instantiations) in zip(REGRESSION_LOADERS, RW_REGRESSION_INSTANTIATIONS, RKS_REGRESSION_INSTANTIATIONS):
            n_total_fits += len(rw_instantiations) * N_RUNS * len(LEARNERS) # RKHS Weightings
            n_total_fits += len(rks_instantiations) * N_RUNS # RKS
            n_total_fits += (len(REGRESSION_ALGOS)-1) * N_RUNS # AdaBoost
            n_total_fits += 1 # SVM
    return n_total_fits

def evaluate_model(model, loader, task):
    """Evaluates the model on the dataset loader.

    Returns the train and test errors.
    """
    results = {}
    X_train, X_test, y_train, y_test = loader.load()
    start_time = time()
    train_pred = model.predict(X_train)
    test_pred = model.predict(X_test)
    results['Inference time'] = [time() - start_time]
    if task == 'classification':
        results['Train error'] = [round(1 - accuracy_score(y_train, train_pred), 3)] 
        results['Test error'] = [round(1 - accuracy_score(y_test, test_pred), 3)] 
    elif task == 'regression':
        results['Train MSE'] = [mean_squared_error(y_train, train_pred)] 
        results['Test MSE'] = [mean_squared_error(y_test, test_pred)] 
        results['Train $R^2$'] = [r2_score(y_train, train_pred)] 
        results['Test $R^2$'] = [r2_score(y_test, test_pred)] 
    return results

def save_results_to_csv(results, path):
    df = pd.DataFrame.from_dict(results)
    if os.path.exists(path):
        previous_df = pd.read_csv(path, index_col=0)
        df = pd.concat((previous_df, df), sort=True)
        df.index = pd.Index(np.arange(len(df.index)))
    df.to_csv(path)

def run_one_sfgd_experiment(dataset_loader, learner_class, model_class, path, task):
    """Performs one experiment for an SFGD instantiation.

    One experiment means crossvalidating to find the best hyperparameters
    among algorithm.params, then calculating various metrics on the
    final model. Results are saved in the results folder.
    """
    if task == 'classification':
        RKHS_CLASS = RKHSWeightingClassifier
    elif task == 'regression':
        RKHS_CLASS = RKHSWeightingRegressor
    results = {}
    results['Dataset'] = [dataset_loader.name]
    results['Instantiation'] = [model_class.__name__]
    results['Algorithm'] = ['RKHS Weighting']

    X_train, _, y_train, _ = dataset_loader.load()
    
    rw_cv = RKHSWeightingRandomSearchCV(RKHS_CLASS, 
                                        learner_class, 
                                        model_class, 
                                        RW_LEARNER_CV_PARAMS, 
                                        INSTANTIATIONS_PARAMS_DICT[model_class.__name__],
                                        folds=5, 
                                        n_iter=N_CV_ITER,
                                        rng=RNG,
                                        verbose=False)
    rw_cv.fit(X_train, y_train)
    results['Train time (s)'] = rw_cv.refit_time_

    rw_clf = rw_cv.best_estimator_
    for key in rw_cv.best_learner_params_:
        results[key] = rw_cv.best_learner_params_[key]
    for key in rw_cv.best_model_params_:
        results[key] = rw_cv.best_model_params_[key]

    results.update(evaluate_model(rw_clf, dataset_loader, task=task))
    save_results_to_csv(results, path)

def run_one_rks_experiment(dataset_loader, model_class, path, task):
    if task == 'classification':
        RKS_CLASS = RKSClassifier
    elif task == 'regression':
        RKS_CLASS = RKSRegressor

    results = {}
    results['Dataset'] = [dataset_loader.name]
    results['Instantiation'] = [model_class.__name__]
    results['Algorithm'] = ['RKS']

    X_train, _, y_train, _ = dataset_loader.load()

    # Random Kitchen Sinks
    rks_cv = RKHSWeightingRandomSearchCV(RKS_CLASS,
                                        model_class=model_class,
                                        learner_param_grid=RKS_CV_PARAMS,
                                        model_param_grid=INSTANTIATIONS_PARAMS_DICT[model_class.__name__],
                                        folds=5, 
                                        n_iter=N_CV_ITER,
                                        rng=RNG,
                                        verbose=False
                                        )
    rks_cv.fit(X_train, y_train)
    rks_clf = rks_cv.best_estimator_
    results['Train time (s)'] = rks_cv.refit_time_

    results.update(evaluate_model(rks_clf, dataset_loader, task=task))
    save_results_to_csv(results, path)
    return


def run_other_experiment(dataset_loader, algorithm, path, task):
    """Run experiments for algorithms other than an SFGD instantiation.

    Simple crossvalidation with training and test errors.
    Saves results in the results folder.
    """
    results = {}
    results['Dataset'] = [dataset_loader.name]
    results['Instantiation'] = ['']
    results['Algorithm'] = [algorithm.name]

    X_train, _, y_train, _ = dataset_loader.load()
    cv = RandomizedSearchCV(algorithm.clf, algorithm.params, cv=5, verbose=0, n_jobs=-1, n_iter=N_CV_ITER, random_state=0)
    cv.fit(X_train, y_train)
    results['Train time (s)'] = cv.refit_time_
    
    clf = cv.best_estimator_
    for key in cv.best_params_:
        results[key] = cv.best_params_[key]

    results.update(evaluate_model(clf, dataset_loader, task=task))
    save_results_to_csv(results, path)
    return

def launch_experiments(task):
    results_path = get_results_path(task)
    os.remove(results_path) if os.path.exists(results_path) else None
    n_total_fits = get_n_total_fits(task)
    timetracker = TimeTracker(n_total_fits)
    if task == 'regression':
        dataset_loaders = REGRESSION_LOADERS
        rw_instantiations_list = RW_REGRESSION_INSTANTIATIONS
        rks_instantiations_list = RKS_REGRESSION_INSTANTIATIONS
        algorithms = REGRESSION_ALGOS
    elif task == 'classification':
        dataset_loaders = CLASSIFICATION_LOADERS
        rw_instantiations_list = RW_CLASSIFICATION_INSTANTIATIONS
        rks_instantiations_list = RKS_CLASSIFICATION_INSTANTIATIONS
        algorithms = CLASSIFICATION_ALGOS
    
    for (loader, rw_instantiations, rks_instantiations) in zip(dataset_loaders, rw_instantiations_list, rks_instantiations_list):
        print("Running experiments for dataset " + loader.name)
        for n_run in range(N_RUNS): 
            if RUN_SFGD: 
                for learner_class in LEARNERS:
                    for model_class in rw_instantiations:
                            run_one_sfgd_experiment(loader, learner_class, model_class, results_path, task)
                            timetracker.update()
                for model_class in rks_instantiations:
                        run_one_rks_experiment(loader, model_class, results_path, task)
                        timetracker.update()
            if RUN_OTHER: 
                for algo in algorithms:
                    if (not algo.name == 'SVM') or (n_run == 0):
                        run_other_experiment(loader, algo, results_path, task)
                        timetracker.update()

def ensure_folder_exists(path):
    if not os.path.exists(path):
        os.mkdir(path)

def get_results_path(task='classification'):
    ensure_folder_exists(RESULTS_FOLDER)
    return RESULTS_FOLDER + get_updated_filename(task) + '.csv'

def get_df(task):
    df = get_raw_df(get_results_path(task))
    if task == 'classification':
        df = df[['Algorithm', 'Dataset', 'Instantiation',
                 'Train error', 'Test error', 'Train time (s)',
                 'Inference time']]
    elif task == 'regression':
        df = df[['Algorithm', 'Dataset', 'Instantiation',
                 'Train $R^2$', 'Test $R^2$',
                 'Train time (s)',
                 'Inference time']]
    return clean_df(df)

def get_raw_df(path):
    return pd.read_csv(path, index_col=0)

def clean_df(df): 
    df = df.replace('', 'N/A')
    df = df.replace('LassoLearner', 'Lasso fit', regex=True)
    df = df.replace('LeastSquaresLearner', 'LS fit', regex=True)
    df = df.replace('Instantiation1', 'I1', regex=True)
    df = df.replace('Instantiation3', 'I2', regex=True)
    df = df.replace(np.nan, '', regex=True)
    df = df.replace('nan', '', regex=True)
    df = df.replace('bank marketing', 'marketing', regex=True)
    df = df.replace('breast cancer', 'cancer', regex=True)
    df = df.replace('magic gamma telescope', 'telescope', regex=True)
    df = df.replace('california housing', 'housing', regex=True)
    df = df.replace('skin segmentation', 'skin', regex=True)
    return df

def get_table_from_df(df):
    table = df.groupby(['Dataset', 'Algorithm', 'Instantiation']).mean()
    table = table.round(3)
    pd.set_option("display.precision", 3)
    return table

def get_std_table_from_df(df):
    table = df.groupby(['Dataset', 'Algorithm', 'Instantiation']).std()
    table = table.round(3)  
    pd.set_option("display.precision", 3)
    return table

def clean_table(table):
    table = table.replace(np.nan, '', regex=True)
    table = table.replace('nan', '', regex=True)
    return table

def clean_column(c, rnd, max_value):
    indices = c >= max_value
    new_c = c[indices]
    new_c = new_c.apply(np.log10)
    new_c = new_c.apply(np.floor)
    new_c = new_c.astype(int)
    new_c = new_c.astype(str)
    for i in np.arange(len(new_c)):
        new_c[i] = '>10' + '\\textsuperscript{' + str(new_c[i]) + '}'
    c = c.round(rnd)
    c = c.astype(str)
    c[indices] = new_c
    return c

def get_table_with_std_from_table_and_std(table, std_table):
    table_with_std = table.astype(str) + ' ± ' + std_table.astype(str)
    table_with_std = table_with_std.replace(' ± ', '', regex=False)
    table_with_std = table_with_std.replace(' ± nan', '', regex=True)
    return table_with_std

def save_table_to_latex(table, path):
    table.to_latex(path, column_format='rrrrrrrrr', escape=False, float_format="%.3f")

def generate_and_save_tables(task):
    df = get_df(task)
    table_path = get_table_path(get_updated_filename(task))
    table_with_std_path = get_table_with_std_path(get_updated_filename(task))
    table = get_table_from_df(df)
    std_table = get_std_table_from_df(df)
    table_with_std = get_table_with_std_from_table_and_std(table, std_table)
    save_table_to_latex(table, table_path)
    save_table_to_latex(table_with_std, table_with_std_path)
    print("tables saved to ./tables/")

def print_best_params(path):
    df = get_raw_df(path)
    print(df[['Dataset', 'Algorithm', 'max_theta', 'regularization']])

def run_expe(task='classification'):
    results_path = get_results_path(task)
    if RUN_EXPE:
        start = time()
        launch_experiments(task)
        save_run_time(RESULTS_FOLDER, get_updated_filename(task), start)
        print("results saved in " + results_path)
    if GENERATE_TABLES:
        generate_and_save_tables(task)
    if PRINT_BEST_PARAMS:
        print_best_params(results_path)

if __name__ == '__main__':
    run_expe('regression')
    run_expe('classification')
