import matplotlib.pyplot as plt
import math
import numpy as np
import pandas as pd
from time import time
import argparse

from sklearn.metrics import accuracy_score, mean_squared_error

from learners import SFGDLearner, LeastSquaresLearner, LassoLearner, BaseLearner
from rkhs_weightings import RKHSWeightingGridSearchCV, RKHSWeightingClassifier
from models import RWSign
from loss import MSE
from rks import _RKSEstimator, RKSClassifier

from dataset_loaders import MNISTLoader
from expe_utils import ensure_folder_exists, get_color_list, save_run_time
from expe_utils import RESULTS_FOLDER, FIGURES_FOLDER, MARKERS
import visuals # Must be at the end of the imports for some reason

parser = argparse.ArgumentParser(description='Basic experiment for comparing various RKHS weighting learning algorithm.')
parser.add_argument('--norun', action='store_true', help='Do not run the experiment. Only generate the figures.')
parser.add_argument('--final', action='store_true', help='Do the long experiment.')
parser.add_argument('--noshow', action='store_true', help='Do not show the plots.')
args = parser.parse_args()

TEST = not args.final
RUN_EXPE = not args.norun
GENERATE_FIGURES = True

RNG = np.random.default_rng(0)

FILENAME = 'algo-time-compar'

X_train, X_test, y_train, y_test = MNISTLoader(digits=[1, 7], scale_x=True).load()

if TEST:
    N_RUNS = 1
    N_ITERS = [int(x) for x in np.logspace(math.log10(100), math.log10(2000), num=4)]
    REGULARIZATION = [0.000001]
    LASSO_REGULARIZATION = [0.000001]
    BATCH_SIZE = 100
    FILENAME += '-test'
else:
    N_RUNS = 10
    N_ITERS = [int(x) for x in np.logspace(math.log10(100), math.log10(2000), num=20)]
    REGULARIZATION = [0.000001]
    LASSO_REGULARIZATION = [0.000001]
    BATCH_SIZE = 100

# these lists must be the same length
LEARNERS = [SFGDLearner, LeastSquaresLearner, LassoLearner, RKSClassifier]
LEARNER_PARAMS = [{'batch_size' : [BATCH_SIZE], 'loss' : [MSE()], 'B' : [1000]}, 
                    {},
                    {},
                    {}]

def do_expe_for_one_learner(learner_class: BaseLearner, **learner_params):
    results_df = pd.DataFrame({})
    if learner_class == LassoLearner:
        learner_params = {**learner_params, 'regularization' : LASSO_REGULARIZATION}
    else:
        learner_params = {**learner_params, 'regularization' : REGULARIZATION}
    model_params = {}
    for seed in range(N_RUNS):
        start = time() 
        for n_iter in N_ITERS:
            partial_results = {}
            if issubclass(learner_class, _RKSEstimator):
                cv = RKHSWeightingGridSearchCV(learner_class, 
                                            model_class=RWSign, 
                                            rng=seed,
                                            learner_param_grid={**learner_params, 'n_neurons' : [n_iter]},
                                            model_param_grid=model_params,
                                            verbose=True)
            else:
                cv = RKHSWeightingGridSearchCV(RKHSWeightingClassifier, 
                                            learner_class, 
                                            RWSign, 
                                            rng=seed,
                                            learner_param_grid={**learner_params, 'n_iter' : [n_iter]},
                                            model_param_grid=model_params,
                                            verbose=True)
            cv.fit(X_train, y_train)
            estimator = cv.best_estimator_
            partial_results['algorithm'] = learner_class.__name__
            partial_results['T'] = n_iter
            if issubclass(learner_class, _RKSEstimator):
                partial_results['Nonzero coefficients'] = len(estimator.features)
            else:
                partial_results['Nonzero coefficients'] = estimator.model.get_n_centers()
            partial_results['Training time'] = cv.refit_time_
            partial_results['Training MSE'] = mean_squared_error(y_train, estimator.raw_output(X_train)) 
            partial_results['Training error'] = 1 - accuracy_score(y_train, estimator.predict(X_train)) 
            partial_results['Test MSE'] = mean_squared_error(y_test, estimator.raw_output(X_test)) 
            partial_results['Test error'] = 1 - accuracy_score(y_test, estimator.predict(X_test)) 
            partial_results_df = pd.DataFrame(partial_results, index=[0])
            results_df = pd.concat((results_df, partial_results_df))
        elapsed = time() - start
        print("Finished testing seed {} of {} for {} in {} seconds.".format(seed+1, N_RUNS, learner_class.__name__, elapsed))

    return results_df

def get_results_df_path():
    ensure_folder_exists(RESULTS_FOLDER)
    return RESULTS_FOLDER+FILENAME+'.csv'

def get_figure_path():
    ensure_folder_exists(FIGURES_FOLDER)
    return FIGURES_FOLDER+FILENAME

def get_results_df():
    df = get_raw_df()
    return clean_df(df)

def get_raw_df():
    return pd.read_csv(get_results_df_path(), index_col=0)

def clean_df(df: pd.DataFrame):
    df.replace('SFGDLearner', 'SFGD', inplace=True)
    df.replace('LeastSquaresLearner', 'Least squares fit', inplace=True)
    df.replace('LassoLearner', 'Lasso fit', inplace=True)
    df.replace('RKSClassifier', 'RKS', inplace=True)
    df.replace('OptimalStepsizeLearner', 'Optimal stepsize descent', inplace=True) 
    return df

def expe():
    results_df = pd.DataFrame({})

    for L, params in zip(LEARNERS, LEARNER_PARAMS):
        partial_results_df = do_expe_for_one_learner(L, **params)
        results_df = pd.concat((results_df, partial_results_df), ignore_index=True)

    results_df.to_csv(get_results_df_path())

def make_figures():

    results_df = get_results_df()
    algos = results_df['algorithm'].unique()
    colors = get_color_list(len(algos))
    mean = results_df.groupby(by=['algorithm', 'T'], as_index=False).mean()
    std = results_df.groupby(by=['algorithm', 'T'], as_index=False).std()

    plt.figure(figsize=(4,4))
    for i, algo in enumerate(algos):
        mean_df = mean.loc[mean['algorithm'] == algo]
        std_df = std.loc[std['algorithm'] == algo]
        x = mean_df['T']
        y = mean_df['Training error']
        error = std_df['Training error']
        plt.plot(x, y, label=algo, color=colors[i], marker=MARKERS[i])
        plt.errorbar(x, y, yerr=error, color=colors[i], marker=MARKERS[i], alpha=0.5)
    plt.xlabel('T')
    plt.ylabel('Training error')
    plt.legend()
    plt.tight_layout()
    plt.savefig(get_figure_path() + '-train_01.pdf')
    if not args.noshow:
        plt.show()
    plt.close('all')

    plt.figure(figsize=(4,4))
    for i, algo in enumerate(algos):
        mean_df = mean.loc[mean['algorithm'] == algo]
        std_df = std.loc[std['algorithm'] == algo]
        x = mean_df['T']
        y = mean_df['Test error']
        error = std_df['Test error']
        plt.plot(x, y, label=algo, color=colors[i], marker=MARKERS[i])
        plt.errorbar(x, y, yerr=error, color=colors[i], marker=MARKERS[i], alpha=0.5)
    plt.xlabel('T')
    plt.ylabel('Test error')
    plt.legend()
    plt.tight_layout()
    plt.savefig(get_figure_path() + '-test_01.pdf')
    if not args.noshow:
        plt.show()
    plt.close('all')

    plt.figure(figsize=(4,4))
    for i, algo in enumerate(algos):
        mean_df = mean.loc[mean['algorithm'] == algo]
        std_df = std.loc[std['algorithm'] == algo]
        x = mean_df['T']
        y = mean_df['Training time']
        error = std_df['Training time']
        plt.plot(x, y, label=algo, color=colors[i], marker=MARKERS[i])
        plt.errorbar(x, y, yerr=error, color=colors[i], marker=MARKERS[i], alpha=0.5)
    plt.xlabel('T')
    plt.ylabel('Training time (s)')
    plt.legend()
    plt.tight_layout()
    plt.savefig(get_figure_path() + '-train_time.pdf')
    if not args.noshow:
        plt.show()
    plt.close('all')

    plt.figure(figsize=(4,4))
    for i, algo in enumerate(algos):
        mean_df = mean.loc[mean['algorithm'] == algo]
        std_df = std.loc[std['algorithm'] == algo]
        x = mean_df['Training time']
        y = mean_df['Test error']
        x_error = std_df['Training time']
        y_error = std_df['Test error']
        plt.plot(x, y, label=algo, color=colors[i], marker=MARKERS[i])
        plt.errorbar(x, y, xerr=x_error, yerr=y_error, color=colors[i], marker=MARKERS[i], alpha=0.5)
    plt.xlabel('Training time (s)')
    plt.ylabel('Test error')
    plt.legend()
    plt.tight_layout()
    plt.savefig(get_figure_path() + '-test_error_vs_train_time.pdf')
    if not args.noshow:
        plt.show()
    plt.close('all')

    plt.figure(figsize=(4,4))
    for i, algo in enumerate(algos):
        mean_df = mean.loc[mean['algorithm'] == algo]
        std_df = std.loc[std['algorithm'] == algo]
        x = mean_df['T']
        y = mean_df['Nonzero coefficients']
        error = std_df['Nonzero coefficients']
        plt.plot(x, y, label=algo, color=colors[i], marker=MARKERS[i])
        plt.errorbar(x, y, yerr=error, color=colors[i], marker=MARKERS[i], alpha=0.5)
    plt.xlabel('T')
    plt.ylabel('Nonzero coefficients')
    plt.legend()
    plt.tight_layout()
    plt.savefig(get_figure_path() + '-nonzero_coefficients.pdf')
    if not args.noshow:
        plt.show()
    plt.close('all')

if __name__ == '__main__':
    if RUN_EXPE:
        start = time()
        expe()
        save_run_time(RESULTS_FOLDER, FILENAME, start)
    if GENERATE_FIGURES:
        make_figures()

