import sys
from pathlib import Path
sys.path.append(str(Path.cwd()))

import argparse
import numpy as np
import pandas as pd
import sklearn.metrics as m

from data import get_dataset
import hpt

from nrgboost.preprocessing import fit_discretize_dataset, transform_dataset, infer_discretization_types
from nrgboost.tree.ensemble import fit_generative_bagging_estimator, LogitBaggingTreeEstimator
from nrgboost.tree.generative import constraints as c

datasets = ['abalone', 'california', 'protein', 'adult', 'miniboone', 'mnist', 'covertype']

parser = argparse.ArgumentParser(
                    prog='DEF',
                    description='Run DEF hyperparameter tuning')
parser.add_argument('dataset_name', type=str, choices=['abalone', 'california', 'protein', 'adult', 'miniboone', 'mnist', 'covertype'])
parser.add_argument('-c', '--criterion', type=str, choices=['likelihood', 'brier'], default='likelihood')
parser.add_argument('-T', '--num_trees', type=int, default=1000)
parser.add_argument('--start_fold', type=int, default=0)
parser.add_argument('--stop_fold', type=int, default=5)
args = parser.parse_args()

dataset_name = args.dataset_name
bagging_seed = 1984
num_trees = args.num_trees
criterion = args.criterion

dataset = get_dataset(dataset_name)
num_classes = dataset.num_classes

is_classification = num_classes > 0
is_multiclass = num_classes > 2

max_leaves_values = 2**np.arange(12 if dataset_name == 'adult' else 14, 7, -2)
feature_fraction_values = [1/2, 1/4, 0]
min_data_in_leaf_values = [0, 1, 3, 10, 30]

for fold in range(args.start_fold, args.stop_fold):
    base_models_folder = Path(f'def/models/{criterion}/{dataset_name}/fold_{fold}')
    results_folder = Path(f'def/results/{criterion}/{dataset_name}/fold_{fold}')
    results_folder.mkdir(parents=True, exist_ok=True)

    train_df, val_df, test_df = dataset.load_fold(fold, fixed_point=True)
    
    discretization_types = infer_discretization_types(train_df)
    train_data, uniform, transforms = fit_discretize_dataset(train_df, return_uniform=True, discretization_types=discretization_types)
    val_data = transform_dataset(val_df, transforms)
    test_data = transform_dataset(test_df, transforms)
    
    
    if not is_classification:
        discrete = False
        target_col = train_df.columns[0]
        if target_col in discretization_types['discrete_numerical'] or target_col in discretization_types['discrete_quantized']:
            bin_centers = (transforms['num'][target_col][:-1] + transforms['num'][target_col][1:] - 1)/2
        else:
            bin_centers = (transforms['num'][target_col][:-1] + transforms['num'][target_col][1:])/2
        bin_widths = transforms['num'][target_col][1:] - transforms['num'][target_col][:-1]
        log_vol = np.log(bin_widths)
        num_classes = bin_centers.size
        
        def metric(y, p):
            p = p + log_vol
            p -= np.logaddexp.reduce(p, axis=-1, keepdims=True)
            yh = np.sum(bin_centers*np.exp(p), -1)
            return 1 - np.mean((y - yh)**2)/np.var(y)
                
    elif not is_multiclass:
        discrete = True
        def metric(y, p):
            return m.roc_auc_score(y, np.maximum(p[...,1], -1000))
    else:
        discrete = True
        def metric(y, p):
            yh = np.argmax(p, -1)
            return np.mean(yh == y)
    
    max_leaves_grid = hpt.EarlyStoppingLineSearch(max_leaves_values, patience=1, direction='max') #, 16384]
    feature_frac_power_grid = hpt.EarlyStoppingLineSearch(feature_fraction_values, direction='max')
    min_data_in_leaf_grid = hpt.EarlyStoppingLineSearch(min_data_in_leaf_values, direction='max')
    
    results = {}
    
    for max_leaves in max_leaves_grid:
        for feature_frac_power in feature_frac_power_grid:
            feature_frac = round(train_df.shape[-1]**-feature_frac_power, 3)
            for min_data_in_leaf in min_data_in_leaf_grid:
                model_path = base_models_folder / f'leaves_{max_leaves}' / f'ffrac_{feature_frac_power}' 
                model_path.mkdir(parents=True, exist_ok=True)
                model_path /= f'min_data_{min_data_in_leaf}.pkl'
                
                print(f'Max Leaves: {max_leaves}, FFrac: {feature_frac}, Min Data: {min_data_in_leaf}')
                if model_path.exists():
                    print('Model already exists. Skipping')
                    trees = LogitBaggingTreeEstimator.load(model_path)
                else:
                    constraints = [c.min_ref_in_leaf(1e-300)] if dataset_name=='mnist' else []
                    if min_data_in_leaf > 0:
                        constraints.append(c.min_data_in_leaf(min_data_in_leaf))

                    trees = fit_generative_bagging_estimator(
                        train_data,
                        uniform,
                        num_trees=num_trees,
                        criterion=criterion,
                        max_leaves=max_leaves,
                        constraints= constraints,
                        bagging_frac=1,
                        feature_frac=feature_frac,
                        categorical_split_one_vs_all=False,
                        seed=[fold, bagging_seed],
                        num_jobs=8,
                        )
                
                    print('Saving len:', len(trees.trees))
                    trees.save(model_path)
        
                preds = trees.predict(val_data, slice_dims=[0], chunksize=100)
                preds[np.isnan(preds)] = -np.log(num_classes)
                val_metric = metric(val_data[:,0] if discrete else val_df.iloc[:,0], preds)
                
                preds = trees.predict(test_data, slice_dims=[0], chunksize=100)
                preds[np.isnan(preds)] = -np.log(num_classes)
                test_metric = metric(test_data[:,0] if discrete else test_df.iloc[:,0], preds)
    
                results[(max_leaves, feature_frac_power, min_data_in_leaf)] = pd.Series({
                        'val': val_metric,
                        'test': test_metric,
                    })
                
                print(f'Val: {val_metric}')
                
                min_data_in_leaf_grid.fval(val_metric)
            feature_frac_power_grid.fval(min_data_in_leaf_grid.best)
        max_leaves_grid.fval(feature_frac_power_grid.best)


    results_df = pd.concat(results.values(), axis=1, keys=results.keys(), names = ['max_leaves', 'feature_frac_power', 'min_data_in_leaf']).T
    results_df.to_csv(results_folder / 'results.csv')


folds = range(args.start_fold, args.stop_fold)
compiled_results = pd.concat([pd.read_csv(f'def/results/{criterion}/{dataset_name}/fold_{fold}/results.csv', index_col=[0, 1, 2]) for fold in folds], keys=folds, names =['fold'])

def metrics(df):
    best = df['val'].idxmax()
    r = df.loc[best]
    r['best'] = best
    return r

print(compiled_results.sort_index(level='max_leaves').groupby('fold').apply(metrics))