import sys
from pathlib import Path
import argparse
sys.path.append(str(Path.cwd()))

import numpy as np
import pandas as pd

from data import get_dataset
from nrgboost.tree.ensemble import LogitBaggingTreeEstimator

datasets = ['abalone', 'california', 'protein', 'adult', 'miniboone', 'mnist', 'covertype']

parser = argparse.ArgumentParser('Sample DEF model')
parser.add_argument('dataset_name', type=str, choices=datasets)
parser.add_argument('-c', '--criterion', type=str, choices=['likelihood', 'brier'], default='likelihood')
parser.add_argument('-s', '--seed', type=int, default=99999)
parser.add_argument('-f', '--fold', type=int, default=0)
parser.add_argument('--sfold_start', type=int, default=0)
parser.add_argument('--sfold_stop', type=int, default=5)
args = parser.parse_args()

sampling_folds = range(args.sfold_start, args.sfold_stop)
fold = args.fold
dataset_name = args.dataset_name
criterion = args.criterion
print(fold, dataset_name, criterion)

sampling_seed = args.seed

param_cols = ['leaves', 'ffrac', 'min_data']
results_path = Path('def', 'results', criterion, dataset_name, f'fold_{fold}', 'results.csv')
model_path = Path('def', 'models', criterion, dataset_name, f'fold_{fold}')
base_samples_path = Path('def', 'samples', criterion, dataset_name, f'fold_{fold}')
base_samples_path.mkdir(parents=True, exist_ok=True)

dataset = get_dataset(dataset_name)
train_df, val_df, test_df = dataset.load_fold(0, fixed_point=True)
start_val, start_test, num_total = np.cumsum([len(train_df), len(val_df), len(test_df)])

results_df = pd.read_csv(results_path, index_col=list(range(len(param_cols))))
best = results_df['val'].idxmax()

best_model_path = model_path
for param_name, param_value in zip(param_cols, best):
    best_model_path /= f'{param_name}_{param_value}'

best_model_path = str(best_model_path) + '.pkl'

model = LogitBaggingTreeEstimator.load(best_model_path)

for sampling_fold in sampling_folds:
    samples_path = base_samples_path / f'sampling_fold_{sampling_fold}'
    samples_path.mkdir(parents=True, exist_ok=True)

    samples = model.sample_points(num_total, seed=[sampling_fold, fold, sampling_seed])
    np.save(samples_path / 'train.npy', samples[:start_val])
    np.save(samples_path / 'val.npy', samples[start_val:start_test])
    np.save(samples_path / 'test.npy', samples[start_test:])
