import sys
from pathlib import Path
sys.path.append(str(Path.cwd()))

import argparse
import numpy as np
import pandas as pd
import sklearn.metrics as m

from data import get_dataset
import hpt

from nrgboost.preprocessing import fit_discretize_dataset, transform_dataset, infer_discretization_types
from nrgboost.tree.ensemble import GenerativeBoostingTreeEstimator
from nrgboost.tree.generative import constraints as c

parser = argparse.ArgumentParser(
                    prog='NRGBoost',
                    description='Run NRGBoost hyperparameter tuning')
parser.add_argument('dataset_name', type=str)
parser.add_argument('-T', '--num_trees', type=int, default=200)
parser.add_argument('--start_fold', type=int, default=0)
parser.add_argument('--stop_fold', type=int, default=5)
args = parser.parse_args()

dataset_name = args.dataset_name

dataset = get_dataset(dataset_name)
num_classes = dataset.num_classes

is_classification = num_classes > 0
is_multiclass = num_classes > 2

num_trees = args.num_trees
model_seed = 1984
fixed_params = {
    'max_ratio_in_leaf': 2,
    'num_model_samples': 320_000 if dataset_name == 'covertype' else 80_000,
    'p_refresh': 0.1,
    'num_chains': 64 if dataset_name == 'covertype' else 16,
    'burn_in': 100,
    'line_search': True,
    'categorical_split_one_vs_all': False,
    'feature_frac': 1,
    'temperature': 1,
    'mixing_coef': 0.1,
    'num_threads': 16,
    }

max_leaves_values = 2**np.arange(6, 15 if dataset_name == 'covertype' else 13, 2)
shrinkage_range = (0.01, 0.5)
num_shrinkage_evals = 6

for fold in range(args.start_fold, args.stop_fold):
    train_df, val_df, test_df = dataset.load_fold(fold, fixed_point=True)
    
    discretization_types = infer_discretization_types(train_df)
    
    train_data, uniform, transforms = fit_discretize_dataset(train_df, return_uniform=True, discretization_types=discretization_types)
    val_data = transform_dataset(val_df, transforms)
    test_data = transform_dataset(test_df, transforms)
    
    # Define metric eval
    if not is_classification:
        discrete = False
        target_col = train_df.columns[0]
        if target_col in discretization_types['discrete_numerical'] or target_col in discretization_types['discrete_quantized']:
            bin_centers = (transforms['num'][target_col][:-1] + transforms['num'][target_col][1:] - 1)/2
        else:
            bin_centers = (transforms['num'][target_col][:-1] + transforms['num'][target_col][1:])/2
        bin_widths = transforms['num'][target_col][1:] - transforms['num'][target_col][:-1]
        
        def metric(y, p):
            yh = np.sum(bin_centers*np.exp(p), -1)
            return 1 - np.mean((y - yh)**2)/np.var(y)
                
    elif not is_multiclass:
        discrete = True
        def metric(y, p):
            return m.roc_auc_score(y, np.maximum(p[...,1], -1000))
    else:
        discrete = True
        def metric(y, p):
            yh = np.argmax(p, -1)
            return np.mean(yh == y)
    
    def compute_metric(y, ps):
        metrics = []
        for p in ps:
            assert not np.any(np.isnan(p))
            metrics.append(metric(y, p))
            
        return np.array(metrics)
            
    # setup folders
    base_models_folder = Path(f'nrgb/models/{dataset_name}/fold_{fold}') 
    results_folder = Path(f'nrgb/results/{dataset_name}/fold_{fold}')
    
    # setup HPT
    max_leaves_grid = hpt.EarlyStoppingLineSearch(max_leaves_values, direction='max') #, 16384]
    shrinkage_grid = hpt.GoldenSearch(*shrinkage_range, num_evals=num_shrinkage_evals, log=True, direction='max')
    
    
    results = {}
    
    param = {**fixed_params}
    constraints = [c.max_ratio_in_leaf(param.pop('max_ratio_in_leaf'))]
    mixing_coef = param.pop('mixing_coef')
        
    # run HPT
    for max_leaves in max_leaves_grid:
        models_folder = base_models_folder / f'leaves_{max_leaves}'    
        models_folder.mkdir(parents=True, exist_ok=True)
        
        for shrinkage in shrinkage_grid:
            shrinkage = round(shrinkage, 3)
            print(f'Fitting {max_leaves} max leaves, {shrinkage} shrinkage...')
            
            model_path = models_folder / f'shrinkage_{shrinkage}.pkl'
                
            if model_path.exists():
                print('Model already exists. Loading...')
                bst = GenerativeBoostingTreeEstimator.load(model_path)
            else:
                bst = GenerativeBoostingTreeEstimator.from_distribution_marginals(
                    train_data, mixing_coef)
                bst.fit(
                    train_data, 
                    num_trees=num_trees, 
                    max_leaves=max_leaves, 
                    shrinkage=shrinkage, 
                    **param, 
                    constraints=constraints, 
                    seed=model_seed)
            
                print('Saving model...')
                bst.save(model_path)
            
            preds = bst.predict(val_data, slice_dims=[0], cumulative=True)
            val_metrics = compute_metric(val_data[:,0] if discrete else val_df.iloc[:,0], preds)
            
            preds = bst.predict(test_data, slice_dims=[0], cumulative=True)
            test_metrics = compute_metric(test_data[:,0] if discrete else test_df.iloc[:,0], preds)
            
            results[(max_leaves, shrinkage)] = pd.DataFrame({
                    'val': val_metrics,
                    'test': test_metrics,
                })
            
            
            best_val = np.max(val_metrics)
            shrinkage_grid.fval(best_val)
            print(f'Val Metric: {best_val}')
        
        max_leaves_grid.fval(shrinkage_grid.best)

    results_folder.mkdir(parents=True, exist_ok=True)
    results_df = pd.concat(results.values(), axis=0, keys=results.keys(), names = ['max_leaves', 'shrinkage'])
    results_df.to_csv(results_folder / f'results.csv')


folds = range(args.start_fold, args.stop_fold)
compiled_results = pd.concat([pd.read_csv(f'nrgb/results/{dataset_name}/fold_{fold}/results.csv', index_col=[0, 1, 2]) for fold in folds], keys=folds, names =['fold'])

def metrics(df):
    best = df['val'].idxmax()
    r = df.loc[best]
    r['best'] = best
    return r

print(compiled_results.sort_index(level='max_leaves').groupby('fold').apply(metrics))