from dataset_loaders import *
from expe_utils import RESULTS_FOLDER, TABLES_FOLDER, ensure_folder_exists, save_run_time
from expe_params import RW_LEARNER_CV_PARAMS, RKS_CV_PARAMS, RWSIGN_CV_PARAMS
from expe_params import RWRELU_CV_PARAMS, RWEXPSIGN_CV_PARAMS, RWEXPRELU_CV_PARAMS, RWSTUMPS_CV_PARAMS
from learners import *
from loss import MSE
from models import RWSign, RWExpSign, RWRelu, RWExpRelu, RWStumps
from rkhs_weightings import RKHSWeightingClassifier, RKHSWeightingRegressor 
from rkhs_weightings import RKHSWeightingRandomSearchCV 
from rks import RKSClassifier, RKSRegressor
from bounds import rkhs_rademacher_bound, l2p_rademacher_bound

import argparse
from copy import deepcopy
import numpy as np
import pandas as pd
import pickle as pkl
from scipy.stats import uniform, loguniform
from sklearn.metrics import accuracy_score, mean_squared_error
from time import time
from tqdm import trange

import visuals

parser = argparse.ArgumentParser(description='Experiment for comparing RKHS weightings and random kitchen sinks.')
parser.add_argument('--norun', action='store_true', help='Do not run the experiment. Only generate the table.')
parser.add_argument('--final', action='store_true', help='Do the long experiment.')
parser.add_argument('--print-dataset-info', action='store_true', help='Print the dataset information.')

args = parser.parse_args()
FINAL = args.final
RNG = np.random.default_rng(0)
if FINAL:
    N_ITER = 500
    N_RUNS = 10
    N_CV_ITER = 50
else:
    N_ITER = 500
    N_RUNS = 10
    N_CV_ITER = 50

CLASSIFICATION_LOADERS = [
                   BreastCancerLoader(final=FINAL), 
                   AdultsLoader(final=FINAL), 
                   MNISTLoader(final=FINAL, digits=[1,7]), 
                   SkinSegmentationLoader(final=FINAL),
                   BankMarketingLoader(final=FINAL),
                   MagicGammaTelescopeLoader(final=FINAL),
                   PhishingLoader(final=FINAL)
                   ]

REGRESSION_KWARGS = {'final' : FINAL, 'scale_x' : True, 'scale_y' : True}
REGRESSION_LOADERS = [
                DiabetesLoader(**REGRESSION_KWARGS),
                CaliforniaHousingLoader(**REGRESSION_KWARGS),
                ConcreteLoader(**REGRESSION_KWARGS),
                WineLoader(**REGRESSION_KWARGS),
                ConductivityLoader(**REGRESSION_KWARGS),
                AbaloneLoader(**REGRESSION_KWARGS),
            ]

INSTANTIATIONS = [RWSign, RWExpSign, RWRelu, RWExpRelu, RWStumps]
RW_CV_PARAMS = [RWSIGN_CV_PARAMS, RWEXPSIGN_CV_PARAMS, RWRELU_CV_PARAMS, 
                RWEXPRELU_CV_PARAMS, RWSTUMPS_CV_PARAMS]
for params in RW_CV_PARAMS:
    params.update({'n_mc' : [N_ITER]})
RKS_CV_PARAMS['n_neurons'] = [N_ITER]
RW_LEARNER_CV_PARAMS['n_iter'] = [N_ITER]

def get_filename(task='classification'):
    filename = 'rkhs-vs-rks-' + task
    if not FINAL:
        filename += '-test'
    return filename

def get_csv_filename(task='classification'):
    return RESULTS_FOLDER + get_filename(task) + '.csv'

def get_table_filename(task='classification'):
    return TABLES_FOLDER + get_filename(task)

def get_pickle_path(model_name, task='classification'):
    model_folders = RESULTS_FOLDER + 'rkhs-vs-rks-models/'
    ensure_folder_exists(model_folders)
    return model_folders + f'{get_filename(task)}-{model_name}.pkl'

def get_model_name(instantiation, dataset_name, rks=False):
    name = instantiation.__name__ + '-' + dataset_name
    if rks:
        name += '-rks'
    return name

def save_model(model, model_name):
    model_path = get_pickle_path(model_name)
    with open(model_path, 'wb') as f:
        pkl.dump(model, f)

def load_model(model_name):
    model_path = get_pickle_path(model_name)
    with open(model_path, 'rb') as f:
        model = pkl.load(f)
    return model

def run_expe(task='classification'):
    results = {
        'Instantiation' : [],
        'Dataset' : [],
        'prediction model' : [],
        'n_mc' : [],
        'Train time (s)' : [],
        'Train MSE' : [],
        'Test MSE' : [],
        'Th. 6.2' : [],
        'Th. 6.5' : []
    }
    if task == 'classification':
        dataset_loaders = CLASSIFICATION_LOADERS
        RKHS_CLASS = RKHSWeightingClassifier
        RKS_CLASS = RKSClassifier
        results.update({
            'Train error' : [],
            'Test error' : []
        })
    elif task == 'regression':
        dataset_loaders = REGRESSION_LOADERS
        RKHS_CLASS = RKHSWeightingRegressor
        RKS_CLASS = RKSRegressor
    else:
        raise ValueError('task must be either classification or regression')
    for dataset_loader in dataset_loaders:
        X_train, X_test, y_train, y_test = dataset_loader.load()
        y_max = np.max(np.abs(np.concatenate([y_train, y_test])))
        print(f'Running for dataset {dataset_loader.name}...')
        for (instantiation, rw_cv_params) in zip(INSTANTIATIONS, RW_CV_PARAMS):
            print(f'    Running for instantiation {instantiation.__name__}...')
            for rng in trange(N_RUNS, delay=5):
                rw_cv = RKHSWeightingRandomSearchCV(RKHS_CLASS,
                                    LeastSquaresLearner, 
                                    instantiation, 
                                    RW_LEARNER_CV_PARAMS, 
                                    rw_cv_params,
                                    folds=5, 
                                    n_iter=N_CV_ITER,
                                    rng=rng,
                                    verbose=False)

                # RKHS Weighting
                rw_cv.fit(X_train, y_train)
                rkhs_weighting_clf = rw_cv.best_estimator_
                rkhs_weighting_clf.refit_time_ = rw_cv.refit_time_
                learner = LeastSquaresLearner(**rw_cv.best_learner_params_)

                # Random Kitchen Sinks
                rks_cv = RKHSWeightingRandomSearchCV(RKS_CLASS,
                                                model_class=instantiation,
                                                learner_param_grid=RKS_CV_PARAMS,
                                                model_param_grid=rw_cv_params,
                                                folds=5, 
                                                n_iter=N_CV_ITER,
                                                rng=rng,
                                                verbose=False
                                                )
                rks_cv.fit(X_train, y_train)
                rks_clf = rks_cv.best_estimator_
                rks_clf.refit_time_ = rks_cv.refit_time_

                rho = MSE().lipschitz(rkhs_weighting_clf.model.max_output(), y_max=y_max)
                tau = rkhs_weighting_clf.model.tau_approx(X_test)

                clfs = [rkhs_weighting_clf, rks_clf]
                model_names = ['RW', 'RKS']
                for clf, model_name in zip(clfs, model_names):
                    results['Instantiation'].append(instantiation.__name__)
                    results['Dataset'].append(dataset_loader.name)
                    results['prediction model'].append(model_name)
                    results['n_mc'].append(N_ITER)
                    if task == 'classification':
                        results['Train error'].append(1 - accuracy_score(clf.predict(X_train), y_train))
                        results['Test error'].append(1 - accuracy_score(clf.predict(X_test), y_test))
                    results['Train time (s)'].append(clf.refit_time_)
                    results['Train MSE'].append(mean_squared_error(y_train, clf.raw_output(X_train)))
                    results['Test MSE'].append(mean_squared_error(y_test, clf.raw_output(X_test)))
                    if model_name == 'RW':
                        results['Th. 6.2'].append(rkhs_rademacher_bound(clf.model, rho))
                        results['Th. 6.5'].append(l2p_rademacher_bound(clf.model, rho, tau, X_train))
                    elif model_name == 'RKS':
                        rho = learner.loss.lipschitz(max(clf.raw_output(X_test)), y_max=y_max)
                        results['Th. 6.2'].append('')
                        results['Th. 6.5'].append(l2p_rademacher_bound(clf, rho, tau, X_train))

                df = pd.DataFrame(results)
                df.to_csv(get_csv_filename(task=task), index=False)

def clean_value(x, sig=3):
    if pd.isna(x) or x == 0:
        return x
    value = round(x, -int(np.floor(np.log10(abs(x)))) + (sig - 1))
    if value > 100:
        exponent = int(np.floor(np.log10(value)))
        return '>10\\textsuperscript{' + str(exponent) +'}'
    elif value < 0.1:
        return value
    else:
        return str(value)

def remove_index_column(table):
    if 'index' in table.columns:
        table = table.drop(columns='index')
    return table

def make_table(task='classification'):
    models = ['RW', 'RKS']
    if task == 'classification':
        metrics = ['Train error', 'Test error', 'Train MSE', 'Test MSE', 'Th. 6.2', 'Th. 6.5', 'Train time (s)']
        std_metrics = ['Train error', 'Test error']
    if task == 'regression':
        metrics = ['Train MSE', 'Test MSE', 'Th. 6.2', 'Th. 6.5', 'Train time (s)']
        std_metrics = ['Train MSE', 'Test MSE']
    df = pd.read_csv(get_csv_filename(task=task))
    df = df[df['prediction model'].isin(models)]
    df = df.replace('bank marketing', 'marketing', regex=True)
    df = df.replace('breast cancer', 'cancer', regex=True)
    df = df.replace('magic gamma telescope', 'telescope', regex=True)
    df = df.replace('skin segmentation', 'skin', regex=True)
    pivot_table = df.pivot_table(values=metrics, 
                                  index=['Dataset', 'Instantiation'], 
                                  columns='prediction model', 
                                  aggfunc='mean')
    std_pivot_table = df.pivot_table(values=std_metrics, 
                                  index=['Dataset', 'Instantiation'],
                                  columns='prediction model', 
                                  aggfunc='std')
    pivot_table = remove_index_column(pivot_table)
    std_pivot_table = remove_index_column(std_pivot_table)

    col_to_remove = ('Th. 6.5', 'RKS')
    if col_to_remove in pivot_table.columns:
        pivot_table = pivot_table.drop(columns=[col_to_remove])

    th_cols = [col for col in pivot_table.columns if 'Th.' in col[0]]
    for col in th_cols:
        pivot_table[col] = pivot_table[col].apply(clean_value)

    metric_index = pd.CategoricalIndex(
        pivot_table.columns.get_level_values(0),
        categories=metrics,
        ordered=True
    )
    std_metric_index = pd.CategoricalIndex(
        std_pivot_table.columns.get_level_values(0),
        categories=std_metrics,
        ordered=True
    )

    # Sort the columns by the new metric index
    pivot_table.columns = pd.MultiIndex.from_arrays(
        [metric_index, pivot_table.columns.get_level_values(1)]
    )
    std_pivot_table.columns = pd.MultiIndex.from_arrays(
        [std_metric_index, std_pivot_table.columns.get_level_values(1)]
    )

    # Sort using the new ordering: first by metric, then by model
    pivot_table = pivot_table.sort_index(axis=1, level=[0, 1])
    pivot_table.columns.names = [None, None]
    std_pivot_table = std_pivot_table.sort_index(axis=1, level=[0, 1])
    std_pivot_table.columns.names = [None, None]

    pivot_table.to_latex(
        get_table_filename(task=task) + '.tex',
        index=True,
        float_format="%.3f",           # Format floats to 3 decimal places
        multicolumn=True,              # Group subcolumns under model name
        multirow=True,                  # Allow multirow labels if needed
        column_format='llc' + 'cc' * (pivot_table.shape[1] // 2)
    )

    std_pivot_table.to_latex(
        get_table_filename(task=task) + '-std.tex',
        index=True,
        float_format="%.3f",           # Format floats to 3 decimal places
        multicolumn=True,              # Group subcolumns under model name
        multirow=True,                  # Allow multirow labels if needed
        column_format='llc' + 'cc' * (pivot_table.shape[1] // 2)
    )

    # Keep only columns that also appear in the std table
    common_cols = [col for col in pivot_table.columns if col in std_pivot_table.columns]
    pivot_table = pivot_table.loc[:, common_cols]

    # Build merged table with "mean ± std" strings
    merged = pd.DataFrame(index=pivot_table.index)

    for col in pivot_table.columns:
        means = pivot_table[col]
        stds = std_pivot_table[col]

        formatted = []
        for m, s in zip(means.values, stds.values):
            # If mean already a non-numeric string (e.g. cleaned theoretical values), keep it
            try:
                m_float = float(m)
                m_is_num = True
            except Exception:
                m_is_num = False

            try:
                s_float = float(s)
                s_is_num = True
            except Exception:
                s_is_num = False

            if m_is_num and s_is_num and not np.isnan(s_float):
                formatted.append(f"{m_float:.3f} ± {s_float:.3f}")
            else:
                # If std is missing or mean is non-numeric, just use the mean representation
                formatted.append(str(m))

        merged[col] = formatted

    # Preserve MultiIndex column names for LaTeX output
    merged.columns = pivot_table.columns
    merged.columns.names = [None, None]

    # Save merged table to LaTeX
    merged.to_latex(
        get_table_filename(task=task) + '-with-std.tex',
        index=True,
        escape=False,
        multicolumn=True,
        multirow=True,
        column_format='llc' + 'cc' * (merged.shape[1] // 2)
    )



if __name__ == '__main__':
    if not args.norun:
        start = time()
        run_expe(task='classification')
        save_run_time(RESULTS_FOLDER, get_filename('classification'), start)

        start = time()
        run_expe(task='regression')
        save_run_time(RESULTS_FOLDER, get_filename('regression'), start)
    make_table(task='classification')
    make_table(task='regression')
    if args.print_dataset_info:
        print('Classification datasets:')
        for dataset_loader in CLASSIFICATION_LOADERS:
            dataset_loader.info()
        print('Regression datasets:')
        for dataset_loader in REGRESSION_LOADERS:
            dataset_loader.info()
